diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 8d04e10..b59d819 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -69,6 +69,33 @@ struct ShapeReificationPattern : public OpRewritePattern { } }; +template +struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Find all the shape operands, direct and indirect. + SmallVector inlined_operands; + for (Value direct : op->getOperands()) { + if (auto bcast_op = direct.getDefiningOp()) { + for (Value indirect : bcast_op->getOperands()) + inlined_operands.push_back(indirect); + } else { + inlined_operands.push_back(direct); + } + } + + // Only rewrite if it makes a difference. + if (inlined_operands.size() == op.getNumOperands()) return failure(); + + // Inline shape operands. + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + inlined_operands, op->getAttrs()); + return success(); + } +}; + // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -124,12 +151,9 @@ struct MoveUpDynamicBroadcastsForFusionPass } void runOnFunction() override { - // Populate rewrite patterns. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns); - - // Apply transformation. if (failed( applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { return signalPassFailure(); @@ -142,8 +166,10 @@ struct MoveUpDynamicBroadcastsForFusionPass void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off - patterns->insert(context); + patterns->insert< + InlineBroadcastedShapeOperandsPattern, + MoveUpBroadcastInDimOpPattern, + ShapeReificationPattern>(context); // clang-format on } diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index bb53553..3e97be9 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -97,3 +97,18 @@ func @cast_sub(%arg0: tensor, %arg1: tensor) } return %4 : tensor } + +// ----- + +// CHECK-LABEL: @inline_bcasted_shape_operands +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) +func @inline_bcasted_shape_operands(%a : tensor, %b : tensor, + %c : tensor) -> !shape.witness { + // CHECK-NOT: shape.broadcast + // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]] + // CHECK: return %[[WITNESS]] : !shape.witness + %0 = shape.broadcast %a, %b : tensor, tensor + -> tensor + %1 = shape.cstr_broadcastable %0, %c : tensor, tensor + return %1 : !shape.witness +}