Handle empty tensors in SimplifyConcatSlice.

If the result of the slice is an empty tensor, do nothing.
This fixes a crash: we can't create a `concat` with an
empty operand range.

PiperOrigin-RevId: 378354956
This commit is contained in:
A. Unique TensorFlower 2021-06-09 02:14:22 -07:00 committed by TensorFlow MLIR Team
parent 1770ed455f
commit d828b457b3
2 changed files with 17 additions and 0 deletions

View File

@ -3030,6 +3030,13 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
return failure();
}
// If there's nothing to slice that means the output is an empty tensor and
// there is dead code. We do nothing here and rely on other passes to clean
// this up.
if (subset_size == 0) {
return failure();
}
if (subset_size > 1 && !concat.getResult().hasOneUse()) {
return failure();
}

View File

@ -391,6 +391,16 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
return %1 : tensor<2x5xf32>
}
// CHECK-LABEL: slice_concat_empty
func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<0x5xf32>)
%2 = "mhlo.concatenate"(%1, %arg2) { dimension = 0 : i64 } : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32>
// CHECK: return %arg2
return %2 : tensor<1x5xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_identity
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
// CHECK: return %arg0