Update GetDimensionSize and SetDimensionSize ops to use I64 attribute for dimension

This is to match with HLO semantics and general dimension semantics in MLIR.

Also,

* Define minimal verifier for these ops.
* Add folder for SetDimensionSize op on static shaped dimension.
* Fix assumption of ranked shape in GetDimensionSize op.

PiperOrigin-RevId: 341150923
This commit is contained in:
Smit Hinsu 2020-11-06 18:02:03 -08:00 committed by TensorFlow MLIR Team
parent 3238f8226f
commit 4ef12aa000
4 changed files with 105 additions and 5 deletions

View File

@ -1017,7 +1017,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
BASE_HLO_GetDimensionSizeOp {
let arguments = (ins
HLO_Tensor:$operand,
I32Attr:$dimension
I64Attr:$dimension
);
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
// XLA semantics is available. This limitation is because of the current XLA
@ -1129,9 +1129,11 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
let arguments = (ins
HLO_Tensor:$operand,
I32Tensor:$size,
I32Attr:$dimension
I64Attr:$dimension
);
let results = (outs HLO_Tensor);
let hasFolder = 1;
}
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {

View File

@ -86,6 +86,26 @@ namespace {
// Utilities for the canonicalize patterns
//===----------------------------------------------------------------------===//
// Verifies that dimension attribute for the op correctly indexes in operand or
// result shape.
template <typename OpT>
static LogicalResult VerifyDimAttr(OpT op) {
int64_t rank = -1;
if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
rank = ty.getRank();
} else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
rank = ty.getRank();
} else {
return success();
}
int64_t dim = op.dimension();
if (dim < 0 || dim >= rank)
return op.emitOpError() << "requires dimension attribute in range [0, "
<< rank << "); found (" << dim << ")";
return success();
}
// Returns 1D 64-bit dense elements attribute with the given values.
DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder* builder) {
@ -245,10 +265,14 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
//===----------------------------------------------------------------------===//
// GetDimensionSizeOp
//===----------------------------------------------------------------------===//
//
static LogicalResult Verify(GetDimensionSizeOp op) { return VerifyDimAttr(op); }
/// Fold get_dimension_size when the said shape dimension is a constant.
OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
RankedTensorType type = operand().getType().cast<RankedTensorType>();
RankedTensorType type = operand().getType().dyn_cast<RankedTensorType>();
if (!type) return {};
int32_t dim = dimension();
if (type.isDynamic(dim)) return {};
// The result type is always is a 0-d i32 tensor.
@ -1724,6 +1748,35 @@ LogicalResult SelectOp::reifyReturnTypeShapes(
&reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// SetDimensionSizeOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(SetDimensionSizeOp op) {
if (auto size = op.size().getType().dyn_cast<RankedTensorType>()) {
if (size.getRank() != 0)
return op.emitOpError() << "size operand should be of rank-0";
}
return VerifyDimAttr(op);
}
OpFoldResult SetDimensionSizeOp::fold(ArrayRef<Attribute> operands) {
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (input) return input;
DenseElementsAttr size = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!size || !size.isSplat()) return {};
auto ty = getType().dyn_cast<RankedTensorType>();
if (!ty) return {};
int64_t dim_size = ty.getDimSize(dimension());
if (dim_size == size.getSplatValue().cast<IntegerAttr>().getInt())
return operand();
return {};
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//

View File

@ -909,13 +909,23 @@ func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf
}
// CHECK-LABEL: func @fold_get_dimension_size
func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
return %size : tensor<i32>
// CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
// CHECK-NEXT: return %[[C]]
}
// CHECK-LABEL: func @fold_set_dimension_size
// CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>)
func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
%dim = mhlo.constant dense<512> : tensor<i32>
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
return %result : tensor<1x128x512xf32>
// CHECK-NEXT: return %[[I]]
}
// CHECK-LABEL: func @fold_select_same
func @fold_select_same(%arg0 : tensor<f32>, %arg1 : tensor<i1>) -> tensor<f32> {
%1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>

View File

@ -1246,3 +1246,38 @@ func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}
// -----
func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
// expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
%size = "mhlo.get_dimension_size"(%I) {dimension = 3 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
return %size : tensor<i32>
}
// -----
func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<i32> {
%size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i64} : (tensor<1x128x512xf32>) -> tensor<i32>
return %size : tensor<i32>
}
// -----
func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
%dim = mhlo.constant dense<512> : tensor<1xi32>
// expected-error@+1 {{size operand should be of rank-0}}
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor<1xi32>) -> tensor<1x128x512xf32>
return %result : tensor<1x128x512xf32>
}
// -----
func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
%dim = mhlo.constant dense<512> : tensor<i32>
// expected-error@+1 {{requires dimension attribute in range [0, 3); found (3)}}
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
return %result : tensor<1x128x512xf32>
}