Fix mhlo::SliceOp::fold to not crash on unknown shapes

PiperOrigin-RevId: 329504383
This commit is contained in:
A. Unique TensorFlower 2020-09-01 07:36:57 -07:00 committed by TensorFlow MLIR Team
parent dcab119cec
commit a622bf479b
2 changed files with 10 additions and 3 deletions

View File

@ -1772,11 +1772,11 @@ static Attribute FoldSlice(SliceOp* op, I values) {
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) { OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
// Check if the SliceOp is a NoOp operation. // Check if the SliceOp is a NoOp operation.
auto operand_shape = getOperand().getType().cast<ShapedType>().getShape(); auto operand_type = getOperand().getType().cast<ShapedType>();
auto result_type = getResult().getType().cast<ShapedType>(); auto result_type = getResult().getType().cast<ShapedType>();
auto result_shape = result_type.getShape();
if (result_type.hasStaticShape() && (operand_shape == result_shape)) { if (operand_type.hasStaticShape() && result_type.hasStaticShape() &&
(operand_type.getShape() == result_type.getShape())) {
return getOperand(); return getOperand();
} }

View File

@ -301,6 +301,13 @@ func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
return %1 : tensor<4x1xi64> return %1 : tensor<4x1xi64>
} }
// CHECK-LABEL: slice_unknown_shape
func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: slice_concat_fold_first // CHECK-LABEL: slice_concat_fold_first
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>