diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 31e6f22..35e4845 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -39,6 +39,10 @@ namespace mlir { namespace mhlo { namespace { +bool IsShapeOfOpMovable(Value arg) { + return arg.getDefiningOp(); +} + struct ShapeOfOpConversion : public OpConversionPattern { explicit ShapeOfOpConversion(MLIRContext *context) : OpConversionPattern(context) {} @@ -48,10 +52,11 @@ struct ShapeOfOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { shape::ShapeOfOp::Adaptor transformed(operands); - auto shape_origin = llvm::dyn_cast_or_null( - transformed.arg().getDefiningOp()); - if (!shape_origin) return failure(); + // Only reify shape computation if operand allows for it. + if (!IsShapeOfOpMovable(transformed.arg())) return failure(); + auto shape_origin = + transformed.arg().getDefiningOp(); llvm::SmallVector reified_shapes; if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes))) return failure(); @@ -96,8 +101,10 @@ struct MoveUpDynamicBroadcastsForFusionPass void PopulateMoveUpDynamicBroadcastsForFusionLegality( ConversionTarget *target) { - target->addLegalDialectaddLegalDialect(); + target->addDynamicallyLegalOp( + [](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); }); } void PopulateMoveUpDynamicBroadcastsForFusionPatterns(