From ab6ee1181351e837da486d69028cff572e0319e2 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Thu, 10 Dec 2020 20:21:49 -0800 Subject: [PATCH] Fix folding of HLO SliceOp with zero elements This was causing division by zero in this case. PiperOrigin-RevId: 346920942 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 6 ++++++ tests/canonicalize.mlir | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 6b7b235..082e202 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2332,6 +2332,12 @@ static Attribute FoldSlice(SliceOp* op, I values) { auto shape = result_type.getShape(); int64_t count = result_type.getNumElements(); + if (count == 0) { + return DenseElementsAttr::get( + op->getResult().getType().cast(), + /*list=*/{}); + } + // Compute the striding for each dimension. llvm::SmallVector sizes; sizes.reserve(shape.size()); diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 7f27252..e50f7f3 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -327,6 +327,15 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> { return %1 : tensor<4x1xi64> } +// CHECK-LABEL: slice_zero_elements +func @slice_zero_elements() -> tensor<0xi64> { + %0 = mhlo.constant dense<> : tensor<0xi64> + // CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64> + %1 = "mhlo.slice"(%0) { limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<0xi64>) -> (tensor<0xi64>) + // CHECK: return %[[CONST]] : tensor<0xi64> + return %1 : tensor<0xi64> +} + // CHECK-LABEL: slice_unknown_shape func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>