Unify Conv implementation (#54)
* fixed readme for new git repo * conv with bias as an optional input
This commit is contained in:
parent
1777c07b1e
commit
653fa69102
|
@ -636,35 +636,6 @@ ONNX ConvInteger operation
|
||||||
|
|
||||||
1. `y`: memref of any type values or tensor of any type values
|
1. `y`: memref of any type values or tensor of any type values
|
||||||
|
|
||||||
### onnx.ConvNoBias (ONNXConvNoBiasOp)
|
|
||||||
ONNX Conv operation with no Bias operand.
|
|
||||||
|
|
||||||
#### Description:
|
|
||||||
|
|
||||||
|
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
|
||||||
"computes the output."
|
|
||||||
|
|
||||||
#### Operands:
|
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
|
||||||
|
|
||||||
#### Attributes:
|
|
||||||
|
|
||||||
| Attribute | MLIR Type | Description |
|
|
||||||
| :-------: | :-------: | ----------- |
|
|
||||||
| `auto_pad` | `StringAttr` | string attribute attribute |
|
|
||||||
| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute |
|
|
||||||
| `group` | `IntegerAttr` | 64-bit integer attribute attribute |
|
|
||||||
| `kernel_shape` | `ArrayAttr` | 64-bit integer array attribute attribute |
|
|
||||||
| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute |
|
|
||||||
| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute |
|
|
||||||
|
|
||||||
#### Results:
|
|
||||||
|
|
||||||
1. `o_Y`: memref of any type values or tensor of any type values
|
|
||||||
|
|
||||||
### onnx.Conv (ONNXConvOp)
|
### onnx.Conv (ONNXConvOp)
|
||||||
ONNX Conv operation
|
ONNX Conv operation
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,6 @@ special_attr_defaults = dict([
|
||||||
|
|
||||||
# Special operation importing handlers.
|
# Special operation importing handlers.
|
||||||
special_op_handler = dict([
|
special_op_handler = dict([
|
||||||
("Conv", "ImportNodeConv"),
|
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
("BatchNormalization", "ImportNodeBatchNormalization"),
|
("BatchNormalization", "ImportNodeBatchNormalization"),
|
||||||
("Pad", "ImportNodePad"),
|
("Pad", "ImportNodePad"),
|
||||||
|
@ -47,11 +46,11 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs'
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm']
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
||||||
|
|
||||||
# Operations who have operands that, if produced by constant operations, should
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
# be promoted to become an attribute (via attribute promotion).
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
|
|
@ -303,30 +303,6 @@ private:
|
||||||
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
|
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
|
||||||
* Special handle for Conv operations.
|
|
||||||
* c++ does not allow template specialization inside a class scope
|
|
||||||
* a specialized function is used
|
|
||||||
*/
|
|
||||||
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
|
||||||
// which is determined by the shape of first argument. However, since the
|
|
||||||
// shape is unknown now, these attributes can be not generated auto
|
|
||||||
// dilations_attr = get_attr_ints(node, "dilations",
|
|
||||||
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
|
|
||||||
// 1));
|
|
||||||
// attributes.push_back(dilations_attr)
|
|
||||||
// similar situation for pads, strides in AveragePool
|
|
||||||
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
|
|
||||||
// TODO: fix this after type inference
|
|
||||||
int nOps = node.input().size();
|
|
||||||
|
|
||||||
if (nOps == 2)
|
|
||||||
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
|
|
||||||
else
|
|
||||||
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -50,7 +50,7 @@ if (opName == "Constant")
|
||||||
if (opName == "ConstantOfShape")
|
if (opName == "ConstantOfShape")
|
||||||
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Conv")
|
if (opName == "Conv")
|
||||||
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
return buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "ConvInteger")
|
if (opName == "ConvInteger")
|
||||||
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
if (opName == "ConvTranspose")
|
if (opName == "ConvTranspose")
|
||||||
|
|
|
@ -12,18 +12,19 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
ONNXConvOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXConvOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
ONNXConvOpOperandAdaptor operandAdaptor(operands);
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
@ -32,12 +33,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
||||||
|
|
||||||
auto resultShape = memRefType.getShape();
|
auto resultShape = memRefType.getShape();
|
||||||
auto &inputOperand = operands[0];
|
auto inputOperand = operandAdaptor.X();
|
||||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||||
auto &kernelOperand = operands[1];
|
auto kernelOperand = operandAdaptor.W();
|
||||||
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
|
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
auto biasOperand = operandAdaptor.B();
|
||||||
|
bool hasBias = !biasOperand.getType().isa<NoneType>();
|
||||||
|
|
||||||
// R = ConvNoBias(D, K)
|
// R = Conv(D, K)
|
||||||
//
|
//
|
||||||
// The input/output shapes will look like this:
|
// The input/output shapes will look like this:
|
||||||
//
|
//
|
||||||
|
@ -169,8 +172,23 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// 3.4 Emit inner loop nest.
|
// 3.4 Emit inner loop nest.
|
||||||
innerLoops.createIterateOp();
|
innerLoops.createIterateOp();
|
||||||
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
|
||||||
|
|
||||||
|
// Emit the bias, if needed.
|
||||||
|
if (hasBias) {
|
||||||
|
auto loadResult =
|
||||||
|
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||||
|
SmallVector<Value, 4> biasIndices;
|
||||||
|
biasIndices.emplace_back(kernel);
|
||||||
|
auto loadBias =
|
||||||
|
rewriter.create<LoadOp>(loc, biasOperand, kernel);
|
||||||
|
auto resultWithBias = rewriter.create<MulFOp>(
|
||||||
|
loc, loadResult, loadBias);
|
||||||
|
// Store initializer value into output location.
|
||||||
|
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
||||||
{
|
{
|
||||||
// 4. Emit inner loop body
|
// 4. Emit inner loop body
|
||||||
// R[n][kernel][r1][r2] =
|
// R[n][kernel][r1][r2] =
|
||||||
|
@ -238,5 +256,5 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
void populateLoweringONNXConvOpPattern(
|
void populateLoweringONNXConvOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXConvNoBiasOpLowering>(ctx);
|
patterns.insert<ONNXConvOpLowering>(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1022,14 +1022,18 @@ void ONNXReduceSumOp::inferShapes() {
|
||||||
// - kernelShape: inferred from weight matrix if not defined by user;
|
// - kernelShape: inferred from weight matrix if not defined by user;
|
||||||
// - pads: set to proper value, 0 if not defined by user.
|
// - pads: set to proper value, 0 if not defined by user.
|
||||||
|
|
||||||
void ONNXConvNoBiasOp::inferShapes() {
|
void ONNXConvOp::inferShapes() {
|
||||||
// Generic shape for data input X and weight tensor W:
|
// Generic shape for data input X, weight tensor W, and optional bias B
|
||||||
// X: (N x C x D1 x D2 ... x Dn)
|
// X: (N x C x D1 x D2 ... x Dn)
|
||||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||||
|
// B: (M) Optional
|
||||||
|
|
||||||
|
bool hasBias = !B().getType().isa<NoneType>();
|
||||||
|
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!W().getType().isa<RankedTensorType>())
|
!W().getType().isa<RankedTensorType>() ||
|
||||||
|
(hasBias && !B().getType().isa<RankedTensorType>()))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
|
@ -1047,7 +1051,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
emitError("Weight size not compatible with data size");
|
emitError("Weight size not compatible with data size");
|
||||||
|
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = ONNXConvNoBiasOp::group().getSExtValue();
|
int64_t group = ONNXConvOp::group().getSExtValue();
|
||||||
|
|
||||||
// Check if the attribute actually exists. If it does not then add it.
|
// Check if the attribute actually exists. If it does not then add it.
|
||||||
if (!groupAttr())
|
if (!groupAttr())
|
||||||
|
@ -1058,6 +1062,16 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
xShape[1] != (weightShape[1] * group))
|
xShape[1] != (weightShape[1] * group))
|
||||||
emitError("Channel dimension mismatch");
|
emitError("Channel dimension mismatch");
|
||||||
|
|
||||||
|
// Check the size of bias.
|
||||||
|
if (hasBias) {
|
||||||
|
auto bTx = B().getType().cast<RankedTensorType>();
|
||||||
|
auto bShape = bTx.getShape();
|
||||||
|
if (bShape.size() != 1)
|
||||||
|
emitError("bias should be one dimensional");
|
||||||
|
if (bShape[0] != weightShape[0])
|
||||||
|
emitError("bias should have same dimensions as weight's first dimension");
|
||||||
|
}
|
||||||
|
|
||||||
// Note: the value of the group attribut only impacts the way the
|
// Note: the value of the group attribut only impacts the way the
|
||||||
// computation is carried out and not the actual output size.
|
// computation is carried out and not the actual output size.
|
||||||
|
|
||||||
|
|
|
@ -95,25 +95,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||||
// or outputs. This decision affects only ONNX operations with optional
|
// or outputs. This decision affects only ONNX operations with optional
|
||||||
// arguments not ONNX operations with variadic operands.
|
// arguments not ONNX operations with variadic operands.
|
||||||
|
|
||||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let summary = "ONNX Conv operation with no Bias operand.";
|
|
||||||
let description = [{
|
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
|
||||||
"computes the output."
|
|
||||||
}];
|
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
|
|
||||||
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
|
||||||
OptionalAttr<I64ArrayAttr>:$dilations,
|
|
||||||
DefaultValuedAttr<I64Attr, "1">:$group,
|
|
||||||
OptionalAttr<I64ArrayAttr>:$kernel_shape,
|
|
||||||
OptionalAttr<I64ArrayAttr>:$pads,
|
|
||||||
OptionalAttr<I64ArrayAttr>:$strides);
|
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
|
||||||
}
|
|
||||||
|
|
||||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
|
@ -363,7 +363,8 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConvOp:ONNX_Op<"Conv",
|
def ONNXConvOp:ONNX_Op<"Conv",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Conv operation";
|
let summary = "ONNX Conv operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
"The convolution operator consumes an input tensor and a filter, and"
|
||||||
|
|
|
@ -72,7 +72,7 @@ ArrayAttr insertZerosForNonPaddedDims(
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Rewrite:
|
// Rewrite:
|
||||||
// %0 = onnx.ConvNoBiasOp(%D : tensor<DShape>, %K)
|
// %0 = onnx.Conv(%D : tensor<DShape>, %K)
|
||||||
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
|
||||||
// tensor<OutShape>
|
// tensor<OutShape>
|
||||||
//
|
//
|
||||||
|
@ -80,14 +80,14 @@ ArrayAttr insertZerosForNonPaddedDims(
|
||||||
// %0 = onnx.PadConstantValuePasOp(%D)
|
// %0 = onnx.PadConstantValuePasOp(%D)
|
||||||
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
|
||||||
// tensor<DPaddedShape>
|
// tensor<DPaddedShape>
|
||||||
// %1 = onnx.ConvNoBias(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
|
// %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
|
||||||
// tensor<OutShape>
|
// tensor<OutShape>
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
struct SplitConvOpPattern : public RewritePattern {
|
struct SplitConvOpPattern : public RewritePattern {
|
||||||
SplitConvOpPattern(MLIRContext *context)
|
SplitConvOpPattern(MLIRContext *context)
|
||||||
: RewritePattern(ONNXConvNoBiasOp::getOperationName(),
|
: RewritePattern(ONNXConvOp::getOperationName(),
|
||||||
{ONNXPadConstantValuePadOp::getOperationName(),
|
{ONNXPadConstantValuePadOp::getOperationName(),
|
||||||
ONNXConvNoBiasOp::getOperationName()},
|
ONNXConvOp::getOperationName()},
|
||||||
1, context) {}
|
1, context) {}
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
|
@ -95,7 +95,7 @@ struct SplitConvOpPattern : public RewritePattern {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// If convolution does not use padding then no rewrite is required.
|
// If convolution does not use padding then no rewrite is required.
|
||||||
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
|
||||||
auto padsAttribute = convOp.padsAttr();
|
auto padsAttribute = convOp.padsAttr();
|
||||||
if (!padsAttribute)
|
if (!padsAttribute)
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
@ -155,8 +155,9 @@ struct SplitConvOpPattern : public RewritePattern {
|
||||||
|
|
||||||
SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0);
|
SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0);
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
ONNXConvNoBiasOp newConvOp = rewriter.create<ONNXConvNoBiasOp>(
|
ONNXConvOp newConvOp = rewriter.create<ONNXConvOp>(
|
||||||
loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1],
|
loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1],
|
||||||
|
convOp.getOperands()[2],
|
||||||
convOp.auto_padAttr(), convOp.dilationsAttr(),
|
convOp.auto_padAttr(), convOp.dilationsAttr(),
|
||||||
convOp.groupAttr(), convOp.kernel_shapeAttr(),
|
convOp.groupAttr(), convOp.kernel_shapeAttr(),
|
||||||
rewriter.getI64ArrayAttr(newConvPads),
|
rewriter.getI64ArrayAttr(newConvPads),
|
||||||
|
@ -173,8 +174,8 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
|
results.insert<MaxPoolSingleOutOpPaddingPattern>(context);
|
||||||
}
|
}
|
||||||
/// on the ONNXConvNoBiasOp.
|
/// on the ONNXConvOp.
|
||||||
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
|
void ONNXConvOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<SplitConvOpPattern>(context);
|
results.insert<SplitConvOpPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,7 +113,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.ReduceSum" &&
|
op->getName().getStringRef() != "onnx.ReduceSum" &&
|
||||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||||
op->getName().getStringRef() != "onnx.ConvNoBias" &&
|
op->getName().getStringRef() != "onnx.Conv" &&
|
||||||
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
||||||
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
||||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||||
|
|
|
@ -48,10 +48,12 @@ func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
// CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
||||||
func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK-NEXT: %cst = constant unit
|
||||||
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32>
|
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32>
|
||||||
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: return %1 : tensor<*xf32>
|
// CHECK-NEXT: return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1149,7 +1149,8 @@ func @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_no_pad
|
// CHECK-LABEL: test_conv_no_bias_no_pad
|
||||||
|
@ -1191,8 +1192,55 @@ func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2
|
||||||
// CHECK: return [[RES]] : memref<1x5x27x58xf32>
|
// CHECK: return [[RES]] : memref<1x5x27x58xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_conv_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, tensor<5xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_conv_bias_no_pad
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32>
|
||||||
|
// CHECK: [[CONST0:%.+]] = constant 5 : index
|
||||||
|
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[CONST2:%.+]] = constant 2 : index
|
||||||
|
// CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
|
||||||
|
// CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg3 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg4 = 0 to 5) {
|
||||||
|
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
|
||||||
|
// CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) {
|
||||||
|
// CHECK: store [[CONST1]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
|
||||||
|
// CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
|
||||||
|
// CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg7 = 0 to 2, [[INNER_LOOPS]]#1 -> %arg8 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg9 = 0 to 7) {
|
||||||
|
// CHECK: [[R1PLUSK1:%.+]] = addi %arg5, %arg8 : index
|
||||||
|
// CHECK: [[R2PLUSK2:%.+]] = addi %arg6, %arg9 : index
|
||||||
|
// CHECK: [[DATA:%.+]] = load %arg0[%arg3, %arg7, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x2x32x64xf32>
|
||||||
|
// CHECK: [[KERNEL:%.+]] = load %arg1[%arg4, %arg7, %arg8, %arg9] : memref<5x2x6x7xf32>
|
||||||
|
// CHECK: [[ACC_RES:%.+]] = load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32
|
||||||
|
// CHECK: store [[ADD]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[BIAS1:%.+]] = load [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
|
||||||
|
// CHECK: [[BIAS2:%.+]] = load %arg2[%arg4] : memref<5xf32>
|
||||||
|
// CHECK: [[BIAS3:%.+]] = mulf [[BIAS1]], [[BIAS2]] : f32
|
||||||
|
// CHECK: store [[BIAS3]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return [[RES]] : memref<1x5x27x58xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
|
// CHECK-LABEL: test_conv_no_bias_no_pad_w_group
|
||||||
|
@ -1239,7 +1287,8 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_no_pad_w_strides
|
// CHECK-LABEL: test_conv_no_bias_no_pad_w_strides
|
||||||
|
|
|
@ -140,39 +140,42 @@ func @test_matmul_10(%arg0 : tensor<?x42x32xf32>, %arg1 : tensor<32xf32>) -> ten
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test shape inference for ConvNoBias operation and all its attributes.
|
/// Test shape inference for Conv (first with no bias) operation and all its attributes.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Default and required attributes for 1-D convolution.
|
/// Default and required attributes for 1-D convolution.
|
||||||
|
|
||||||
func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_0
|
// CHECK-LABEL: test_conv_no_bias_0
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<1x5x27xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<1x5x27xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Default and required attributes.
|
/// Default and required attributes.
|
||||||
|
|
||||||
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_1
|
// CHECK-LABEL: test_conv_no_bias_1
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x27x58xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// kernel_shape attribute.
|
/// kernel_shape attribute.
|
||||||
|
|
||||||
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_2
|
// CHECK-LABEL: test_conv_no_bias_2
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x25x56xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,53 +183,58 @@ func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x
|
||||||
/// Use pads to make output size equal to input size by adding K - 1 to the result.
|
/// Use pads to make output size equal to input size by adding K - 1 to the result.
|
||||||
|
|
||||||
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_3
|
// CHECK-LABEL: test_conv_no_bias_3
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// auto_pad set to SAME_UPPER and SAME_LOWER.
|
/// auto_pad set to SAME_UPPER and SAME_LOWER.
|
||||||
|
|
||||||
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_4
|
// CHECK-LABEL: test_conv_no_bias_4
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_5
|
// CHECK-LABEL: test_conv_no_bias_5
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// auto_pad set to VALID.
|
/// auto_pad set to VALID.
|
||||||
|
|
||||||
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_6
|
// CHECK-LABEL: test_conv_no_bias_6
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x27x55xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// With strides attribute.
|
/// With strides attribute.
|
||||||
|
|
||||||
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_7
|
// CHECK-LABEL: test_conv_no_bias_7
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x14x20xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,45 +242,61 @@ func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x
|
||||||
/// The auto_pad will pas as if stride is equal to 1.
|
/// The auto_pad will pas as if stride is equal to 1.
|
||||||
|
|
||||||
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_8
|
// CHECK-LABEL: test_conv_no_bias_8
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x16x22xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// dilations attribute.
|
/// dilations attribute.
|
||||||
|
|
||||||
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_9
|
// CHECK-LABEL: test_conv_no_bias_9
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x22x46xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x22x46xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// dilations attribute with stride.
|
/// dilations attribute with stride.
|
||||||
|
|
||||||
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_10
|
// CHECK-LABEL: test_conv_no_bias_10
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x11x23xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x11x23xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// dilations attribute with auto_pad set to SAME_UPPER.
|
/// dilations attribute with auto_pad set to SAME_UPPER.
|
||||||
|
|
||||||
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_11
|
// CHECK-LABEL: test_conv_no_bias_11
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test convolution with bias input.
|
||||||
|
|
||||||
|
func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_conv_12
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<1x5x27xf32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test shape inference for PadConstantValuePad.
|
/// Test shape inference for PadConstantValuePad.
|
||||||
|
|
Loading…
Reference in New Issue