[MLIR] Use MLIR provided 0DTensorOf/1DTensorOf instead of custom constraints.

- Use MLIR provided constraints for HLO_ScalarIntTensor and HLO_DimensionTensor.
- Update unit tests to expect new error messages.

PiperOrigin-RevId: 333313131
This commit is contained in:
Rahul Joshi 2020-09-23 09:56:10 -07:00 committed by TensorFlow MLIR Team
parent 7d01a60de8
commit 87d9e6951e
2 changed files with 4 additions and 9 deletions

View File

@ -45,9 +45,7 @@ def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
def HLO_IntTensor : TensorOf<[HLO_Int]>; def HLO_IntTensor : TensorOf<[HLO_Int]>;
// Any integer tensor type with rank 0 (i.e. representing a single integer). // Any integer tensor type with rank 0 (i.e. representing a single integer).
def HLO_ScalarIntTensor : ShapedContainerType< def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
[HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>,
"a 0-dim integer tensor">;
// Any floating-point tensor types // Any floating-point tensor types
def HLO_FpTensor : TensorOf<[AnyFloat]>; def HLO_FpTensor : TensorOf<[AnyFloat]>;
@ -67,10 +65,7 @@ def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
// Dynamic representation of a shape vector as a tensor. // Dynamic representation of a shape vector as a tensor.
def HLO_DimensionTensor : ShapedContainerType< def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
[HLO_DimensionValue],
And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;
// In general, static shaped tensor constraints should be avoided unless // In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes. // it is for a legacy op which is only correct with static shapes.

View File

@ -739,7 +739,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor
// ----- // -----
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} // expected-error@+1 {{operand #1 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
%0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32> return %0 : tensor<1x4xi32>
} }
@ -755,7 +755,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta
// ----- // -----
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} // expected-error@+1 {{operand #2 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
%0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
return %0 : tensor<3x4xi64> return %0 : tensor<3x4xi64>
} }