Add Rewrite rule to eliminate CastOp when input element type is the same as expected output element type (#237)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * fix format Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
parent
32ceb6968a
commit
a611c145f4
|
@ -450,6 +450,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
|
||||||
|
|
||||||
def ONNXCastOp:ONNX_Op<"Cast",
|
def ONNXCastOp:ONNX_Op<"Cast",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Cast operation";
|
let summary = "ONNX Cast operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The operator casts the elements of a given input tensor to a data type"
|
"The operator casts the elements of a given input tensor to a data type"
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/Transform/ONNX/ONNXCombine.inc"
|
#include "src/Transform/ONNX/ONNXCombine.inc"
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
@ -44,3 +45,9 @@ void ONNXPadConstantValueOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &result, MLIRContext *context) {
|
OwningRewritePatternList &result, MLIRContext *context) {
|
||||||
result.insert<ConstantPadPattern>(context);
|
result.insert<ConstantPadPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// on the ONNXCastOp.
|
||||||
|
void ONNXCastOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &result, MLIRContext *context) {
|
||||||
|
result.insert<CastEliminationPattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -27,6 +27,10 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
||||||
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
||||||
|
def HasSameElementType : Constraint<
|
||||||
|
CPred<"($0.getType().dyn_cast<ShapedType>().getElementType() == "
|
||||||
|
"convertONNXTypeToMLIRType($_builder, static_cast<onnx::TensorProto_DataType>($1.cast<::mlir::IntegerAttr>().getInt())))">,
|
||||||
|
"has same element type">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern-Match and Rewrite
|
// Pattern-Match and Rewrite
|
||||||
|
@ -55,4 +59,12 @@ def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $
|
||||||
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
(ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
|
|
||||||
|
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
|
||||||
|
def CastEliminationPattern : Pat<
|
||||||
|
(ONNXCastOp $arg, $type),
|
||||||
|
(replaceWithValue $arg),
|
||||||
|
[(HasSameElementType $arg, $type)]>;
|
||||||
|
|
||||||
|
|
||||||
#endif // ONNX_COMBINE
|
#endif // ONNX_COMBINE
|
||||||
|
|
|
@ -96,3 +96,13 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
|
||||||
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
||||||
// return [[GEMM]] : tensor<*xf32>
|
// return [[GEMM]] : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
//CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -256,7 +256,7 @@ OpsWithShapeInference = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast']
|
||||||
|
|
||||||
# Operations who have operands that, if produced by constant operations, should
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
# be promoted to become an attribute (via attribute promotion).
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
|
Loading…
Reference in New Issue