diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index d43ef7d..ba5de60 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -2859,6 +2859,60 @@ ONNX PRelu operation 1. `Y`: memref of any type values or tensor of any type values +### onnx.PadConstantValue (ONNXPadConstantValueOp) +ONNX Pad operation with constant padding value + +#### Description: + + +"this operation is introduced to handle situation" + " in which the padding value is a constant. + " The value will become an attribute." + "This operation is also used to handle the optional value input is missing and the default value 0." + "is used." + +#### Operands: + +1. `data`: memref of any type values or tensor of any type values +1. `pads`: memref of any type values or tensor of any type values + +#### Attributes: + +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `constant_value` | `FloatAttr` | 32-bit float attribute attribute | +| `mode` | `StringAttr` | string attribute attribute | + +#### Results: + +1. `output`: memref of any type values or tensor of any type values + +### onnx.PadConstatValuePad (ONNXPadConstantValuePadOp) +ONNX Pad operation with constant padding value + +#### Description: + + +"this operation is introduced to handle situation" + " in which the padding value and padding are constants" + "They will become attributes." + +#### Operands: + +1. `data`: memref of any type values or tensor of any type values + +#### Attributes: + +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `constant_value` | `FloatAttr` | 32-bit float attribute attribute | +| `mode` | `StringAttr` | string attribute attribute | + +#### Results: + +1. `output`: memref of any type values or tensor of any type values + ### onnx.Pad (ONNXPadOp) ONNX Pad operation diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 6bb0d7e..04d456b 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -37,6 +37,7 @@ special_op_handler = dict([ ("Conv", "ImportNodeConv"), ("MaxPool", "ImportNodeMaxPool"), ("Gemm", "ImportNodeGemm"), + ("Pad", "ImportNodePad"), #("Transpose", "ImportNodeTranspose") ]) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 10fc2c9..29b7798 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -446,6 +446,18 @@ private: } } + /*! + * Special handle for Pad operations. + */ + void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { + int nOps = node.input().size(); + if (nOps == 2) { + ImportNodeOneOut(node, 2, nOut); + } else { + ImportNodeOneOut(node, nIn, nOut); + } + } + void ImportNode(const onnx::NodeProto &node) { std::vector inputs; for (const auto &item : node.input()) { diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index d7dea0f..e6d97c5 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -178,7 +178,7 @@ }else if (OpName == "PRelu") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Pad") { - ImportNodeOneOut(node, 3, 1); + ImportNodePad(node, 3, 1); }else if (OpName == "Pow") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "QLinearConv") { diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 340c910..9ad9cde 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -145,4 +145,36 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); } +def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", + [NoSideEffect ]> { + let summary = "ONNX Pad operation with constant padding value"; + let hasCanonicalizer = 1; + let description = [{ "this operation is introduced to handle situation" + " in which the padding value is a constant. + " The value will become an attribute." + "This operation is also used to handle the optional value input is missing and the default value 0." + "is used." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads, + DefaultValuedAttr:$constant_value, + DefaultValuedAttr:$mode); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); +} + +def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstatValuePad", + [NoSideEffect ]> { + let summary = "ONNX Pad operation with constant padding value"; + let description = [{ "this operation is introduced to handle situation" + " in which the padding value and padding are constants" + "They will become attributes." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + I64ArrayAttr:$pads, + DefaultValuedAttr:$constant_value, + DefaultValuedAttr:$mode); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); +} + + #endif // ONNX_OPS diff --git a/src/pass/onnx_combine.cpp b/src/pass/onnx_combine.cpp index 22960a0..31eb2d6 100644 --- a/src/pass/onnx_combine.cpp +++ b/src/pass/onnx_combine.cpp @@ -33,3 +33,9 @@ void ONNXIdentityOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); } + +///on the ONNXPadConstantValueOp. +void ONNXPadConstantValueOp::getCanonicalizationPatterns( + OwningRewritePatternList& result, MLIRContext* context) { + result.insert(context); +} diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index cc3abfa..a3b6a67 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -45,4 +45,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), (replaceWithValue $arg)>; +def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3), + (ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3), + [(HasOneUse $res)]>; + #endif // ONNX_COMBINE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 79c3013..0233a28 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -84,3 +84,11 @@ func @test_reducesumsquare(%arg0 : tensor) -> tensor<*xf32> { // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> } + +// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor) -> tensor<*xf32> { +func @test_constant_pad(%arg0 : tensor) -> tensor<*xf32> { + // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstatValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor) -> tensor<*xf32> + %0 ="onnx.Constant"() {value=[0, 2, 0, 0]} : ()-> tensor + %2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor, tensor)-> tensor<*xf32> + "std.return"(%2) : (tensor<*xf32>) -> () +}