Fold away shape.shape_of(mhlo.dynamic_reshape(inp, shape))

This specific pattern can be replaced with the shape
passed to dynamic_reshape. This is implemented as a
canonicalization on mhlo.dynamic_reshape to fit in
the infrastructure of canonicalization.

PiperOrigin-RevId: 342009365
This commit is contained in:
Tres Popp 2020-11-12 02:47:40 -08:00 committed by TensorFlow MLIR Team
parent 7fc4985eae
commit 1dffa62fe9
3 changed files with 15 additions and 1 deletions

View File

@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
void DynamicReshapeOp::getCanonicalizationPatterns( void DynamicReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicReshapeOpNotActuallyDynamic>(context); results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
(replaceWithValue $x)>; (replaceWithValue $x)>;
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>;

View File

@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
return %0 : tensor<4x1xf32> return %0 : tensor<4x1xf32>
} }
// CHECK-LABEL: func @shape_of_dynamic_reshape
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
// CHECK: return [[ARG1]]
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
return %1 : tensor<2xindex>
}
// CHECK-LABEL: do_not_dce_while_with_outfeed // CHECK-LABEL: do_not_dce_while_with_outfeed
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> { func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: mhlo.while // CHECK: mhlo.while