diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 35e4845..b8f14fc 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -45,7 +45,10 @@ bool IsShapeOfOpMovable(Value arg) { struct ShapeOfOpConversion : public OpConversionPattern { explicit ShapeOfOpConversion(MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) { + // Recursively reify until we hit an op that doesn't support it. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite( shape::ShapeOfOp op, ArrayRef operands, diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index a9acfbd..449e755 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -21,7 +21,8 @@ func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> // CHECK: "use"(%[[SHAPE]]) %0 = mhlo.subtract %arg0, %arg1 : tensor - %1 = shape.shape_of %0 : tensor -> tensor - "use"(%1) : (tensor) -> () + %1 = mhlo.subtract %0, %arg1 : tensor + %2 = shape.shape_of %1 : tensor -> tensor + "use"(%2) : (tensor) -> () return }