Rewrite shape and size OP (#285)
* add shape inference * Revert "add shape inference" This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16. * add rewrite rules * test cases * format * add constraint * response to review * response to review
This commit is contained in:
		
							parent
							
								
									5e11429d77
								
							
						
					
					
						commit
						ac67900baf
					
				| 
						 | 
					@ -4729,6 +4729,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXShapeOp:ONNX_Op<"Shape",
 | 
					def ONNXShapeOp:ONNX_Op<"Shape",
 | 
				
			||||||
  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
					  [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
  let summary = "ONNX Shape operation";
 | 
					  let summary = "ONNX Shape operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
 | 
					  "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
 | 
				
			||||||
| 
						 | 
					@ -4863,6 +4864,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXSizeOp:ONNX_Op<"Size",
 | 
					def ONNXSizeOp:ONNX_Op<"Size",
 | 
				
			||||||
  [NoSideEffect]> {
 | 
					  [NoSideEffect]> {
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
  let summary = "ONNX Size operation";
 | 
					  let summary = "ONNX Size operation";
 | 
				
			||||||
  let description = [{
 | 
					  let description = [{
 | 
				
			||||||
  "Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor."
 | 
					  "Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor."
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,29 @@ DenseElementsAttr createDenseElementsAttrFromFloatAttr(
 | 
				
			||||||
  return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
 | 
					  return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a DenseElementsAttr based on the shape of type.
 | 
				
			||||||
 | 
					DenseElementsAttr createDenseElementsAttrFromShape(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Value value) {
 | 
				
			||||||
 | 
					  auto inType = value.getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					  auto shape = inType.getShape();
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> dims = {inType.getRank()};
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 4> values(shape.begin(), shape.end());
 | 
				
			||||||
 | 
					  auto tensorType =
 | 
				
			||||||
 | 
					      mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64));
 | 
				
			||||||
 | 
					  return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a DenseElementsAttr based on the size of type.
 | 
				
			||||||
 | 
					DenseElementsAttr createDenseElementsAttrFromSize(
 | 
				
			||||||
 | 
					    PatternRewriter &rewriter, Value value) {
 | 
				
			||||||
 | 
					  auto inType = value.getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> dims(1, 1);
 | 
				
			||||||
 | 
					  SmallVector<int64_t, 1> values = {inType.getNumElements()};
 | 
				
			||||||
 | 
					  auto tensorType =
 | 
				
			||||||
 | 
					      mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64));
 | 
				
			||||||
 | 
					  return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// If 'lhs' is not NoneType, return 'lhs - rhs'.
 | 
					// If 'lhs' is not NoneType, return 'lhs - rhs'.
 | 
				
			||||||
// Otherwise, return '-rhs'.
 | 
					// Otherwise, return '-rhs'.
 | 
				
			||||||
Value subtractOrNeg(
 | 
					Value subtractOrNeg(
 | 
				
			||||||
| 
						 | 
					@ -128,3 +151,15 @@ void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns(
 | 
				
			||||||
    OwningRewritePatternList &results, MLIRContext *context) {
 | 
					    OwningRewritePatternList &results, MLIRContext *context) {
 | 
				
			||||||
  results.insert<FuseBatchNormTestModeConvPattern>(context);
 | 
					  results.insert<FuseBatchNormTestModeConvPattern>(context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// on the ONNXShapeOp.
 | 
				
			||||||
 | 
					void ONNXShapeOp::getCanonicalizationPatterns(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &results, MLIRContext *context) {
 | 
				
			||||||
 | 
					  results.insert<ShapeToConstantPattern>(context);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// on the ONNXSizeOp.
 | 
				
			||||||
 | 
					void ONNXSizeOp::getCanonicalizationPatterns(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &results, MLIRContext *context) {
 | 
				
			||||||
 | 
					  results.insert<SizeToConstantPattern>(context);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,6 +28,14 @@ include "src/Dialect/ONNX/ONNXOps.td"
 | 
				
			||||||
def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
 | 
					def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
 | 
				
			||||||
  "createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">;
 | 
					  "createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a DenseElementsAttr from the shape of the type of a value.
 | 
				
			||||||
 | 
					def createDenseElementsAttrFromShape : NativeCodeCall<
 | 
				
			||||||
 | 
					  "createDenseElementsAttrFromShape($_builder, $0)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a DenseElementsAttr from the size of the type of a value.
 | 
				
			||||||
 | 
					def createDenseElementsAttrFromSize : NativeCodeCall<
 | 
				
			||||||
 | 
					  "createDenseElementsAttrFromSize($_builder, $0)">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// If '$1' is not NoneType, do subtraction '$1 - $2'.
 | 
					// If '$1' is not NoneType, do subtraction '$1 - $2'.
 | 
				
			||||||
// Otherwise, take the negative of '$2'.
 | 
					// Otherwise, take the negative of '$2'.
 | 
				
			||||||
def subtractOrNeg: NativeCodeCall<
 | 
					def subtractOrNeg: NativeCodeCall<
 | 
				
			||||||
| 
						 | 
					@ -172,4 +180,25 @@ def FuseBatchNormTestModeConvPattern: Pat<
 | 
				
			||||||
     $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides)
 | 
					     $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides)
 | 
				
			||||||
>;
 | 
					>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def IsStaticShapeTensor:
 | 
				
			||||||
 | 
					   Constraint<
 | 
				
			||||||
 | 
					     CPred<
 | 
				
			||||||
 | 
					       "$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">,
 | 
				
			||||||
 | 
					     "hasStaticShape">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ShapeToConstantPattern: Pat<
 | 
				
			||||||
 | 
					     (ONNXShapeOp $A),
 | 
				
			||||||
 | 
					     (ONNXConstantOp
 | 
				
			||||||
 | 
					        (GetNullAttr),
 | 
				
			||||||
 | 
					        (createDenseElementsAttrFromShape $A)),
 | 
				
			||||||
 | 
					     [(IsStaticShapeTensor:$A)]
 | 
				
			||||||
 | 
					>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def SizeToConstantPattern: Pat<
 | 
				
			||||||
 | 
					     (ONNXSizeOp $A),
 | 
				
			||||||
 | 
					     (ONNXConstantOp
 | 
				
			||||||
 | 
					        (GetNullAttr),
 | 
				
			||||||
 | 
					        (createDenseElementsAttrFromSize $A)),
 | 
				
			||||||
 | 
					     [(IsStaticShapeTensor:$A)]
 | 
				
			||||||
 | 
					>;
 | 
				
			||||||
#endif // ONNX_REWRITE
 | 
					#endif // ONNX_REWRITE
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -222,3 +222,49 @@ func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10
 | 
				
			||||||
  // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
 | 
					  // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
 | 
				
			||||||
  "std.return"(%1) : (tensor<10x11x12x13xf32>) -> ()
 | 
					  "std.return"(%1) : (tensor<10x11x12x13xf32>) -> ()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_shape1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Shape"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: @test_shape1
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[2, 4, 8, 16]> : tensor<4xi64>} : () -> tensor<*xi64>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_shape2(%arg0 : tensor<?x4x8x16xf32>) -> tensor<*xi64> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Shape"(%arg0) : (tensor<?x4x8x16xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: @test_shape2
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %0 = "onnx.Shape"(%arg0) : (tensor<?x4x8x16xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_size1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Size"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: @test_size1
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1024> : tensor<1xi64>} : () -> tensor<*xi64>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_size2(%arg0 : tensor<*xf32>) -> tensor<*xi64> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: @test_size2
 | 
				
			||||||
 | 
					  // CHECK-NEXT: %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: return %0 : tensor<*xi64>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -321,7 +321,7 @@ OpsWithShapeInference=[
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Operations supporting canonicalization.
 | 
					# Operations supporting canonicalization.
 | 
				
			||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout']
 | 
					OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout', 'Shape', 'Size']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Operations who have operands that, if produced by constant operations, should
 | 
					# Operations who have operands that, if produced by constant operations, should
 | 
				
			||||||
# be promoted to become an attribute (via attribute promotion).
 | 
					# be promoted to become an attribute (via attribute promotion).
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue