Add Rewrite rule to eliminate CastOp when input element type is the same as expected output element type (#237)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * fix format Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
		
							parent
							
								
									32ceb6968a
								
							
						
					
					
						commit
						a611c145f4
					
				| 
						 | 
				
			
			@ -450,6 +450,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
 | 
			
		|||
 | 
			
		||||
def ONNXCastOp:ONNX_Op<"Cast",
 | 
			
		||||
  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"ResultTypeInferenceOpInterface">]> {
 | 
			
		||||
  let hasCanonicalizer = 1;
 | 
			
		||||
  let summary = "ONNX Cast operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
  "The operator casts the elements of a given input tensor to a data type"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,6 +18,7 @@
 | 
			
		|||
using namespace mlir;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
/// Include the patterns defined in the Declarative Rewrite framework.
 | 
			
		||||
#include "src/Transform/ONNX/ONNXCombine.inc"
 | 
			
		||||
} // end anonymous namespace
 | 
			
		||||
| 
						 | 
				
			
			@ -44,3 +45,9 @@ void ONNXPadConstantValueOp::getCanonicalizationPatterns(
 | 
			
		|||
    OwningRewritePatternList &result, MLIRContext *context) {
 | 
			
		||||
  result.insert<ConstantPadPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// on the ONNXCastOp.
 | 
			
		||||
void ONNXCastOp::getCanonicalizationPatterns(
 | 
			
		||||
    OwningRewritePatternList &result, MLIRContext *context) {
 | 
			
		||||
  result.insert<CastEliminationPattern>(context);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,10 @@ include "src/Dialect/ONNX/ONNXOps.td"
 | 
			
		|||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
 | 
			
		||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
 | 
			
		||||
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
 | 
			
		||||
def HasSameElementType : Constraint<
 | 
			
		||||
    CPred<"($0.getType().dyn_cast<ShapedType>().getElementType() == "
 | 
			
		||||
          "convertONNXTypeToMLIRType($_builder, static_cast<onnx::TensorProto_DataType>($1.cast<::mlir::IntegerAttr>().getInt())))">,
 | 
			
		||||
    "has same element type">;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Pattern-Match and Rewrite
 | 
			
		||||
| 
						 | 
				
			
			@ -55,4 +59,12 @@ def ConstantPadPattern : Pat<(ONNXPadConstantValueOp $m1, (ONNXConstantOp:$res $
 | 
			
		|||
                             (ONNXPadConstantValuePadOp $m1, $v2, $m2, $m3),
 | 
			
		||||
                             [(HasOneUse $res)]>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// ONNX_Op (onnx.Cast (%X, $type)) = ONNX_Op (%X)
 | 
			
		||||
def CastEliminationPattern : Pat<
 | 
			
		||||
	(ONNXCastOp $arg, $type),
 | 
			
		||||
	(replaceWithValue $arg),
 | 
			
		||||
  [(HasSameElementType $arg, $type)]>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#endif // ONNX_COMBINE
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,3 +96,13 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
 | 
			
		|||
  // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
 | 
			
		||||
  // return [[GEMM]] : tensor<*xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> {
 | 
			
		||||
func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> {
 | 
			
		||||
  %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2xf32>) -> tensor<2xf32>
 | 
			
		||||
  return %0 : tensor<2xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK-NEXT: return %arg0 : tensor<2xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -256,7 +256,7 @@ OpsWithShapeInference = [
 | 
			
		|||
]
 | 
			
		||||
 | 
			
		||||
# Operations supporting canonicalization.
 | 
			
		||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
 | 
			
		||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast']
 | 
			
		||||
 | 
			
		||||
# Operations who have operands that, if produced by constant operations, should
 | 
			
		||||
# be promoted to become an attribute (via attribute promotion).
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue