Rewrite shape and size OP (#285)
* add shape inference * Revert "add shape inference" This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16. * add rewrite rules * test cases * format * add constraint * response to review * response to review
This commit is contained in:
		
							parent
							
								
									5e11429d77
								
							
						
					
					
						commit
						ac67900baf
					
				|  | @ -4729,6 +4729,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", | |||
| 
 | ||||
| def ONNXShapeOp:ONNX_Op<"Shape", | ||||
|   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let hasCanonicalizer = 1; | ||||
|   let summary = "ONNX Shape operation"; | ||||
|   let description = [{ | ||||
|   "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." | ||||
|  | @ -4863,6 +4864,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh", | |||
| 
 | ||||
| def ONNXSizeOp:ONNX_Op<"Size", | ||||
|   [NoSideEffect]> { | ||||
|   let hasCanonicalizer = 1; | ||||
|   let summary = "ONNX Size operation"; | ||||
|   let description = [{ | ||||
|   "Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor." | ||||
|  |  | |||
|  | @ -27,6 +27,29 @@ DenseElementsAttr createDenseElementsAttrFromFloatAttr( | |||
|   return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); | ||||
| } | ||||
| 
 | ||||
| // Create a DenseElementsAttr based on the shape of type.
 | ||||
| DenseElementsAttr createDenseElementsAttrFromShape( | ||||
|     PatternRewriter &rewriter, Value value) { | ||||
|   auto inType = value.getType().cast<ShapedType>(); | ||||
|   auto shape = inType.getShape(); | ||||
|   SmallVector<int64_t, 1> dims = {inType.getRank()}; | ||||
|   SmallVector<int64_t, 4> values(shape.begin(), shape.end()); | ||||
|   auto tensorType = | ||||
|       mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64)); | ||||
|   return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); | ||||
| } | ||||
| 
 | ||||
| // Create a DenseElementsAttr based on the size of type.
 | ||||
| DenseElementsAttr createDenseElementsAttrFromSize( | ||||
|     PatternRewriter &rewriter, Value value) { | ||||
|   auto inType = value.getType().cast<ShapedType>(); | ||||
|   SmallVector<int64_t, 1> dims(1, 1); | ||||
|   SmallVector<int64_t, 1> values = {inType.getNumElements()}; | ||||
|   auto tensorType = | ||||
|       mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64)); | ||||
|   return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); | ||||
| } | ||||
| 
 | ||||
| // If 'lhs' is not NoneType, return 'lhs - rhs'.
 | ||||
| // Otherwise, return '-rhs'.
 | ||||
| Value subtractOrNeg( | ||||
|  | @ -128,3 +151,15 @@ void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns( | |||
|     OwningRewritePatternList &results, MLIRContext *context) { | ||||
|   results.insert<FuseBatchNormTestModeConvPattern>(context); | ||||
| } | ||||
| 
 | ||||
| /// on the ONNXShapeOp.
 | ||||
| void ONNXShapeOp::getCanonicalizationPatterns( | ||||
|     OwningRewritePatternList &results, MLIRContext *context) { | ||||
|   results.insert<ShapeToConstantPattern>(context); | ||||
| } | ||||
| 
 | ||||
| /// on the ONNXSizeOp.
 | ||||
| void ONNXSizeOp::getCanonicalizationPatterns( | ||||
|     OwningRewritePatternList &results, MLIRContext *context) { | ||||
|   results.insert<SizeToConstantPattern>(context); | ||||
| } | ||||
|  |  | |||
|  | @ -28,6 +28,14 @@ include "src/Dialect/ONNX/ONNXOps.td" | |||
| def createDenseElementsAttrFromFloatAttr : NativeCodeCall< | ||||
|   "createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">; | ||||
| 
 | ||||
| // Create a DenseElementsAttr from the shape of the type of a value. | ||||
| def createDenseElementsAttrFromShape : NativeCodeCall< | ||||
|   "createDenseElementsAttrFromShape($_builder, $0)">; | ||||
| 
 | ||||
| // Create a DenseElementsAttr from the size of the type of a value. | ||||
| def createDenseElementsAttrFromSize : NativeCodeCall< | ||||
|   "createDenseElementsAttrFromSize($_builder, $0)">; | ||||
| 
 | ||||
| // If '$1' is not NoneType, do subtraction '$1 - $2'. | ||||
| // Otherwise, take the negative of '$2'. | ||||
| def subtractOrNeg: NativeCodeCall< | ||||
|  | @ -172,4 +180,25 @@ def FuseBatchNormTestModeConvPattern: Pat< | |||
|      $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides) | ||||
| >; | ||||
| 
 | ||||
