Fix folding of HLO SliceOp with zero elements
This was causing division by zero in this case. PiperOrigin-RevId: 346920942
This commit is contained in:
parent
f232da1f9d
commit
ab6ee11813
|
@ -2332,6 +2332,12 @@ static Attribute FoldSlice(SliceOp* op, I values) {
|
||||||
|
|
||||||
auto shape = result_type.getShape();
|
auto shape = result_type.getShape();
|
||||||
int64_t count = result_type.getNumElements();
|
int64_t count = result_type.getNumElements();
|
||||||
|
if (count == 0) {
|
||||||
|
return DenseElementsAttr::get<E>(
|
||||||
|
op->getResult().getType().cast<ShapedType>(),
|
||||||
|
/*list=*/{});
|
||||||
|
}
|
||||||
|
|
||||||
// Compute the striding for each dimension.
|
// Compute the striding for each dimension.
|
||||||
llvm::SmallVector<int64_t, 6> sizes;
|
llvm::SmallVector<int64_t, 6> sizes;
|
||||||
sizes.reserve(shape.size());
|
sizes.reserve(shape.size());
|
||||||
|
|
|
@ -327,6 +327,15 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
|
||||||
return %1 : tensor<4x1xi64>
|
return %1 : tensor<4x1xi64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: slice_zero_elements
|
||||||
|
func @slice_zero_elements() -> tensor<0xi64> {
|
||||||
|
%0 = mhlo.constant dense<> : tensor<0xi64>
|
||||||
|
// CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64>
|
||||||
|
%1 = "mhlo.slice"(%0) { limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<0xi64>) -> (tensor<0xi64>)
|
||||||
|
// CHECK: return %[[CONST]] : tensor<0xi64>
|
||||||
|
return %1 : tensor<0xi64>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: slice_unknown_shape
|
// CHECK-LABEL: slice_unknown_shape
|
||||||
func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
// CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
Loading…
Reference in New Issue