diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 1dde6cc..fae2806 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -111,6 +111,7 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", [NoSideEffect, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX MaxPool operation with a single output."; let description = [{ "ONNX MaxPool operation with a single output." @@ -195,6 +196,10 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", DefaultValuedAttr:$constant_value, DefaultValuedAttr:$mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + // A build method with the result type deduction. + let builders = [OpBuilder<"Builder *builder, OperationState &state, " + "Value data, ArrayAttr pads, " + "FloatAttr constant_value, StringAttr mode">]; } diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 78ba92b..d6b86b0 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -1095,6 +1095,16 @@ void ONNXPadConstantValuePadOp::inferShapes(){ return; } +void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state, + Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) { + Type outputType = padShapeInferenceHelper(data, pads); + if (!outputType) { + auto elementType = data.getType().cast().getElementType(); + outputType = UnrankedTensorType::get(elementType); + } + build(builder, state, outputType, data, pads, constant_value, mode); +} + //===----------------------------------------------------------------------===// // Unsqueeze diff --git a/src/main.cpp b/src/main.cpp index 8893382..9500d31 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -125,6 +125,7 @@ int main(int argc, char *argv[]) { pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createShapeInferencePass()); if (emissionTarget >= EmitMLIR) { pm.addPass(mlir::createLowerToKrnlPass()); diff --git a/src/pass/onnx_rewrite.cpp b/src/pass/onnx_rewrite.cpp index dcc4dc1..79bf712 100644 --- a/src/pass/onnx_rewrite.cpp +++ b/src/pass/onnx_rewrite.cpp @@ -17,6 +17,56 @@ using namespace mlir; namespace { + +// Check whether an ArrayAttr contains non-zero values or not. +bool hasNonZeroInArrayAttr(ArrayAttr attrs) { + bool allZeros = true; + if (attrs) { + for (auto attr: attrs.getValue()) { + if (attr.cast().getInt() > 0) { + allZeros = false; + break; + } + } + } + return !allZeros; +} + +// Create an ArrayAttr of IntergerAttr(s) of zero values. +// This function is used for padding attribute in MaxPoolSingleOut. +ArrayAttr createArrayAttrOfZeros( + PatternRewriter &rewriter, ArrayAttr origAttrs) { + int nElements = origAttrs.getValue().size(); + SmallVector vals(nElements, 0); + return rewriter.getI64ArrayAttr(vals); +} + +// Pad a ArrayAttr with zeros. +// +// pads = [B1, B2, ... Bk, E1, E2, ..., Ek] +// +// becomes: +// +// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek] +// |_____| |_____| +// nZeros nZeros +// +// This function is used for padding attribute in MaxPoolSingleOut. +ArrayAttr insertZerosForNonPaddedDims( + PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { + int nDims = (int) origAttrs.getValue().size() / 2; + int nElements = (nDims + extensionLength) * 2; + SmallVector pads(nElements, 0); + for (int i = 0; i < nDims; ++i) { + int64_t beginPad = origAttrs.getValue()[i].cast().getInt(); + int64_t endPad = + origAttrs.getValue()[nDims + i].cast().getInt(); + pads[i + extensionLength] = beginPad; + pads[nDims + extensionLength + i + extensionLength] = endPad; + } + return rewriter.getI64ArrayAttr(pads); +} + /// Include the patterns defined in the Declarative Rewrite framework. #include "src/onnx_rewrite.inc" @@ -118,6 +168,11 @@ struct SplitConvOpPattern : public RewritePattern { }; } // end anonymous namespace +/// on the ONNXMaxPoolSingleOutOp. +void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} /// on the ONNXReduceSumSquareOp. void ONNXConvNoBiasOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { diff --git a/src/pass/onnx_rewrite.td b/src/pass/onnx_rewrite.td index ab73989..c3b5490 100644 --- a/src/pass/onnx_rewrite.td +++ b/src/pass/onnx_rewrite.td @@ -24,4 +24,68 @@ include "dialect/onnx/onnx.td" /// dag benefitsAdded = (addBenefit 0) /// >; +// Create a StringAttr from a string. +class StringAttrOfValue: + NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">; + +// Create a FloatAttr from an interger value. +// It seems Table-gen does not support `float` type, so we can not pass a float value. +class FloatAttrOfValue: + NativeCodeCall<"FloatAttr::get($0.getType().cast().getElementType(), " # val # ")">; + +// Create an ArrayAttr of IntergerAttr(s) of zero values. +// This function is used for padding attribute in MaxPoolSingleOut. +def createArrayAttrOfZerosFrom: + NativeCodeCall<"createArrayAttrOfZeros($_builder, $0)">; + +// Pad a ArrayAttr with zeros. +// +// pads = [B1, B2, ... Bk, E1, E2, ..., Ek] +// +// becomes: +// +// pads = [0,... 0, B1, B2, ... Bk, 0,... 0, E1, E2, ..., Ek] +// |_____| |_____| +// nZeros nZeros +// +// This function is used for padding attribute in MaxPoolSingleOut. +class insertZerosForNonPaddedDims: + NativeCodeCall<"insertZerosForNonPaddedDims($_builder, $0," + # extensionLength # ")">; + +// Check whether an ArrayAttr contains non-zero values or not. +def HasNonZeroInArrayAttr: Constraint, + "has non-zero elements">; + +//===----------------------------------------------------------------------===// +// Rewrite: +// %0 = onnx.MaxPoolSingleOutOp(%D : tensor) +// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> +// tensor +// +// as: +// %0 = onnx.PadConstantValuePadOp(%D) +// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> +// tensor +// %1 = onnx.MaxPoolSingleOut(%0 : tensor) {pads = [0, ..., 0]} -> +// tensor +//===----------------------------------------------------------------------===// + +def MaxPoolSingleOutOpPaddingPattern: Pat< + (ONNXMaxPoolSingleOutOp:$res + $x, + $auto_pad, $ceil_mode, $dilation, $kernel_shape, + $pads, + $storage_order, $strides), + (ONNXMaxPoolSingleOutOp + (ONNXPadConstantValuePadOp $x, + (insertZerosForNonPaddedDims<2> $pads), + (FloatAttrOfValue<0> $res), + (StringAttrOfValue<"constant">)), + $auto_pad, $ceil_mode, $dilation, $kernel_shape, + (createArrayAttrOfZerosFrom $pads), + $storage_order, $strides), + [(HasNonZeroInArrayAttr:$pads)] +>; + #endif // ONNX_REWRITE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index bbc2686..78193cb 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -77,3 +77,23 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1 // return [[GEMM]] : tensor<*xf32> } +//CHECK-LABEL: @test_maxpoolsingleout_split(%{{.*}}: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> { +func @test_maxpoolsingleout_split(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> + "std.return"(%0) : (tensor<5x8x32x39xf32>) -> () + + // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<5x5x32x32xf32>) -> tensor<5x8x32x39xf32> + // CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<5x8x32x39xf32>) -> tensor<5x8x32x39xf32> + // CHECK-NEXT: return %1 : tensor<5x8x32x39xf32> +} + +//CHECK-LABEL: @test_maxpoolsingleout_split_unknown_dims(%{{.*}}: tensor<*xf32>) -> tensor<*xf32> { +func @test_maxpoolsingleout_split_unknown_dims(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-NEXT: return %1 : tensor<*xf32> +} +