Handle empty tensors in SimplifyConcatSlice.
If the result of the slice is an empty tensor, do nothing. This fixes a crash: we can't create a `concat` with an empty operand range. PiperOrigin-RevId: 378354956
This commit is contained in:
parent
1770ed455f
commit
d828b457b3
|
@ -3030,6 +3030,13 @@ struct SimplifyConcatSlice : public OpRewritePattern<SliceOp> {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there's nothing to slice that means the output is an empty tensor and
|
||||||
|
// there is dead code. We do nothing here and rely on other passes to clean
|
||||||
|
// this up.
|
||||||
|
if (subset_size == 0) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
if (subset_size > 1 && !concat.getResult().hasOneUse()) {
|
if (subset_size > 1 && !concat.getResult().hasOneUse()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -391,6 +391,16 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
|
||||||
return %1 : tensor<2x5xf32>
|
return %1 : tensor<2x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: slice_concat_empty
|
||||||
|
func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||||
|
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||||
|
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<0x5xf32>)
|
||||||
|
%2 = "mhlo.concatenate"(%1, %arg2) { dimension = 0 : i64 } : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32>
|
||||||
|
|
||||||
|
// CHECK: return %arg2
|
||||||
|
return %2 : tensor<1x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @broadcast_in_dim_identity
|
// CHECK-LABEL: func @broadcast_in_dim_identity
|
||||||
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||||
// CHECK: return %arg0
|
// CHECK: return %arg0
|
||||||
|
|
Loading…
Reference in New Issue