diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 2f80545..b8378fe 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -45,9 +45,7 @@ def HLO_Token : Type()">, "token">; def HLO_IntTensor : TensorOf<[HLO_Int]>; // Any integer tensor type with rank 0 (i.e. representing a single integer). -def HLO_ScalarIntTensor : ShapedContainerType< - [HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>, - "a 0-dim integer tensor">; +def HLO_ScalarIntTensor : 0DTensorOf<[HLO_Int]>; // Any floating-point tensor types def HLO_FpTensor : TensorOf<[AnyFloat]>; @@ -67,10 +65,7 @@ def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>; def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; // Dynamic representation of a shape vector as a tensor. -def HLO_DimensionTensor : ShapedContainerType< - [HLO_DimensionValue], - And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, - "a 1D tensor of dimensions">; +def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>; // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. diff --git a/tests/ops.mlir b/tests/ops.mlir index 0120a7a..aff2f7f 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -739,7 +739,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor // ----- func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} %0 = "mhlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -755,7 +755,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta // ----- func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { - // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + // expected-error@+1 {{operand #2 must be 0D tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} %0 = "mhlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> }