Update GetDimensionSize and SetDimensionSize ops to use I64 attribute for dimension
This is to match with HLO semantics and general dimension semantics in MLIR. Also, * Define minimal verifier for these ops. * Add folder for SetDimensionSize op on static shaped dimension. * Fix assumption of ranked shape in GetDimensionSize op. PiperOrigin-RevId: 341150923
This commit is contained in:
parent
3238f8226f
commit
4ef12aa000
|
@ -1017,7 +1017,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
|||
BASE_HLO_GetDimensionSizeOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
I32Attr:$dimension
|
||||
I64Attr:$dimension
|
||||
);
|
||||
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
|
||||
// XLA semantics is available. This limitation is because of the current XLA
|
||||
|
@ -1129,9 +1129,11 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
|
|||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
I32Tensor:$size,
|
||||
I32Attr:$dimension
|
||||
I64Attr:$dimension
|
||||
);
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
|
||||
|
|
|
@ -86,6 +86,26 @@ namespace {
|
|||
// Utilities for the canonicalize patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that dimension attribute for the op correctly indexes in operand or
|
||||
// result shape.
|
||||
template <typename OpT>
|
||||
static LogicalResult VerifyDimAttr(OpT op) {
|
||||
int64_t rank = -1;
|
||||
if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
|
||||
rank = ty.getRank();
|
||||
} else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
|
||||
rank = ty.getRank();
|
||||
} else {
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t dim = op.dimension();
|
||||
if (dim < 0 || dim >= rank)
|
||||
return op.emitOpError() << "requires dimension attribute in range [0, "
|
||||
<< rank << "); found (" << dim << ")";
|
||||
return success();
|
||||
}
|
||||
|
||||
// Returns 1D 64-bit dense elements attribute with the given values.
|
||||
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
|
||||
Builder* builder) {
|
||||
|
@ -245,10 +265,14 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
// GetDimensionSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
|
||||
|
||||
/// Fold get_dimension_size when the said shape dimension is a constant.
|
||||
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
||||
RankedTensorType type = operand().getType().cast<RankedTensorType>();
|
||||
RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!type) return {};
|
||||
|
||||
int32_t dim = dimension();
|
||||
if (type.isDynamic(dim)) return {};
|
||||
// The result type is always is a 0-d i32 tensor.
|
||||
|
@ -1724,6 +1748,35 @@ LogicalResult SelectOp::reifyReturnTypeShapes(
|
|||
&reifiedReturnShapes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SetDimensionSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SetDimensionSizeOp op) {
|
||||
if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (size.getRank() != 0)
|
||||
return op.emitOpError() << "size operand should be of rank-0";
|
||||
}
|
||||
|
||||
return VerifyDimAttr(op);
|
||||
}
|
||||
|
||||
OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
|
||||
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||
if (input) return input;
|
||||
|
||||
DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||
if (!size || !size.isSplat()) return {};
|
||||
|
||||
auto ty = getType().dyn_cast<RankedTensorType>();
|
||||
if (!ty) return {};
|
||||
|
||||
int64_t dim_size = ty.getDimSize(dimension());
|
||||
if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
|
||||
return operand();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -909,13 +909,23 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_get_dimension_size
|
||||
func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
|
||||
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||
func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
||||
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||
return %size : tensor<i32>
|
||||
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
||||
// CHECK-NEXT: return %[[C]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_set_dimension_size
|
||||
// CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>)
|
||||
func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||
%dim = mhlo.constant dense<512> : tensor<i32>
|
||||
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||
return %result : tensor<1x128x512xf32>
|
||||
|
||||
// CHECK-NEXT: return %[[I]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_select_same
|
||||
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
|
||||
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
|
|
@ -1246,3 +1246,38 @@ func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
|||
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
||||
// expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
|
||||
%size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||
return %size : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
||||
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||
return %size : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||
%dim = mhlo.constant dense<512> : tensor<1xi32>
|
||||
|
||||
// expected-error@+1 {{size operand should be of rank-0}}
|
||||
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<1xi32>) -> tensor<1x128x512xf32>
|
||||
return %result : tensor<1x128x512xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||
%dim = mhlo.constant dense<512> : tensor<i32>
|
||||
|
||||
// expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
|
||||
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||
return %result : tensor<1x128x512xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue