diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 3e591c1..389c579 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic void DynamicReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); + results.insert( + context); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/hlo_patterns.td b/lib/Dialect/mhlo/IR/hlo_patterns.td index bdb3e3c..776732b 100644 --- a/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat< (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), (replaceWithValue $x)>; +def ShapeOfDynamicReshape : Pat< + (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), + (replaceWithValue $shape)>; diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 0fde1db..8470f36 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor< return %0 : tensor<4x1xf32> } +// CHECK-LABEL: func @shape_of_dynamic_reshape +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] +func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> { + // CHECK: return [[ARG1]] + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<2xindex> + return %1 : tensor<2xindex> +} + // CHECK-LABEL: do_not_dce_while_with_outfeed func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { // CHECK: mhlo.while