Update GetDimensionSize and SetDimensionSize ops to use I64 attribute for dimension
This is to match with HLO semantics and general dimension semantics in MLIR. Also, * Define minimal verifier for these ops. * Add folder for SetDimensionSize op on static shaped dimension. * Fix assumption of ranked shape in GetDimensionSize op. PiperOrigin-RevId: 341150923
This commit is contained in:
		
							parent
							
								
									3238f8226f
								
							
						
					
					
						commit
						4ef12aa000
					
				| 
						 | 
					@ -1017,7 +1017,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
 | 
				
			||||||
      BASE_HLO_GetDimensionSizeOp {
 | 
					      BASE_HLO_GetDimensionSizeOp {
 | 
				
			||||||
  let arguments = (ins
 | 
					  let arguments = (ins
 | 
				
			||||||
    HLO_Tensor:$operand,
 | 
					    HLO_Tensor:$operand,
 | 
				
			||||||
    I32Attr:$dimension
 | 
					    I64Attr:$dimension
 | 
				
			||||||
  );
 | 
					  );
 | 
				
			||||||
  // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
 | 
					  // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
 | 
				
			||||||
  // XLA semantics is available. This limitation is because of the current XLA
 | 
					  // XLA semantics is available. This limitation is because of the current XLA
 | 
				
			||||||
| 
						 | 
					@ -1129,9 +1129,11 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
 | 
				
			||||||
  let arguments = (ins
 | 
					  let arguments = (ins
 | 
				
			||||||
    HLO_Tensor:$operand,
 | 
					    HLO_Tensor:$operand,
 | 
				
			||||||
    I32Tensor:$size,
 | 
					    I32Tensor:$size,
 | 
				
			||||||
    I32Attr:$dimension
 | 
					    I64Attr:$dimension
 | 
				
			||||||
  );
 | 
					  );
 | 
				
			||||||
  let results = (outs HLO_Tensor);
 | 
					  let results = (outs HLO_Tensor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let hasFolder = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
 | 
					def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,6 +86,26 @@ namespace {
 | 
				
			||||||
// Utilities for the canonicalize patterns
 | 
					// Utilities for the canonicalize patterns
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Verifies that dimension attribute for the op correctly indexes in operand or
 | 
				
			||||||
 | 
					// result shape.
 | 
				
			||||||
 | 
					template <typename OpT>
 | 
				
			||||||
 | 
					static LogicalResult VerifyDimAttr(OpT op) {
 | 
				
			||||||
 | 
					  int64_t rank = -1;
 | 
				
			||||||
 | 
					  if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
 | 
				
			||||||
 | 
					    rank = ty.getRank();
 | 
				
			||||||
 | 
					  } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
 | 
				
			||||||
 | 
					    rank = ty.getRank();
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int64_t dim = op.dimension();
 | 
				
			||||||
 | 
					  if (dim < 0 || dim >= rank)
 | 
				
			||||||
 | 
					    return op.emitOpError() << "requires dimension attribute in range [0, "
 | 
				
			||||||
 | 
					                            << rank << "); found (" << dim << ")";
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Returns 1D 64-bit dense elements attribute with the given values.
 | 
					// Returns 1D 64-bit dense elements attribute with the given values.
 | 
				
			||||||
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
 | 
					DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
 | 
				
			||||||
                                        Builder* builder) {
 | 
					                                        Builder* builder) {
 | 
				
			||||||
| 
						 | 
					@ -245,10 +265,14 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// GetDimensionSizeOp
 | 
					// GetDimensionSizeOp
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Fold get_dimension_size when the said shape dimension is a constant.
 | 
					/// Fold get_dimension_size when the said shape dimension is a constant.
 | 
				
			||||||
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
 | 
					OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
 | 
				
			||||||
  RankedTensorType type = operand().getType().cast<RankedTensorType>();
 | 
					  RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					  if (!type) return {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int32_t dim = dimension();
 | 
					  int32_t dim = dimension();
 | 
				
			||||||
  if (type.isDynamic(dim)) return {};
 | 
					  if (type.isDynamic(dim)) return {};
 | 
				
			||||||
  // The result type is always is a 0-d i32 tensor.
 | 
					  // The result type is always is a 0-d i32 tensor.
 | 
				
			||||||
| 
						 | 
					@ -1724,6 +1748,35 @@ LogicalResult SelectOp::reifyReturnTypeShapes(
 | 
				
			||||||
                                     &reifiedReturnShapes);
 | 
					                                     &reifiedReturnShapes);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// SetDimensionSizeOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static LogicalResult Verify(SetDimensionSizeOp op) {
 | 
				
			||||||
 | 
					  if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
 | 
				
			||||||
 | 
					    if (size.getRank() != 0)
 | 
				
			||||||
 | 
					      return op.emitOpError() << "size operand should be of rank-0";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return VerifyDimAttr(op);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
 | 
				
			||||||
 | 
					  DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
 | 
				
			||||||
 | 
					  if (input) return input;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
 | 
				
			||||||
 | 
					  if (!size || !size.isSplat()) return {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto ty = getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					  if (!ty) return {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int64_t dim_size = ty.getDimSize(dimension());
 | 
				
			||||||
 | 
					  if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
 | 
				
			||||||
 | 
					    return operand();
 | 
				
			||||||
 | 
					  return {};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// PadOp
 | 
					// PadOp
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -909,13 +909,23 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @fold_get_dimension_size
 | 
					// CHECK-LABEL: func @fold_get_dimension_size
 | 
				
			||||||
func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
 | 
					func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
 | 
				
			||||||
  %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
 | 
					  %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
 | 
				
			||||||
  return %size : tensor<i32>
 | 
					  return %size : tensor<i32>
 | 
				
			||||||
  // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
 | 
					  // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
 | 
				
			||||||
  // CHECK-NEXT: return %[[C]]
 | 
					  // CHECK-NEXT: return %[[C]]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @fold_set_dimension_size
 | 
				
			||||||
 | 
					// CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>)
 | 
				
			||||||
 | 
					func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
 | 
				
			||||||
 | 
					  %dim = mhlo.constant dense<512> : tensor<i32>
 | 
				
			||||||
 | 
					  %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
 | 
				
			||||||
 | 
					  return %result : tensor<1x128x512xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-NEXT: return %[[I]]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: func @fold_select_same
 | 
					// CHECK-LABEL: func @fold_select_same
 | 
				
			||||||
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
 | 
					func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
 | 
				
			||||||
  %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
					  %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1246,3 +1246,38 @@ func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
 | 
				
			||||||
  %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
 | 
					  %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
 | 
				
			||||||
  return %0 : tensor<2x4xf32>
 | 
					  return %0 : tensor<2x4xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
 | 
				
			||||||
 | 
					  // expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
 | 
				
			||||||
 | 
					  %size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
 | 
				
			||||||
 | 
					  return %size : tensor<i32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
 | 
				
			||||||
 | 
					  %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
 | 
				
			||||||
 | 
					  return %size : tensor<i32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
 | 
				
			||||||
 | 
					  %dim = mhlo.constant dense<512> : tensor<1xi32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // expected-error@+1 {{size operand should be of rank-0}}
 | 
				
			||||||
 | 
					  %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<1xi32>) -> tensor<1x128x512xf32>
 | 
				
			||||||
 | 
					  return %result : tensor<1x128x512xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
 | 
				
			||||||
 | 
					  %dim = mhlo.constant dense<512> : tensor<i32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
 | 
				
			||||||
 | 
					  %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
 | 
				
			||||||
 | 
					  return %result : tensor<1x128x512xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue