Improve error message for improperly shaped slice indices.

The slice indices must be rank-1 and have the same number of elements of the
rank of the operand. Give reasonable error messages for violations of these
requirements instead of a misleading error message that the types of the
indices don't all match.

PiperOrigin-RevId: 340660822
This commit is contained in:
Richard Uhler 2020-11-04 09:10:12 -08:00 committed by TensorFlow MLIR Team
parent f9c87731d9
commit 82031b356c
2 changed files with 39 additions and 4 deletions

View File

@ -2173,10 +2173,21 @@ LogicalResult SliceOp::inferReturnTypes(
return success(); return success();
} }
int64_t rank = ranked_ty.getRank();
ShapedType attr_ty = slice.start_indices().getType(); ShapedType attr_ty = slice.start_indices().getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || if (attr_ty.getRank() != 1) {
!attr_ty.getElementType().isSignlessInteger(64) || return emitOptionalError(location, "start_indices has rank ",
attr_ty.getRank(), " instead of required rank 1");
}
int64_t rank = ranked_ty.getRank();
if (attr_ty.getNumElements() != rank) {
return emitOptionalError(
location, "the number of elements in start_indices (",
attr_ty.getNumElements(), ") does not match the rank of the operand (",
rank, ")");
}
if (!attr_ty.getElementType().isSignlessInteger(64) ||
slice.limit_indices().getType() != attr_ty || slice.limit_indices().getType() != attr_ty ||
slice.strides().getType() != attr_ty) { slice.strides().getType() != attr_ty) {
// Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp

View File

@ -700,7 +700,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> {
// expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}} // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}}
%0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2]> : tensor<2xi64>, limit_indices = dense<[2, 4, 1]> : tensor<3xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32> return %0 : tensor<1x4xi32>
} }
@ -714,6 +714,30 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> {
// ----- // -----
func @slice_indices_not_rank_1(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
// expected-error@+1 {{start_indices has rank 2 instead of required rank 1}}
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[[1, 0]]> : tensor<1x2xi64>,
limit_indices = dense<[[2, 4]]> : tensor<1x2xi64>,
strides = dense<[[1, 2]]> : tensor<1x2xi64>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// -----
func @slice_indices_wrong_size(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> {
// expected-error@+1 {{the number of elements in start_indices (3) does not match the rank of the operand (2)}}
%0 = "mhlo.slice"(%arg0) {
start_indices = dense<[1, 0, 0]> : tensor<3xi64>,
limit_indices = dense<[2, 4, 0]> : tensor<3xi64>,
strides = dense<[1, 2, 0]> : tensor<3xi64>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
// -----
// CHECK-LABEL: func @dynamic_slice // CHECK-LABEL: func @dynamic_slice
func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> { func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
%0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>