| def IsStaticShapeTensor: | ||||
|    Constraint< | ||||
|      CPred< | ||||
|        "$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">, | ||||
|      "hasStaticShape">; | ||||
| 
 | ||||
| def ShapeToConstantPattern: Pat< | ||||
|      (ONNXShapeOp $A), | ||||
|      (ONNXConstantOp | ||||
|         (GetNullAttr), | ||||
|         (createDenseElementsAttrFromShape $A)), | ||||
|      [(IsStaticShapeTensor:$A)] | ||||
| >; | ||||
| 
 | ||||
| def SizeToConstantPattern: Pat< | ||||
|      (ONNXSizeOp $A), | ||||
|      (ONNXConstantOp | ||||
|         (GetNullAttr), | ||||
|         (createDenseElementsAttrFromSize $A)), | ||||
|      [(IsStaticShapeTensor:$A)] | ||||
| >; | ||||
| #endif // ONNX_REWRITE | ||||
|  |  | |||
|  | @ -222,3 +222,49 @@ func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10 | |||
|   // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32> | ||||
|   "std.return"(%1) : (tensor<10x11x12x13xf32>) -> () | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_shape1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> { | ||||
|   %0 = "onnx.Shape"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64> | ||||
|   return %0 : tensor<*xi64> | ||||
| 
 | ||||
|   // CHECK-LABEL: @test_shape1 | ||||
|   // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[2, 4, 8, 16]> : tensor<4xi64>} : () -> tensor<*xi64> | ||||
|   // CHECK-NEXT: %0 : tensor<*xi64> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_shape2(%arg0 : tensor<?x4x8x16xf32>) -> tensor<*xi64> { | ||||
|   %0 = "onnx.Shape"(%arg0) : (tensor<?x4x8x16xf32>) -> tensor<*xi64> | ||||
|   return %0 : tensor<*xi64> | ||||
| 
 | ||||
|   // CHECK-LABEL: @test_shape2 | ||||
|   // CHECK-NEXT: %0 = "onnx.Shape"(%arg0) : (tensor<?x4x8x16xf32>) -> tensor<*xi64> | ||||
|   // CHECK-NEXT: return %0 : tensor<*xi64> | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_size1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> { | ||||
|   %0 = "onnx.Size"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64> | ||||
|   return %0 : tensor<*xi64> | ||||
| 
 | ||||
|   // CHECK-LABEL: @test_size1 | ||||
|   // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1024> : tensor<1xi64>} : () -> tensor<*xi64> | ||||
|   // CHECK-NEXT: %0 : tensor<*xi64> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_size2(%arg0 : tensor<*xf32>) -> tensor<*xi64> { | ||||
|   %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64> | ||||
|   return %0 : tensor<*xi64> | ||||
| 
 | ||||
|   // CHECK-LABEL: @test_size2 | ||||
|   // CHECK-NEXT: %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64> | ||||
|   // CHECK-NEXT: return %0 : tensor<*xi64> | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -321,7 +321,7 @@ OpsWithShapeInference=[ | |||
| ] | ||||
| 
 | ||||
| # Operations supporting canonicalization. | ||||
| OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout'] | ||||
| OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout', 'Shape', 'Size'] | ||||
| 
 | ||||
| # Operations who have operands that, if produced by constant operations, should | ||||
| # be promoted to become an attribute (via attribute promotion). | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue