From c3041bfb43148de3ffe4113c9308aa2b4d8187a1 Mon Sep 17 00:00:00 2001 From: chentong Date: Thu, 13 Feb 2020 12:08:29 -0500 Subject: [PATCH] shape inference for pad with constant pads --- doc/Dialects/onnx.md | 28 ++++++++- src/dialect/onnx/onnx.td | 18 +++++- src/dialect/onnx/onnx_ops.cpp | 74 +++++++++++++++++++++++ src/pass/shape_inference_pass.cpp | 2 + test/mlir/onnx/onnx_canonicalization.mlir | 2 +- test/mlir/onnx/onnx_shape_inference.mlir | 20 ++++++ 6 files changed, 140 insertions(+), 4 deletions(-) diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index ba5de60..f67b1ae 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -2859,6 +2859,32 @@ ONNX PRelu operation 1. `Y`: memref of any type values or tensor of any type values +### onnx.PadConstantPad (ONNXPadConstantPadOp) +ONNX Pad operation with constant padding value + +#### Description: + + +"this operation is introduced to handle situation" + " in which the padding value and padding are constants" + "They will become attributes." + +#### Operands: + +1. `data`: memref of any type values or tensor of any type values +1. `constant_value`: memref of any type values or tensor of any type values + +#### Attributes: + +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute | +| `mode` | `StringAttr` | string attribute attribute | + +#### Results: + +1. `output`: memref of any type values or tensor of any type values + ### onnx.PadConstantValue (ONNXPadConstantValueOp) ONNX Pad operation with constant padding value @@ -2887,7 +2913,7 @@ ONNX Pad operation with constant padding value 1. `output`: memref of any type values or tensor of any type values -### onnx.PadConstatValuePad (ONNXPadConstantValuePadOp) +### onnx.PadConstantValuePad (ONNXPadConstantValuePadOp) ONNX Pad operation with constant padding value #### Description: diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 9ad9cde..41c644d 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -162,8 +162,22 @@ def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); } -def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstatValuePad", - [NoSideEffect ]> { +def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad", + [NoSideEffect, DeclareOpInterfaceMethods ]> { + let summary = "ONNX Pad operation with constant padding value"; + let description = [{ "this operation is introduced to handle situation" + " in which the padding value and padding are constants" + "They will become attributes." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$constant_value, + I64ArrayAttr:$pads, + DefaultValuedAttr:$mode); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); +} + +def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", + [NoSideEffect, DeclareOpInterfaceMethods ]> { let summary = "ONNX Pad operation with constant padding value"; let description = [{ "this operation is introduced to handle situation" " in which the padding value and padding are constants" diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index fafc834..4c30c86 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -1015,6 +1015,80 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { //===----------------------------------------------------------------------===// +// PadConstantPad + +void ONNXPadConstantPadOp::inferShapes(){ + // Cannot infer shape if no shape exists. + if (!data().getType().isa()) + return; + + // 1) get shape of input "data" + auto dataTy = data().getType().cast(); + auto dataShape = dataTy.getShape(); + auto dataRank = dataShape.size(); + + SmallVector outputShape(dataShape.begin(), dataShape.end()); + auto padsOpt = pads(); + if (padsOpt) { + auto padsArray = padsOpt.getValue(); + // pads consists of two entries for each spatial axis. + if (padsArray.size() != 2 * dataRank) + emitError("pads rank is not twice the spatial rank."); + // fill in the actual values + for (int i = 0; i < dataRank; ++i) { + int64_t p1 = (padsArray[2*i]).cast().getInt(); + if (p1 < 0) + emitError("pads value must be nonnegative."); + int64_t p2 = (padsArray[2*i+1]).cast().getInt(); + if (p2 < 0) + emitError("pads value must be nonnegative."); + outputShape[i] += p1+p2; + } + getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType())); + } else { + emitError("pads attribute is not available."); + } +} + +//===----------------------------------------------------------------------===// + +// PadConstantValuePad + +void ONNXPadConstantValuePadOp::inferShapes(){ + // Cannot infer shape if no shape exists. + if (!data().getType().isa()) + return; + + // 1) get shape of input "data" + auto dataTy = data().getType().cast(); + auto dataShape = dataTy.getShape(); + auto dataRank = dataShape.size(); + + SmallVector outputShape(dataShape.begin(), dataShape.end()); + auto padsOpt = pads(); + if (padsOpt) { + auto padsArray = padsOpt.getValue(); + // pads consists of two entries for each spatial axis. + if (padsArray.size() != 2 * dataRank) + emitError("pads rank is not twice the spatial rank."); + // fill in the actual values + for (int i = 0; i < dataRank; ++i) { + int64_t p1 = (padsArray[2*i]).cast().getInt(); + if (p1 < 0) + emitError("pads value must be nonnegative."); + int64_t p2 = (padsArray[2*i+1]).cast().getInt(); + if (p2 < 0) + emitError("pads value must be nonnegative."); + outputShape[i] += p1+p2; + } + getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType())); + } else { + emitError("pads attribute is not available."); + } +} + +//===----------------------------------------------------------------------===// + // Unsqueeze void ONNXUnsqueezeOp::inferShapes() { diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index d62069a..0d4ae18 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -128,6 +128,8 @@ public: op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.ConvNoBias" && + op->getName().getStringRef() != "onnx.PadConstantPad" && + op->getName().getStringRef() != "onnx.PadConstantValuePad" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 0233a28..78825c8 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -87,7 +87,7 @@ func @test_reducesumsquare(%arg0 : tensor) -> tensor<*xf32> { // CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor) -> tensor<*xf32> { func @test_constant_pad(%arg0 : tensor) -> tensor<*xf32> { - // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstatValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor) -> tensor<*xf32> + // CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor) -> tensor<*xf32> %0 ="onnx.Constant"() {value=[0, 2, 0, 0]} : ()-> tensor %2 = "onnx.PadConstantValue"(%arg0, %0) {constant_value=0. : f32, mode = "constant"} : (tensor, tensor)-> tensor<*xf32> "std.return"(%2) : (tensor<*xf32>) -> () diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 14c575d..3e52625 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -263,3 +263,23 @@ func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7 // CHECK-LABEL: test_conv_no_bias_11 // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", dilations = [2, 3], group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> + + +/// Test PadConstantValuePad_1 +func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { + %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<16x13xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_PadConstantValuePad_1 +// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32> +// CHECK: return [[RES]] : tensor<18x13xf32> + +/// Test PadConstantPad_1 +func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_PadConstantPad_1 +// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 2, 3, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32> +// CHECK: return [[RES]] : tensor<18x17xf32> +