Add Rewrite rule to eliminate CastOp when input element type is the same as expected output element type (#237)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * fix format Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
32ceb6968a
commit
a611c145f4
|
@ -450,6 +450,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
|
|||
|
||||
def ONNXCastOp:ONNX_Op<"Cast",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX Cast operation";
|
||||
let description = [{
|
||||
"The operator casts the elements of a given input tensor to a data type"
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/Transform/ONNX/ONNXCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
@ -44,3 +45,9 @@ void ONNXPadConstantValueOp::getCanonicalizationPatterns(
|
|||
OwningRewritePatternList &result, MLIRContext *context) {
|
||||
result.insert<ConstantPadPattern>(context);
|
||||
}
|
||||
|
||||
/// on the ONNXCastOp.
|
||||
void ONNXCastOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &result, MLIRContext *context) {
|
||||
result.insert<CastEliminationPattern>(context);
|
||||
}
|
||||
|
|
|
@ -27,6 +27,10 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
|||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
||||
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
||||
def HasSameElementType : Constraint<
|
||||
CPred<"($0.getType().dyn_cast<ShapedType>().getElementType() == "
|
||||
"convertONNXTypeToMLIRType($_builder, static_cast<onnx::TensorProto_DataType>($1.cast<::mlir::IntegerAttr>().getInt())))">,
|
||||
"has same element type">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite
|
||||
|
@ -55,4 +59,12 @@ def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $
|
|||
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
||||
[(HasOneUse $res)]>;
|
||||
|
||||
|
||||
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
|
||||
def CastEliminationPattern : Pat<
|
||||
(ONNXCastOp $arg, $type),
|
||||
(replaceWithValue $arg),
|
||||
[(HasSameElementType $arg, $type)]>;
|
||||
|
||||
|
||||
#endif // ONNX_COMBINE
|
||||
|
|
|
@ -96,3 +96,13 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
|
|||
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
||||
// return [[GEMM]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> {
|
||||
func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
|
||||
// CHECK-NEXT: return %arg0 : tensor<2xf32>
|
||||
}
|
||||
|
|
|
@ -256,7 +256,7 @@ OpsWithShapeInference = [
|
|||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast']
|
||||
|
||||
# Operations who have operands that, if produced by constant operations, should
|
||||
# be promoted to become an attribute (via attribute promotion).
|
||||
|
|
Loading…
Reference in New Issue