Add Rewrite rule to eliminate CastOp when input element type is the same as expected output element type (#237)

* move scalerop to decompose

* change clang format

* change clang format

* add shape inference for scaler op

* fixing generated onnxop

* generate onnx.md

* add benefit for scaler decompose and simplify scaler shape inference

* cast rewrite only for float

* add cast op same type rewrite rule

* fix format

Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
Anh Leu 2020-07-27 11:49:14 -05:00 committed by GitHub
parent 32ceb6968a
commit a611c145f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 1 deletions

View File

@ -450,6 +450,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
def ONNXCastOp:ONNX_Op<"Cast",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
let hasCanonicalizer = 1;
let summary = "ONNX Cast operation";
let description = [{
"The operator casts the elements of a given input tensor to a data type"

View File

@ -18,6 +18,7 @@
using namespace mlir;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXCombine.inc"
} // end anonymous namespace
@ -44,3 +45,9 @@ void ONNXPadConstantValueOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<ConstantPadPattern>(context);
}
/// on the ONNXCastOp.
void ONNXCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<CastEliminationPattern>(context);
}

View File

@ -27,6 +27,10 @@ include "src/Dialect/ONNX/ONNXOps.td"
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
def HasSameElementType : Constraint<
CPred<"($0.getType().dyn_cast<ShapedType>().getElementType() == "
"convertONNXTypeToMLIRType($_builder, static_cast<onnx::TensorProto_DataType>($1.cast<::mlir::IntegerAttr>().getInt())))">,
"has same element type">;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite
@ -55,4 +59,12 @@ def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
[(HasOneUse $res)]>;
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
def CastEliminationPattern : Pat<
(ONNXCastOp $arg, $type),
(replaceWithValue $arg),
[(HasSameElementType $arg, $type)]>;
#endif // ONNX_COMBINE

View File

@ -96,3 +96,13 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32>
}
// -----
//CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> {
func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
// CHECK-NEXT: return %arg0 : tensor<2xf32>
}

View File

@ -256,7 +256,7 @@ OpsWithShapeInference = [
]
# Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast']
# Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion).