diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 34a3ffc..8d04e10 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -33,76 +33,60 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace mhlo { namespace { -bool IsShapeOfOpMovable(Value arg) { - return arg.getDefiningOp(); -} - -struct ShapeOfOpConversion : public OpConversionPattern { - explicit ShapeOfOpConversion(MLIRContext *context) - : OpConversionPattern(context) { +struct ShapeReificationPattern : public OpRewritePattern { + explicit ShapeReificationPattern(MLIRContext *context) + : OpRewritePattern(context) { // Recursively reify until we hit an op that doesn't support it. setHasBoundedRewriteRecursion(); } - LogicalResult matchAndRewrite( - shape::ShapeOfOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - shape::ShapeOfOp::Adaptor transformed(operands); - + LogicalResult matchAndRewrite(shape::ShapeOfOp op, + PatternRewriter &rewriter) const override { // Only reify shape computation if operand allows for it. - if (!IsShapeOfOpMovable(transformed.arg())) return failure(); + auto shape_origin = op.arg().getDefiningOp(); + if (!shape_origin) return failure(); - auto shape_origin = - transformed.arg().getDefiningOp(); - llvm::SmallVector reified_shapes; - if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes))) + llvm::SmallVector reifications; + if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reifications))) return failure(); + assert(reifications.size() == 1); + Value reified_shape = reifications.front(); - assert(reified_shapes.size() == 1); - Value reified_shape = reified_shapes.front(); + // Insert cast if needed. if (reified_shape.getType() != op.getType()) { reified_shape = rewriter.create(op.getLoc(), op.getType(), reified_shape); } - rewriter.replaceOp(op, reified_shapes.front()); + rewriter.replaceOp(op, reified_shape); return success(); } }; -// We can only move up broadcasting ops that apply to the result of a -// shape-preserving operation. -bool isDynamicBroadcastInDimOpMovable(Value operand) { - Operation *producer_op = operand.getDefiningOp(); - return producer_op != nullptr && - producer_op->hasTrait() && - producer_op->hasTrait(); -} - // TODO(frgossen): Only move up broadcasting operations if there is a consumer. -struct MoveUpBroadcastInDimOpConversion - : public OpConversionPattern { - explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context) - : OpConversionPattern(context) {} +struct MoveUpBroadcastInDimOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite( - DynamicBroadcastInDimOp bcast_op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - DynamicBroadcastInDimOp::Adaptor transformed(operands); - if (!isDynamicBroadcastInDimOpMovable(transformed.operand())) + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op, + PatternRewriter &rewriter) const override { + Operation *producer_op = bcast_op.operand().getDefiningOp(); + if (!producer_op || + !producer_op->hasTrait() || + !producer_op->hasTrait()) { return failure(); + } // Materialize broadcast on operands. SmallVector bcasted_operands; Location loc = bcast_op.getLoc(); ArrayRef ty_shape = bcast_op.getType().getShape(); - Operation *producer_op = transformed.operand().getDefiningOp(); for (Value operand : producer_op->getOperands()) { // The broadcast only works on ranked operations. auto operand_ty = operand.getType().dyn_cast(); @@ -114,7 +98,7 @@ struct MoveUpBroadcastInDimOpConversion auto bcasted_operand_ty = RankedTensorType::get(ty_shape, operand_ty.getElementType()); bcasted_operands.push_back(rewriter.create( - loc, bcasted_operand_ty, operand, transformed.output_dimensions(), + loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(), bcast_op.broadcast_dimensions())); } @@ -140,18 +124,14 @@ struct MoveUpDynamicBroadcastsForFusionPass } void runOnFunction() override { - // Setup target legality. - MLIRContext &ctx = getContext(); - ConversionTarget target(ctx); - PopulateMoveUpDynamicBroadcastsForFusionLegality(&target); - // Populate rewrite patterns. - OwningRewritePatternList patterns(&ctx); - mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns); + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns); // Apply transformation. - if (failed(applyPartialConversion(getFunction(), target, - std::move(patterns)))) { + if (failed( + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { return signalPassFailure(); } } @@ -159,23 +139,11 @@ struct MoveUpDynamicBroadcastsForFusionPass } // namespace -void PopulateMoveUpDynamicBroadcastsForFusionLegality( - ConversionTarget *target) { - target->addLegalDialect(); - target->addDynamicallyLegalOp( - [](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); }); - target->addDynamicallyLegalOp( - [](DynamicBroadcastInDimOp op) { - return !isDynamicBroadcastInDimOpMovable(op.operand()); - }); -} - void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off - patterns->insert(context); + patterns->insert(context); // clang-format on } diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index b97f1ab..bb53553 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -5,7 +5,8 @@ // CHECK-SAME: (%[[ARG:.*]]: tensor) func @shape_of_unary(%arg : tensor) { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor<2xindex> - // CHECK: "use"(%[[SHAPE]]) + // CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor + // CHECK: "use"(%[[CASTED]]) %0 = "mhlo.convert"(%arg) : (tensor) -> tensor %1 = shape.shape_of %0 : tensor -> tensor "use"(%1) : (tensor) -> () @@ -19,7 +20,8 @@ func @shape_of_unary(%arg : tensor) { // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) func @shape_of_nary(%arg0 : tensor, %arg1 : tensor) { // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> - // CHECK: "use"(%[[SHAPE]]) + // CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor + // CHECK: "use"(%[[CASTED]]) %0 = mhlo.subtract %arg0, %arg1 : tensor %1 = mhlo.subtract %0, %arg1 : tensor %2 = shape.shape_of %1 : tensor -> tensor