diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 41c644d..c09d910 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -111,6 +111,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", [NoSideEffect, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX Conv operation with no Bias operand."; let description = [{ "The convolution operator consumes an input tensor and a filter, and" diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index a3b6a67..efcc34b 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -39,14 +39,14 @@ def GemmTransB : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">; // onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z) def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), (ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), - [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>; + [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>; // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), - (replaceWithValue $arg)>; + (replaceWithValue $arg)>; def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3), (ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3), - [(HasOneUse $res)]>; + [(HasOneUse $res)]>; #endif // ONNX_COMBINE diff --git a/src/pass/onnx_rewrite.cpp b/src/pass/onnx_rewrite.cpp index 2f172de..bf2527b 100644 --- a/src/pass/onnx_rewrite.cpp +++ b/src/pass/onnx_rewrite.cpp @@ -263,6 +263,103 @@ struct ReduceSumSquareOpPattern : public RewritePattern { return matchSuccess(); }; }; + +//===----------------------------------------------------------------------===// +// Rewrite: +// %0 = onnx.ConvNoBiasOp(%D : tensor, %K) +// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> +// tensor +// +// as: +// %0 = onnx.PadConstantValuePasOp(%D) +// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> +// tensor +// %1 = onnx.ConvNoBias(%0 : tensor, %K) {pads = [0, ..., 0]} -> +// tensor +//===----------------------------------------------------------------------===// +struct SplitConvOpPattern : public RewritePattern { + SplitConvOpPattern(MLIRContext *context) + : RewritePattern(ONNXConvNoBiasOp::getOperationName(), + {ONNXPadConstantValuePadOp::getOperationName(), + ONNXConvNoBiasOp::getOperationName()}, + 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // If convolution does not use padding then no rewrite is required. + ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + auto padsAttribute = convOp.padsAttr(); + if (!padsAttribute) + return matchFailure(); + + // If auto_pad is VALID then no padding happens and no rewrite isrequired. + auto autoPad = convOp.auto_pad(); + if (autoPad == "VALID") + return matchFailure(); + + auto data = op->getOperands()[0]; + auto inputShape = data.getType().cast().getShape(); + + // Dimensionality of the input: + // inputRank + // |----------------------| + // D : (N x C x D1 x D2 x ... DK) + // |______________| + // inputDims + // + int64_t inputRank = inputShape.size(); + int64_t inputDims = inputRank - 2; + + // If all pads values are equal to zero then no rewrite is required. + bool allZeros = true; + for (auto padsValue : padsAttribute.getValue()) { + if (padsValue.cast().getInt() > 0) { + allZeros = false; + break; + } + } + + if (allZeros) + return matchFailure(); + + // Create padding vector for the explicit padding op attribute. + SmallVector pads(2 * inputRank, 0); + SmallVector outPaddedShape(inputRank, 0); + outPaddedShape[0] = inputShape[0]; + outPaddedShape[1] = inputShape[1]; + for (int i = 0; i < inputDims; ++i) { + int64_t beginPad = + padsAttribute.getValue()[i].cast().getInt(); + int64_t endPad = + padsAttribute.getValue()[inputDims + i].cast().getInt(); + pads[i + 2] = beginPad; + pads[inputRank + i + 2] = endPad; + outPaddedShape[i + 2] += beginPad + inputShape[i + 2] + endPad; + } + + // Create padding operation. + auto inputElemType = data.getType().cast().getElementType(); + ONNXPadConstantValuePadOp paddingOp = + rewriter.create( + loc, RankedTensorType::get(outPaddedShape, inputElemType), data, + rewriter.getI64ArrayAttr(pads), FloatAttr::get(inputElemType, 0), + StringAttr::get("constant", loc->getContext())); + + SmallVector newConvPads(2 * inputDims, 0); + auto tensorType = (*op->result_type_begin()).cast(); + ONNXConvNoBiasOp newConvOp = rewriter.create( + loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], + convOp.auto_padAttr(), convOp.dilationsAttr(), + convOp.groupAttr(), convOp.kernel_shapeAttr(), + rewriter.getI64ArrayAttr(newConvPads), + convOp.stridesAttr()); + + rewriter.replaceOp(op, newConvOp.getResult()); + return matchSuccess(); + }; +}; } // end anonymous namespace /// on the ONNXReduceL1Op. @@ -293,3 +390,9 @@ void ONNXReduceSumSquareOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } + +/// on the ONNXReduceSumSquareOp. +void ONNXConvNoBiasOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 78825c8..bf07fdc 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -92,3 +92,12 @@ func @test_constant_pad(%arg0 : tensor) -> tensor<*xf32> { %2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor, tensor)-> tensor<*xf32> "std.return"(%2) : (tensor<*xf32>) -> () } + +// CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> { +func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { + %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + // CHECK-NEXT: %0 = "onnx.PadConstatValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> + // CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> + // CHECK-NEXT: return %1 : tensor<*xf32> +}