[MLIR] Use MLIR provided 0DTensorOf/1DTensorOf instead of custom constraints.
- Use MLIR provided constraints for HLO_ScalarIntTensor and HLO_DimensionTensor. - Update unit tests to expect new error messages. PiperOrigin-RevId: 333313131
This commit is contained in:
parent
7d01a60de8
commit
87d9e6951e
|
@ -45,9 +45,7 @@ def HLO_Token : Type<CPred<"$_self.isa<TokenType>()">, "token">;
|
||||||
def HLO_IntTensor : TensorOf<[HLO_Int]>;
|
def HLO_IntTensor : TensorOf<[HLO_Int]>;
|
||||||
|
|
||||||
// Any integer tensor type with rank 0 (i.e. representing a single integer).
|
// Any integer tensor type with rank 0 (i.e. representing a single integer).
|
||||||
def HLO_ScalarIntTensor : ShapedContainerType<
|
def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>;
|
||||||
[HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>,
|
|
||||||
"a 0-dim integer tensor">;
|
|
||||||
|
|
||||||
// Any floating-point tensor types
|
// Any floating-point tensor types
|
||||||
def HLO_FpTensor : TensorOf<[AnyFloat]>;
|
def HLO_FpTensor : TensorOf<[AnyFloat]>;
|
||||||
|
@ -67,10 +65,7 @@ def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
|
||||||
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
|
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
|
||||||
|
|
||||||
// Dynamic representation of a shape vector as a tensor.
|
// Dynamic representation of a shape vector as a tensor.
|
||||||
def HLO_DimensionTensor : ShapedContainerType<
|
def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;
|
||||||
[HLO_DimensionValue],
|
|
||||||
And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
|
||||||
"a 1D tensor of dimensions">;
|
|
||||||
|
|
||||||
// In general, static shaped tensor constraints should be avoided unless
|
// In general, static shaped tensor constraints should be avoided unless
|
||||||
// it is for a legacy op which is only correct with static shapes.
|
// it is for a legacy op which is only correct with static shapes.
|
||||||
|
|
|
@ -739,7 +739,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||||
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
|
// expected-error@+1 {{operand #1 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
|
||||||
%0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
%0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||||
return %0 : tensor<1x4xi32>
|
return %0 : tensor<1x4xi32>
|
||||||
}
|
}
|
||||||
|
@ -755,7 +755,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
|
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
|
||||||
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
|
// expected-error@+1 {{operand #2 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
|
||||||
%0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
|
%0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
|
||||||
return %0 : tensor<3x4xi64>
|
return %0 : tensor<3x4xi64>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue