diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 771827e..f48b0a5 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -3030,6 +3030,13 @@ struct SimplifyConcatSlice : public OpRewritePattern { return failure(); } + // If there's nothing to slice that means the output is an empty tensor and + // there is dead code. We do nothing here and rely on other passes to clean + // this up. + if (subset_size == 0) { + return failure(); + } + if (subset_size > 1 && !concat.getResult().hasOneUse()) { return failure(); } diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index a7f2339..e07c8ba 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -391,6 +391,16 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg return %1 : tensor<2x5xf32> } +// CHECK-LABEL: slice_concat_empty +func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<0x5xf32>) + %2 = "mhlo.concatenate"(%1, %arg2) { dimension = 0 : i64 } : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32> + + // CHECK: return %arg2 + return %2 : tensor<1x5xf32> +} + // CHECK-LABEL: func @broadcast_in_dim_identity func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { // CHECK: return %arg0