From a611c145f4080b3d915151eaca0a7eebf0968f66 Mon Sep 17 00:00:00 2001 From: Anh Leu Date: Mon, 27 Jul 2020 11:49:14 -0500 Subject: [PATCH] Add Rewrite rule to eliminate CastOp when input element type is the same as expected output element type (#237) * move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * fix format Co-authored-by: chentong319 --- src/Dialect/ONNX/ONNXOps.td.inc | 1 + src/Transform/ONNX/Combine.cpp | 7 +++++++ src/Transform/ONNX/Combine.td | 12 ++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 10 ++++++++++ utils/gen_onnx_mlir.py | 2 +- 5 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index e234b9a..43a02f5 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -450,6 +450,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", def ONNXCastOp:ONNX_Op<"Cast", [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"ResultTypeInferenceOpInterface">]> { + let hasCanonicalizer = 1; let summary = "ONNX Cast operation"; let description = [{ "The operator casts the elements of a given input tensor to a data type" diff --git a/src/Transform/ONNX/Combine.cpp b/src/Transform/ONNX/Combine.cpp index 60696e0..08f8abb 100644 --- a/src/Transform/ONNX/Combine.cpp +++ b/src/Transform/ONNX/Combine.cpp @@ -18,6 +18,7 @@ using namespace mlir; namespace { + /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXCombine.inc" } // end anonymous namespace @@ -44,3 +45,9 @@ void ONNXPadConstantValueOp::getCanonicalizationPatterns( OwningRewritePatternList &result, MLIRContext *context) { result.insert(context); } + +/// on the ONNXCastOp. +void ONNXCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &result, MLIRContext *context) { + result.insert(context); +} diff --git a/src/Transform/ONNX/Combine.td b/src/Transform/ONNX/Combine.td index 610257c..04067e0 100644 --- a/src/Transform/ONNX/Combine.td +++ b/src/Transform/ONNX/Combine.td @@ -27,6 +27,10 @@ include "src/Dialect/ONNX/ONNXOps.td" def HasOneUse : Constraint>; class HasRankOf : Constraint() && $0.getType().cast().getRank() == " # rank>>; def HasNoneType : Constraint()">>; +def HasSameElementType : Constraint< + CPred<"($0.getType().dyn_cast().getElementType() == " + "convertONNXTypeToMLIRType($_builder, static_cast($1.cast<::mlir::IntegerAttr>().getInt())))">, + "has same element type">; //===----------------------------------------------------------------------===// // Pattern-Match and Rewrite @@ -55,4 +59,12 @@ def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $ (ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3), [(HasOneUse $res)]>; + +// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X) +def CastEliminationPattern : Pat< + (ONNXCastOp $arg, $type), + (replaceWithValue $arg), + [(HasSameElementType $arg, $type)]>; + + #endif // ONNX_COMBINE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 3d37f5d..954fe75 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -96,3 +96,13 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1 // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> // return [[GEMM]] : tensor<*xf32> } + +// ----- + +//CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> { +func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> + + // CHECK-NEXT: return %arg0 : tensor<2xf32> +} diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index d2d4360..74b232c 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -256,7 +256,7 @@ OpsWithShapeInference = [ ] # Operations supporting canonicalization. -OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] +OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast'] # Operations who have operands that, if produced by constant operations, should # be promoted to become an attribute (via attribute promotion).