Fix folding of HLO SliceOp with zero elements

This was causing division by zero in this case.

PiperOrigin-RevId: 346920942
This commit is contained in:
Smit Hinsu 2020-12-10 20:21:49 -08:00 committed by TensorFlow MLIR Team
parent f232da1f9d
commit ab6ee11813
2 changed files with 15 additions and 0 deletions

View File

@ -2332,6 +2332,12 @@ static Attribute FoldSlice(SliceOp* op, I values) {
auto shape = result_type.getShape(); auto shape = result_type.getShape();
int64_t count = result_type.getNumElements(); int64_t count = result_type.getNumElements();
if (count == 0) {
return DenseElementsAttr::get<E>(
op->getResult().getType().cast<ShapedType>(),
/*list=*/{});
}
// Compute the striding for each dimension. // Compute the striding for each dimension.
llvm::SmallVector<int64_t, 6> sizes; llvm::SmallVector<int64_t, 6> sizes;
sizes.reserve(shape.size()); sizes.reserve(shape.size());

View File

@ -327,6 +327,15 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
return %1 : tensor<4x1xi64> return %1 : tensor<4x1xi64>
} }
// CHECK-LABEL: slice_zero_elements
func @slice_zero_elements() -> tensor<0xi64> {
%0 = mhlo.constant dense<> : tensor<0xi64>
// CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<0xi64>) -> (tensor<0xi64>)
// CHECK: return %[[CONST]] : tensor<0xi64>
return %1 : tensor<0xi64>
}
// CHECK-LABEL: slice_unknown_shape // CHECK-LABEL: slice_unknown_shape
func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>