diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index edc0ab2..0dd2f8e 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2173,10 +2173,21 @@ LogicalResult SliceOp::inferReturnTypes( return success(); } - int64_t rank = ranked_ty.getRank(); ShapedType attr_ty = slice.start_indices().getType(); - if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || - !attr_ty.getElementType().isSignlessInteger(64) || + if (attr_ty.getRank() != 1) { + return emitOptionalError(location, "start_indices has rank ", + attr_ty.getRank(), " instead of required rank 1"); + } + + int64_t rank = ranked_ty.getRank(); + if (attr_ty.getNumElements() != rank) { + return emitOptionalError( + location, "the number of elements in start_indices (", + attr_ty.getNumElements(), ") does not match the rank of the operand (", + rank, ")"); + } + + if (!attr_ty.getElementType().isSignlessInteger(64) || slice.limit_indices().getType() != attr_ty || slice.strides().getType() != attr_ty) { // Unfortunately we can't rely on the AllTypesMatch trait for the SliceOp diff --git a/tests/ops.mlir b/tests/ops.mlir index d22f7d1..2e05697 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -700,7 +700,7 @@ func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}} - %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 2]> : tensor<2xi64>, limit_indices = dense<[2, 4, 1]> : tensor<3xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -714,6 +714,30 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // ----- +func @slice_indices_not_rank_1(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { + // expected-error@+1 {{start_indices has rank 2 instead of required rank 1}} + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[[1, 0]]> : tensor<1x2xi64>, + limit_indices = dense<[[2, 4]]> : tensor<1x2xi64>, + strides = dense<[[1, 2]]> : tensor<1x2xi64> + } : (tensor<3x4xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} + +// ----- + +func @slice_indices_wrong_size(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { + // expected-error@+1 {{the number of elements in start_indices (3) does not match the rank of the operand (2)}} + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 0, 0]> : tensor<3xi64>, + limit_indices = dense<[2, 4, 0]> : tensor<3xi64>, + strides = dense<[1, 2, 0]> : tensor<3xi64> + } : (tensor<3x4xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} + +// ----- + // CHECK-LABEL: func @dynamic_slice func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { %0 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32>