Update GetDimensionSize and SetDimensionSize ops to use I64 attribute for dimension
This is to match with HLO semantics and general dimension semantics in MLIR. Also, * Define minimal verifier for these ops. * Add folder for SetDimensionSize op on static shaped dimension. * Fix assumption of ranked shape in GetDimensionSize op. PiperOrigin-RevId: 341150923
This commit is contained in:
parent
3238f8226f
commit
4ef12aa000
|
@ -1017,7 +1017,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
||||||
BASE_HLO_GetDimensionSizeOp {
|
BASE_HLO_GetDimensionSizeOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
I32Attr:$dimension
|
I64Attr:$dimension
|
||||||
);
|
);
|
||||||
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
|
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
|
||||||
// XLA semantics is available. This limitation is because of the current XLA
|
// XLA semantics is available. This limitation is because of the current XLA
|
||||||
|
@ -1129,9 +1129,11 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
I32Tensor:$size,
|
I32Tensor:$size,
|
||||||
I32Attr:$dimension
|
I64Attr:$dimension
|
||||||
);
|
);
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
|
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
|
||||||
|
|
|
@ -86,6 +86,26 @@ namespace {
|
||||||
// Utilities for the canonicalize patterns
|
// Utilities for the canonicalize patterns
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Verifies that dimension attribute for the op correctly indexes in operand or
|
||||||
|
// result shape.
|
||||||
|
template <typename OpT>
|
||||||
|
static LogicalResult VerifyDimAttr(OpT op) {
|
||||||
|
int64_t rank = -1;
|
||||||
|
if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
|
||||||
|
rank = ty.getRank();
|
||||||
|
} else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
|
||||||
|
rank = ty.getRank();
|
||||||
|
} else {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim = op.dimension();
|
||||||
|
if (dim < 0 || dim >= rank)
|
||||||
|
return op.emitOpError() << "requires dimension attribute in range [0, "
|
||||||
|
<< rank << "); found (" << dim << ")";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Returns 1D 64-bit dense elements attribute with the given values.
|
// Returns 1D 64-bit dense elements attribute with the given values.
|
||||||
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
|
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
|
||||||
Builder* builder) {
|
Builder* builder) {
|
||||||
|
@ -245,10 +265,14 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GetDimensionSizeOp
|
// GetDimensionSizeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
|
||||||
|
|
||||||
/// Fold get_dimension_size when the said shape dimension is a constant.
|
/// Fold get_dimension_size when the said shape dimension is a constant.
|
||||||
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
|
||||||
RankedTensorType type = operand().getType().cast<RankedTensorType>();
|
RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type) return {};
|
||||||
|
|
||||||
int32_t dim = dimension();
|
int32_t dim = dimension();
|
||||||
if (type.isDynamic(dim)) return {};
|
if (type.isDynamic(dim)) return {};
|
||||||
// The result type is always is a 0-d i32 tensor.
|
// The result type is always is a 0-d i32 tensor.
|
||||||
|
@ -1724,6 +1748,35 @@ LogicalResult SelectOp::reifyReturnTypeShapes(
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SetDimensionSizeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(SetDimensionSizeOp op) {
|
||||||
|
if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
if (size.getRank() != 0)
|
||||||
|
return op.emitOpError() << "size operand should be of rank-0";
|
||||||
|
}
|
||||||
|
|
||||||
|
return VerifyDimAttr(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
if (input) return input;
|
||||||
|
|
||||||
|
DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
if (!size || !size.isSplat()) return {};
|
||||||
|
|
||||||
|
auto ty = getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!ty) return {};
|
||||||
|
|
||||||
|
int64_t dim_size = ty.getDimSize(dimension());
|
||||||
|
if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
|
||||||
|
return operand();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PadOp
|
// PadOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -910,12 +910,22 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_get_dimension_size
|
// CHECK-LABEL: func @fold_get_dimension_size
|
||||||
func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
||||||
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
|
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||||
return %size : tensor<i32>
|
return %size : tensor<i32>
|
||||||
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
|
||||||
// CHECK-NEXT: return %[[C]]
|
// CHECK-NEXT: return %[[C]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_set_dimension_size
|
||||||
|
// CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>)
|
||||||
|
func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||||
|
%dim = mhlo.constant dense<512> : tensor<i32>
|
||||||
|
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||||
|
return %result : tensor<1x128x512xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: return %[[I]]
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_select_same
|
// CHECK-LABEL: func @fold_select_same
|
||||||
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
|
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
|
||||||
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
|
|
@ -1246,3 +1246,38 @@ func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||||
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
return %0 : tensor<2x4xf32>
|
return %0 : tensor<2x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
||||||
|
// expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
|
||||||
|
%size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||||
|
return %size : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
|
||||||
|
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
|
||||||
|
return %size : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||||
|
%dim = mhlo.constant dense<512> : tensor<1xi32>
|
||||||
|
|
||||||
|
// expected-error@+1 {{size operand should be of rank-0}}
|
||||||
|
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<1xi32>) -> tensor<1x128x512xf32>
|
||||||
|
return %result : tensor<1x128x512xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||||
|
%dim = mhlo.constant dense<512> : tensor<i32>
|
||||||
|
|
||||||
|
// expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
|
||||||
|
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||||
|
return %result : tensor<1x128x512xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue