Update GetDimensionSize and SetDimensionSize ops to use I64 attribute for dimension
This is to match with HLO semantics and general dimension semantics in MLIR. Also, * Define minimal verifier for these ops. * Add folder for SetDimensionSize op on static shaped dimension. * Fix assumption of ranked shape in GetDimensionSize op. PiperOrigin-RevId: 341150923
This commit is contained in:
		
							parent
							
								
									3238f8226f
								
							
						
					
					
						commit
						4ef12aa000
					
				|  | @ -1017,7 +1017,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, | |||
|       BASE_HLO_GetDimensionSizeOp { | ||||
|   let arguments = (ins | ||||
|     HLO_Tensor:$operand, | ||||
|     I32Attr:$dimension | ||||
|     I64Attr:$dimension | ||||
|   ); | ||||
|   // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the | ||||
|   // XLA semantics is available. This limitation is because of the current XLA | ||||
|  | @ -1129,9 +1129,11 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, | |||
|   let arguments = (ins | ||||
|     HLO_Tensor:$operand, | ||||
|     I32Tensor:$size, | ||||
|     I32Attr:$dimension | ||||
|     I64Attr:$dimension | ||||
|   ); | ||||
|   let results = (outs HLO_Tensor); | ||||
| 
 | ||||
|   let hasFolder = 1; | ||||
| } | ||||
| 
 | ||||
| def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp { | ||||
|  |  | |||
|  | @ -86,6 +86,26 @@ namespace { | |||
| // Utilities for the canonicalize patterns
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| // Verifies that dimension attribute for the op correctly indexes in operand or
 | ||||
| // result shape.
 | ||||
| template <typename OpT> | ||||
| static LogicalResult VerifyDimAttr(OpT op) { | ||||
|   int64_t rank = -1; | ||||
|   if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) { | ||||
|     rank = ty.getRank(); | ||||
|   } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) { | ||||
|     rank = ty.getRank(); | ||||
|   } else { | ||||
|     return success(); | ||||
|   } | ||||
| 
 | ||||
|   int64_t dim = op.dimension(); | ||||
|   if (dim < 0 || dim >= rank) | ||||
|     return op.emitOpError() << "requires dimension attribute in range [0, " | ||||
|                             << rank << "); found (" << dim << ")"; | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| // Returns 1D 64-bit dense elements attribute with the given values.
 | ||||
| DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values, | ||||
|                                         Builder* builder) { | ||||
|  | @ -245,10 +265,14 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results, | |||
| //===----------------------------------------------------------------------===//
 | ||||
| // GetDimensionSizeOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| //
 | ||||
| static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); } | ||||
| 
 | ||||
| /// Fold get_dimension_size when the said shape dimension is a constant.
 | ||||
| OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) { | ||||
|   RankedTensorType type = operand().getType().cast<RankedTensorType>(); | ||||
|   RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>(); | ||||
|   if (!type) return {}; | ||||
| 
 | ||||
|   int32_t dim = dimension(); | ||||
|   if (type.isDynamic(dim)) return {}; | ||||
|   // The result type is always is a 0-d i32 tensor.
 | ||||
|  | @ -1724,6 +1748,35 @@ LogicalResult SelectOp::reifyReturnTypeShapes( | |||
|                                      &reifiedReturnShapes); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // SetDimensionSizeOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| static LogicalResult Verify(SetDimensionSizeOp op) { | ||||
|   if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) { | ||||
|     if (size.getRank() != 0) | ||||
|       return op.emitOpError() << "size operand should be of rank-0"; | ||||
|   } | ||||
| 
 | ||||
|   return VerifyDimAttr(op); | ||||
| } | ||||
| 
 | ||||
| OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) { | ||||
|   DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>(); | ||||
|   if (input) return input; | ||||
| 
 | ||||
|   DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>(); | ||||
|   if (!size || !size.isSplat()) return {}; | ||||
| 
 | ||||
|   auto ty = getType().dyn_cast<RankedTensorType>(); | ||||
|   if (!ty) return {}; | ||||
| 
 | ||||
|   int64_t dim_size = ty.getDimSize(dimension()); | ||||
|   if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt()) | ||||
|     return operand(); | ||||
|   return {}; | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // PadOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -909,13 +909,23 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf | |||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @fold_get_dimension_size | ||||
| func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> { | ||||
|   %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32> | ||||
| func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> { | ||||
|   %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32> | ||||
|   return %size : tensor<i32> | ||||
|   // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32> | ||||
|   // CHECK-NEXT: return %[[C]] | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @fold_set_dimension_size | ||||
| // CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>) | ||||
| func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> { | ||||
|   %dim = mhlo.constant dense<512> : tensor<i32> | ||||
|   %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32> | ||||
|   return %result : tensor<1x128x512xf32> | ||||
| 
 | ||||
|   // CHECK-NEXT: return %[[I]] | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @fold_select_same | ||||
| func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> { | ||||
|   %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> | ||||
|  |  | |||
|  | @ -1246,3 +1246,38 @@ func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { | |||
|   %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32> | ||||
|   return %0 : tensor<2x4xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> { | ||||
|   // expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}} | ||||
|   %size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor<i32> | ||||
|   return %size : tensor<i32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> { | ||||
|   %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32> | ||||
|   return %size : tensor<i32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> { | ||||
|   %dim = mhlo.constant dense<512> : tensor<1xi32> | ||||
| 
 | ||||
|   // expected-error@+1 {{size operand should be of rank-0}} | ||||
|   %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<1xi32>) -> tensor<1x128x512xf32> | ||||
|   return %result : tensor<1x128x512xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> { | ||||
|   %dim = mhlo.constant dense<512> : tensor<i32> | ||||
| 
 | ||||
|   // expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}} | ||||
|   %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32> | ||||
|   return %result : tensor<1x128x512xf32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue