Fix mhlo::SliceOp::fold to not crash on unknown shapes
PiperOrigin-RevId: 329504383
This commit is contained in:
parent
dcab119cec
commit
a622bf479b
|
@ -1772,11 +1772,11 @@ static Attribute FoldSlice(SliceOp* op, I values) {
|
|||
|
||||
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Check if the SliceOp is a NoOp operation.
|
||||
auto operand_shape = getOperand().getType().cast<ShapedType>().getShape();
|
||||
auto operand_type = getOperand().getType().cast<ShapedType>();
|
||||
auto result_type = getResult().getType().cast<ShapedType>();
|
||||
auto result_shape = result_type.getShape();
|
||||
|
||||
if (result_type.hasStaticShape() && (operand_shape == result_shape)) {
|
||||
if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
|
||||
(operand_type.getShape() == result_type.getShape())) {
|
||||
return getOperand();
|
||||
}
|
||||
|
||||
|
|
|
@ -301,6 +301,13 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
|
|||
return %1 : tensor<4x1xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_unknown_shape
|
||||
func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_first
|
||||
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
|
|
Loading…
Reference in New Issue