From 653fa69102755decaf2abafee92d0fad6e694701 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 26 Mar 2020 11:03:19 -0400 Subject: [PATCH] Unify Conv implementation (#54) * fixed readme for new git repo * conv with bias as an optional input --- doc/Dialects/onnx.md | 29 --------- doc/gen_doc.py | 5 +- src/Builder/FrontendDialectTransformer.cpp | 24 ------- src/Builder/OpBuildTable.inc | 2 +- src/Conversion/ONNXToKrnl/NN/Conv.cpp | 36 +++++++--- src/Dialect/ONNX/ONNXOps.cpp | 22 +++++-- src/Dialect/ONNX/ONNXOps.td | 19 ------ src/Dialect/ONNX/ONNXOps.td.inc | 3 +- src/Transform/ONNX/ONNXRewrite.cpp | 17 ++--- src/Transform/ONNX/ShapeInferencePass.cpp | 2 +- test/mlir/onnx/onnx_canonicalization.mlir | 6 +- test/mlir/onnx/onnx_lowering.mlir | 55 +++++++++++++++- test/mlir/onnx/onnx_shape_inference.mlir | 76 ++++++++++++++-------- 13 files changed, 166 insertions(+), 130 deletions(-) diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index b7cb820..6170f9e 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -636,35 +636,6 @@ ONNX ConvInteger operation 1. `y`: memref of any type values or tensor of any type values -### onnx.ConvNoBias (ONNXConvNoBiasOp) -ONNX Conv operation with no Bias operand. - -#### Description: - - -"The convolution operator consumes an input tensor and a filter, and" -"computes the output." - -#### Operands: - -1. `X`: memref of any type values or tensor of any type values -1. `W`: memref of any type values or tensor of any type values - -#### Attributes: - -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `auto_pad` | `StringAttr` | string attribute attribute | -| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `group` | `IntegerAttr` | 64-bit integer attribute attribute | -| `kernel_shape` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute | - -#### Results: - -1. `o_Y`: memref of any type values or tensor of any type values - ### onnx.Conv (ONNXConvOp) ONNX Conv operation diff --git a/doc/gen_doc.py b/doc/gen_doc.py index f9b99ab..61e06b4 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -32,7 +32,6 @@ special_attr_defaults = dict([ # Special operation importing handlers. special_op_handler = dict([ - ("Conv", "ImportNodeConv"), ("MaxPool", "ImportNodeMaxPool"), ("BatchNormalization", "ImportNodeBatchNormalization"), ("Pad", "ImportNodePad"), @@ -47,11 +46,11 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'AveragePool', 'Abs' + 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv' ] # Operations supporting canonicalization. -OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm'] +OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] # Operations who have operands that, if produced by constant operations, should # be promoted to become an attribute (via attribute promotion). diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 77691a1..91d25a6 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -303,30 +303,6 @@ private: buildOutputAndOperation(node, inputs, nIn, nOut); } - /*! - * Special handle for Conv operations. - * c++ does not allow template specialization inside a class scope - * a specialized function is used - */ - void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) { - // Conv has attribute dilations, kernel_shape, pads, the default value of - // which is determined by the shape of first argument. However, since the - // shape is unknown now, these attributes can be not generated auto - // dilations_attr = get_attr_ints(node, "dilations", - // std::vector(inputs[0]->getType().cast.getDims()-2, - // 1)); - // attributes.push_back(dilations_attr) - // similar situation for pads, strides in AveragePool - // axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool - // TODO: fix this after type inference - int nOps = node.input().size(); - - if (nOps == 2) - buildOperation(node, nOps, nOut); - else - buildOperation(node, nOps, nOut); - } - /*! * Special handle for MaxPool operations. */ diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 32328e8..d23416a 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -50,7 +50,7 @@ if (opName == "Constant") if (opName == "ConstantOfShape") return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Conv") - return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "ConvInteger") return buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); if (opName == "ConvTranspose") diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index d75f6f7..db29107 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -12,18 +12,19 @@ using namespace mlir; -struct ONNXConvNoBiasOpLowering : public ConversionPattern { - ONNXConvNoBiasOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} +struct ONNXConvOpLowering : public ConversionPattern { + ONNXConvOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConvOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); + ONNXConvOpOperandAdaptor operandAdaptor(operands); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); - ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + ONNXConvOp convOp = llvm::dyn_cast(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); @@ -32,12 +33,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { memRefType, loc, rewriter, insertDealloc, {operands[0]}); auto resultShape = memRefType.getShape(); - auto &inputOperand = operands[0]; + auto inputOperand = operandAdaptor.X(); auto inputShape = inputOperand.getType().cast().getShape(); - auto &kernelOperand = operands[1]; + auto kernelOperand = operandAdaptor.W(); auto kernelShape = kernelOperand.getType().cast().getShape(); + auto biasOperand = operandAdaptor.B(); + bool hasBias = !biasOperand.getType().isa(); - // R = ConvNoBias(D, K) + // R = Conv(D, K) // // The input/output shapes will look like this: // @@ -169,8 +172,23 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { // 3.4 Emit inner loop nest. innerLoops.createIterateOp(); - rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); + // Emit the bias, if needed. + if (hasBias) { + auto loadResult = + rewriter.create(loc, alloc, resultIndices); + SmallVector biasIndices; + biasIndices.emplace_back(kernel); + auto loadBias = + rewriter.create(loc, biasOperand, kernel); + auto resultWithBias = rewriter.create( + loc, loadResult, loadBias); + // Store initializer value into output location. + rewriter.create(loc, resultWithBias, alloc, resultIndices); + } + + // + rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); { // 4. Emit inner loop body // R[n][kernel][r1][r2] = @@ -238,5 +256,5 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { void populateLoweringONNXConvOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index d976e4c..6b6a29b 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1022,14 +1022,18 @@ void ONNXReduceSumOp::inferShapes() { // - kernelShape: inferred from weight matrix if not defined by user; // - pads: set to proper value, 0 if not defined by user. -void ONNXConvNoBiasOp::inferShapes() { - // Generic shape for data input X and weight tensor W: +void ONNXConvOp::inferShapes() { + // Generic shape for data input X, weight tensor W, and optional bias B // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) + // B: (M) Optional + + bool hasBias = !B().getType().isa(); // Cannot infer shape if no shape exists. if (!X().getType().isa() || - !W().getType().isa()) + !W().getType().isa() || + (hasBias && !B().getType().isa())) return; auto xTy = X().getType().cast(); @@ -1047,7 +1051,7 @@ void ONNXConvNoBiasOp::inferShapes() { emitError("Weight size not compatible with data size"); // Group is a required attribute and should have default value of 1. - int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); + int64_t group = ONNXConvOp::group().getSExtValue(); // Check if the attribute actually exists. If it does not then add it. if (!groupAttr()) @@ -1058,6 +1062,16 @@ void ONNXConvNoBiasOp::inferShapes() { xShape[1] != (weightShape[1] * group)) emitError("Channel dimension mismatch"); + // Check the size of bias. + if (hasBias) { + auto bTx = B().getType().cast(); + auto bShape = bTx.getShape(); + if (bShape.size() != 1) + emitError("bias should be one dimensional"); + if (bShape[0] != weightShape[0]) + emitError("bias should have same dimensions as weight's first dimension"); + } + // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. diff --git a/src/Dialect/ONNX/ONNXOps.td b/src/Dialect/ONNX/ONNXOps.td index e48973b..27a8882 100644 --- a/src/Dialect/ONNX/ONNXOps.td +++ b/src/Dialect/ONNX/ONNXOps.td @@ -95,25 +95,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { // or outputs. This decision affects only ONNX operations with optional // arguments not ONNX operations with variadic operands. -def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", - [NoSideEffect, DeclareOpInterfaceMethods]> { - let hasCanonicalizer = 1; - let summary = "ONNX Conv operation with no Bias operand."; - let description = [{ - "The convolution operator consumes an input tensor and a filter, and" - "computes the output." - }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, - DefaultValuedAttr:$auto_pad, - OptionalAttr:$dilations, - DefaultValuedAttr:$group, - OptionalAttr:$kernel_shape, - OptionalAttr:$pads, - OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); -} - def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", [NoSideEffect, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 99b1f5f..e92e2e4 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -363,7 +363,8 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", } def ONNXConvOp:ONNX_Op<"Conv", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX Conv operation"; let description = [{ "The convolution operator consumes an input tensor and a filter, and" diff --git a/src/Transform/ONNX/ONNXRewrite.cpp b/src/Transform/ONNX/ONNXRewrite.cpp index 3985bdf..d7619f0 100644 --- a/src/Transform/ONNX/ONNXRewrite.cpp +++ b/src/Transform/ONNX/ONNXRewrite.cpp @@ -72,7 +72,7 @@ ArrayAttr insertZerosForNonPaddedDims( //===----------------------------------------------------------------------===// // Rewrite: -// %0 = onnx.ConvNoBiasOp(%D : tensor, %K) +// %0 = onnx.Conv(%D : tensor, %K) // {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> // tensor // @@ -80,14 +80,14 @@ ArrayAttr insertZerosForNonPaddedDims( // %0 = onnx.PadConstantValuePasOp(%D) // {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> // tensor -// %1 = onnx.ConvNoBias(%0 : tensor, %K) {pads = [0, ..., 0]} -> +// %1 = onnx.Conv(%0 : tensor, %K) {pads = [0, ..., 0]} -> // tensor //===----------------------------------------------------------------------===// struct SplitConvOpPattern : public RewritePattern { SplitConvOpPattern(MLIRContext *context) - : RewritePattern(ONNXConvNoBiasOp::getOperationName(), + : RewritePattern(ONNXConvOp::getOperationName(), {ONNXPadConstantValuePadOp::getOperationName(), - ONNXConvNoBiasOp::getOperationName()}, + ONNXConvOp::getOperationName()}, 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, @@ -95,7 +95,7 @@ struct SplitConvOpPattern : public RewritePattern { auto loc = op->getLoc(); // If convolution does not use padding then no rewrite is required. - ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + ONNXConvOp convOp = llvm::dyn_cast(op); auto padsAttribute = convOp.padsAttr(); if (!padsAttribute) return matchFailure(); @@ -155,8 +155,9 @@ struct SplitConvOpPattern : public RewritePattern { SmallVector newConvPads(2 * inputDims, 0); auto tensorType = (*op->result_type_begin()).cast(); - ONNXConvNoBiasOp newConvOp = rewriter.create( + ONNXConvOp newConvOp = rewriter.create( loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], + convOp.getOperands()[2], convOp.auto_padAttr(), convOp.dilationsAttr(), convOp.groupAttr(), convOp.kernel_shapeAttr(), rewriter.getI64ArrayAttr(newConvPads), @@ -173,8 +174,8 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } -/// on the ONNXConvNoBiasOp. -void ONNXConvNoBiasOp::getCanonicalizationPatterns( +/// on the ONNXConvOp. +void ONNXConvOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 4c59625..e16fe89 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -113,7 +113,7 @@ public: op->getName().getStringRef() != "onnx.ReduceSum" && op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Sqrt" && - op->getName().getStringRef() != "onnx.ConvNoBias" && + op->getName().getStringRef() != "onnx.Conv" && op->getName().getStringRef() != "onnx.PadConstantPad" && op->getName().getStringRef() != "onnx.PadConstantValuePad" && op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 1308567..a839253 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -48,10 +48,12 @@ func @test_constant_pad(%arg0 : tensor) -> tensor<*xf32> { // CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> { func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + // CHECK-NEXT: %cst = constant unit // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> - // CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> + // CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> // CHECK-NEXT: return %1 : tensor<*xf32> } diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index df5239c..0b7451f 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1149,7 +1149,8 @@ func @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32 } func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_no_pad @@ -1191,8 +1192,55 @@ func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2 // CHECK: return [[RES]] : memref<1x5x27x58xf32> } +func @test_conv_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, tensor<5xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_conv_bias_no_pad + // CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32> + // CHECK: [[CONST0:%.+]] = constant 5 : index + // CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CONST2:%.+]] = constant 2 : index + // CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg3 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg4 = 0 to 5) { + // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) { + // CHECK: store [[CONST1]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg7 = 0 to 2, [[INNER_LOOPS]]#1 -> %arg8 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg9 = 0 to 7) { + // CHECK: [[R1PLUSK1:%.+]] = addi %arg5, %arg8 : index + // CHECK: [[R2PLUSK2:%.+]] = addi %arg6, %arg9 : index + // CHECK: [[DATA:%.+]] = load %arg0[%arg3, %arg7, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x2x32x64xf32> + // CHECK: [[KERNEL:%.+]] = load %arg1[%arg4, %arg7, %arg8, %arg9] : memref<5x2x6x7xf32> + // CHECK: [[ACC_RES:%.+]] = load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 + // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: } + // CHECK: [[BIAS1:%.+]] = load [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[BIAS2:%.+]] = load %arg2[%arg4] : memref<5xf32> + // CHECK: [[BIAS3:%.+]] = mulf [[BIAS1]], [[BIAS2]] : f32 + // CHECK: store [[BIAS3]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x5x27x58xf32> +} + func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_no_pad_w_group @@ -1239,7 +1287,8 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te } func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_no_pad_w_strides diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index e9ece82..02cf415 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -140,39 +140,42 @@ func @test_matmul_10(%arg0 : tensor, %arg1 : tensor<32xf32>) -> ten } //===----------------------------------------------------------------------===// -/// Test shape inference for ConvNoBias operation and all its attributes. +/// Test shape inference for Conv (first with no bias) operation and all its attributes. //===----------------------------------------------------------------------===// /// Default and required attributes for 1-D convolution. func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_0 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<1x5x27xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<1x5x27xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> } /// Default and required attributes. func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_1 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x27x58xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32> } /// kernel_shape attribute. func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_2 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x25x56xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32> } @@ -180,53 +183,58 @@ func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x /// Use pads to make output size equal to input size by adding K - 1 to the result. func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_3 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } /// auto_pad set to SAME_UPPER and SAME_LOWER. func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_4 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_5 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> } /// auto_pad set to VALID. func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_6 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x27x55xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32> } /// With strides attribute. func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_7 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x14x20xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32> } @@ -234,45 +242,61 @@ func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x /// The auto_pad will pas as if stride is equal to 1. func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_8 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x16x22xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32> } /// dilations attribute. func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_9 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x22x46xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x22x46xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32> } /// dilations attribute with stride. func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_conv_no_bias_10 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x11x23xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x11x23xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32> } /// dilations attribute with auto_pad set to SAME_UPPER. func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { - %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + %cst = constant unit + %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () -} + // CHECK-LABEL: test_conv_no_bias_11 - // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> +} + +// Test convolution with bias input. + +func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_conv_12 + // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<1x5x27xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> +} //===----------------------------------------------------------------------===// /// Test shape inference for PadConstantValuePad.