Create constant pad (#75)

* handle pad op which does not have the optional third argment

* rewrite PadConstantValue with constant pad into PadConstantValuePad

* add test for PadConstantValuePad

* update onnx.md
This commit is contained in:
chentong319 2020-02-11 15:32:01 -05:00 committed by GitHub
parent 094be4f37a
commit 49dae74eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 118 additions and 1 deletions

View File

@ -2859,6 +2859,60 @@ ONNX PRelu operation
1. `Y`: memref of any type values or tensor of any type values
### onnx.PadConstantValue (ONNXPadConstantValueOp)
ONNX Pad operation with constant padding value
#### Description:
"this operation is introduced to handle situation"
" in which the padding value is a constant.
" The value will become an attribute."
"This operation is also used to handle the optional value input is missing and the default value 0."
"is used."
#### Operands:
1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `constant_value` | `FloatAttr` | 32-bit float attribute attribute |
| `mode` | `StringAttr` | string attribute attribute |
#### Results:
1. `output`: memref of any type values or tensor of any type values
### onnx.PadConstatValuePad (ONNXPadConstantValuePadOp)
ONNX Pad operation with constant padding value
#### Description:
"this operation is introduced to handle situation"
" in which the padding value and padding are constants"
"They will become attributes."
#### Operands:
1. `data`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `constant_value` | `FloatAttr` | 32-bit float attribute attribute |
| `mode` | `StringAttr` | string attribute attribute |
#### Results:
1. `output`: memref of any type values or tensor of any type values
### onnx.Pad (ONNXPadOp)
ONNX Pad operation

View File

@ -37,6 +37,7 @@ special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
("Gemm", "ImportNodeGemm"),
("Pad", "ImportNodePad"),
#("Transpose", "ImportNodeTranspose")
])

View File

@ -446,6 +446,18 @@ private:
}
}
/*!
* Special handle for Pad operations.
*/
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {

View File

@ -178,7 +178,7 @@
}else if (OpName == "PRelu") {
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
}else if (OpName == "Pad") {
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1);
ImportNodePad(node, 3, 1);
}else if (OpName == "Pow") {
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
}else if (OpName == "QLinearConv") {

View File

@ -145,4 +145,36 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
[NoSideEffect ]> {
let summary = "ONNX Pad operation with constant padding value";
let hasCanonicalizer = 1;
let description = [{ "this operation is introduced to handle situation"
" in which the padding value is a constant.
" The value will become an attribute."
"This operation is also used to handle the optional value input is missing and the default value 0."
"is used."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads,
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstatValuePad",
[NoSideEffect ]> {
let summary = "ONNX Pad operation with constant padding value";
let description = [{ "this operation is introduced to handle situation"
" in which the padding value and padding are constants"
"They will become attributes."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
I64ArrayAttr:$pads,
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
#endif // ONNX_OPS

View File

@ -33,3 +33,9 @@ void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<IdentityEliminationPattern>(context);
}
///on the ONNXPadConstantValueOp.
void ONNXPadConstantValueOp::getCanonicalizationPatterns(
OwningRewritePatternList& result, MLIRContext* context) {
result.insert<ConstantPadPattern>(context);
}

View File

@ -45,4 +45,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>;
def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $v1, $v2), $m2, $m3),
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
[(HasOneUse $res)]>;
#endif // ONNX_COMBINE

View File

@ -84,3 +84,11 @@ func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}
// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstatValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>
%0 ="onnx.Constant"() {value=[0, 2, 0, 0]} : ()-> tensor<?xi64>
%2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor<?x?xf32>, tensor<?xi64>)-> tensor<*xf32>
"std.return"(%2) : (tensor<*xf32>) -> ()
}