From d77b55682231f78d8c5deb2188c7c03da5883263 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 11 Mar 2021 02:51:27 -0800 Subject: [PATCH] [MLIR][MHLO] Allow recursion in the shape_of mover This allows it to push shape_of over a chain of ops all the way to the top. PiperOrigin-RevId: 362249009 --- .../mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc | 5 ++++- tests/move_up_dynamic_broadcasts_for_fusion.mlir | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 35e4845..b8f14fc 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -45,7 +45,10 @@ bool IsShapeOfOpMovable(Value arg) { struct ShapeOfOpConversion : public OpConversionPattern { explicit ShapeOfOpConversion(MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) { + // Recursively reify until we hit an op that doesn't support it. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite( shape::ShapeOfOp op, ArrayRef operands, diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index a9acfbd..449e755 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -21,7 +21,8 @@ func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> // CHECK: "use"(%[[SHAPE]]) %0 = mhlo.subtract %arg0, %arg1 : tensor - %1 = shape.shape_of %0 : tensor -> tensor - "use"(%1) : (tensor) -> () + %1 = mhlo.subtract %0, %arg1 : tensor + %2 = shape.shape_of %1 : tensor -> tensor + "use"(%2) : (tensor) -> () return }