Create constant pad (#75)
* handle pad op which does not have the optional third argment * rewrite PadConstantValue with constant pad into PadConstantValuePad * add test for PadConstantValuePad * update onnx.md
This commit is contained in:
parent
094be4f37a
commit
49dae74eab
|
@ -2859,6 +2859,60 @@ ONNX PRelu operation
|
|||
|
||||
1. `Y`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.PadConstantValue (ONNXPadConstantValueOp)
|
||||
ONNX Pad operation with constant padding value
|
||||
|
||||
#### Description:
|
||||
|
||||
|
||||
"this operation is introduced to handle situation"
|
||||
" in which the padding value is a constant.
|
||||
" The value will become an attribute."
|
||||
"This operation is also used to handle the optional value input is missing and the default value 0."
|
||||
"is used."
|
||||
|
||||
#### Operands:
|
||||
|
||||
1. `data`: memref of any type values or tensor of any type values
|
||||
1. `pads`: memref of any type values or tensor of any type values
|
||||
|
||||
#### Attributes:
|
||||
|
||||
| Attribute | MLIR Type | Description |
|
||||
| :-------: | :-------: | ----------- |
|
||||
| `constant_value` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `mode` | `StringAttr` | string attribute attribute |
|
||||
|
||||
#### Results:
|
||||
|
||||
1. `output`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.PadConstatValuePad (ONNXPadConstantValuePadOp)
|
||||
ONNX Pad operation with constant padding value
|
||||
|
||||
#### Description:
|
||||
|
||||
|
||||
"this operation is introduced to handle situation"
|
||||
" in which the padding value and padding are constants"
|
||||
"They will become attributes."
|
||||
|
||||
#### Operands:
|
||||
|
||||
1. `data`: memref of any type values or tensor of any type values
|
||||
|
||||
#### Attributes:
|
||||
|
||||
| Attribute | MLIR Type | Description |
|
||||
| :-------: | :-------: | ----------- |
|
||||
| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute |
|
||||
| `constant_value` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `mode` | `StringAttr` | string attribute attribute |
|
||||
|
||||
#### Results:
|
||||
|
||||
1. `output`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.Pad (ONNXPadOp)
|
||||
ONNX Pad operation
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ special_op_handler = dict([
|
|||
("Conv", "ImportNodeConv"),
|
||||
("MaxPool", "ImportNodeMaxPool"),
|
||||
("Gemm", "ImportNodeGemm"),
|
||||
("Pad", "ImportNodePad"),
|
||||
#("Transpose", "ImportNodeTranspose")
|
||||
])
|
||||
|
||||
|
|
|
@ -446,6 +446,18 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for Pad operations.
|
||||
*/
|
||||
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
||||
int nOps = node.input().size();
|
||||
if (nOps == 2) {
|
||||
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
|
||||
} else {
|
||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
|
||||
}
|
||||
}
|
||||
|
||||
void ImportNode(const onnx::NodeProto &node) {
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (const auto &item : node.input()) {
|
||||
|
|
|
@ -178,7 +178,7 @@
|
|||
}else if (OpName == "PRelu") {
|
||||
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
|
||||
}else if (OpName == "Pad") {
|
||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1);
|
||||
ImportNodePad(node, 3, 1);
|
||||
}else if (OpName == "Pow") {
|
||||
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
|
||||
}else if (OpName == "QLinearConv") {
|
||||
|
|
|
@ -145,4 +145,36 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
|||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
|
||||
[NoSideEffect ]> {
|
||||
let summary = "ONNX Pad operation with constant padding value";
|
||||
let hasCanonicalizer = 1;
|
||||
let description = [{ "this operation is introduced to handle situation"
|
||||
" in which the padding value is a constant.
|
||||
" The value will become an attribute."
|
||||
"This operation is also used to handle the optional value input is missing and the default value 0."
|
||||
"is used."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads,
|
||||
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||
}
|
||||
|
||||
def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstatValuePad",
|
||||
[NoSideEffect ]> {
|
||||
let summary = "ONNX Pad operation with constant padding value";
|
||||
let description = [{ "this operation is introduced to handle situation"
|
||||
" in which the padding value and padding are constants"
|
||||
"They will become attributes."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||
I64ArrayAttr:$pads,
|
||||
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||
}
|
||||
|
||||
|
||||
#endif // ONNX_OPS
|
||||
|
|
|
@ -33,3 +33,9 @@ void ONNXIdentityOp::getCanonicalizationPatterns(
|
|||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<IdentityEliminationPattern>(context);
|
||||
}
|
||||
|
||||
///on the ONNXPadConstantValueOp.
|
||||
void ONNXPadConstantValueOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& result, MLIRContext* context) {
|
||||
result.insert<ConstantPadPattern>(context);
|
||||
}
|
||||
|
|
|
@ -45,4 +45,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
|||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
|
||||
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
||||
[(HasOneUse $res)]>;
|
||||
|
||||
#endif // ONNX_COMBINE
|
||||
|
|
|
@ -84,3 +84,11 @@ func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
|
|||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
|
||||
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstatValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>
|
||||
%0 ="onnx.Constant"() {value=[0, 2, 0, 0]} : ()-> tensor<?xi64>
|
||||
%2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor<?x?xf32>, tensor<?xi64>)-> tensor<*xf32>
|
||||
"std.return"(%2) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue