Improve error message for improperly shaped slice indices.
The slice indices must be rank-1 and have the same number of elements of the rank of the operand. Give reasonable error messages for violations of these requirements instead of a misleading error message that the types of the indices don't all match. PiperOrigin-RevId: 340660822
This commit is contained in:
parent
f9c87731d9
commit
82031b356c
|
@ -2173,10 +2173,21 @@ LogicalResult SliceOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
int64_t rank = ranked_ty.getRank();
|
||||
ShapedType attr_ty = slice.start_indices().getType();
|
||||
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
|
||||
!attr_ty.getElementType().isSignlessInteger(64) ||
|
||||
if (attr_ty.getRank() != 1) {
|
||||
return emitOptionalError(location, "start_indices has rank ",
|
||||
attr_ty.getRank(), " instead of required rank 1");
|
||||
}
|
||||
|
||||
int64_t rank = ranked_ty.getRank();
|
||||
if (attr_ty.getNumElements() != rank) {
|
||||
return emitOptionalError(
|
||||
location, "the number of elements in start_indices (",
|
||||
attr_ty.getNumElements(), ") does not match the rank of the operand (",
|
||||
rank, ")");
|
||||
}
|
||||
|
||||
if (!attr_ty.getElementType().isSignlessInteger(64) ||
|
||||
slice.limit_indices().getType() != attr_ty ||
|
||||
slice.strides().getType() != attr_ty) {
|
||||
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp
|
||||
|
|
|
@ -700,7 +700,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
|||
|
||||
func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
|
||||
// expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}}
|
||||
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2]> : tensor<2xi64>, limit_indices = dense<[2, 4, 1]> : tensor<3xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
|
@ -714,6 +714,30 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @slice_indices_not_rank_1(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
// expected-error@+1 {{start_indices has rank 2 instead of required rank 1}}
|
||||
%0 = "mhlo.slice"(%arg0) {
|
||||
start_indices = dense<[[1, 0]]> : tensor<1x2xi64>,
|
||||
limit_indices = dense<[[2, 4]]> : tensor<1x2xi64>,
|
||||
strides = dense<[[1, 2]]> : tensor<1x2xi64>
|
||||
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @slice_indices_wrong_size(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
// expected-error@+1 {{the number of elements in start_indices (3) does not match the rank of the operand (2)}}
|
||||
%0 = "mhlo.slice"(%arg0) {
|
||||
start_indices = dense<[1, 0, 0]> : tensor<3xi64>,
|
||||
limit_indices = dense<[2, 4, 0]> : tensor<3xi64>,
|
||||
strides = dense<[1, 2, 0]> : tensor<3xi64>
|
||||
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_slice
|
||||
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
%0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
|
|
Loading…
Reference in New Issue