[MLIR][HLO] Extend broadcast propagation pass to enable more fusion
Move element-wise operations into assuming regions. This enables fusion opportunities within the region. PiperOrigin-RevId: 378362725
This commit is contained in:
parent
d828b457b3
commit
b9e45007d5
|
@ -99,15 +99,8 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Move operation into a preceding assuming op. This allows to process
|
||||
/// operations that depend on the assuming op's results. It will eventually
|
||||
/// allow to make assuming regions' constraints independent from each other.
|
||||
template <typename OpTy>
|
||||
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) {
|
||||
// Only move into immediately preceding `assuming` op.
|
||||
auto assuming_op =
|
||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||
|
@ -118,40 +111,39 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|||
|
||||
// Find the operands to use if the op was within the assuming region. We
|
||||
// will later use their copies, as we copy the assuming op and its body.
|
||||
SmallVector<Value, 8> new_operands_unmapped;
|
||||
for (auto operand : op->getOperands()) {
|
||||
new_operands_unmapped.push_back(operand);
|
||||
SmallVector<Value, 8> new_operands_unmapped =
|
||||
llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
|
||||
for (auto result : llvm::enumerate(assuming_op->getResults())) {
|
||||
if (result.value() == operand)
|
||||
new_operands_unmapped.back() = yield_op->getOperand(result.index());
|
||||
}
|
||||
if (result.value() == v) return yield_op->getOperand(result.index());
|
||||
}
|
||||
return v;
|
||||
}));
|
||||
|
||||
// Insert the rewritten assuming op right before the old one.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(assuming_op);
|
||||
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
||||
assuming_op.getLoc(), assuming_op.witness(),
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
|
||||
// Copy body.
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto &nested : body->without_terminator())
|
||||
b.clone(nested, mapping);
|
||||
|
||||
// Copy op into the new body and use the mapped operands.
|
||||
SmallVector<Value, 2> new_operands;
|
||||
for (Value v_unmapped : new_operands_unmapped) {
|
||||
Value v = mapping.lookupOrDefault(v_unmapped);
|
||||
new_operands.push_back(v);
|
||||
for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
|
||||
Value old_operand, new_operand_unmapped;
|
||||
std::tie(old_operand, new_operand_unmapped) = it;
|
||||
mapping.map(old_operand,
|
||||
mapping.lookupOrDefault(new_operand_unmapped));
|
||||
}
|
||||
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands,
|
||||
op->getAttrs());
|
||||
Operation *new_op = b.clone(*op, mapping);
|
||||
|
||||
// Yield the previous results and also the new one.
|
||||
SmallVector<Value, 2> mapped_results;
|
||||
for (auto result : yield_op.operands())
|
||||
mapped_results.push_back(mapping.lookupOrDefault(result));
|
||||
mapped_results.push_back(new_op);
|
||||
// Yield the previous results and also the new ones.
|
||||
auto mapped_results = llvm::to_vector<8>(llvm::map_range(
|
||||
yield_op.operands(),
|
||||
[&](Value v) { return mapping.lookupOrDefault(v); }));
|
||||
mapped_results.append(new_op->getResults().begin(),
|
||||
new_op->getResults().end());
|
||||
return mapped_results;
|
||||
});
|
||||
|
||||
|
@ -161,6 +153,35 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|||
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
|
||||
rewriter.replaceOp(op, new_assuming_op_results.back());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Move operation into a preceding assuming op. This allows to process
|
||||
/// operations that depend on the assuming op's results. It will eventually
|
||||
/// allow to make assuming regions' constraints independent from each other.
|
||||
template <typename OpTy>
|
||||
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
// Move elementwise operations into assuming regions. This will eventually allow
|
||||
// for more fusion opportunities.
|
||||
struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
|
||||
explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Apply to all elementwise and broadcasting elementwise operations.
|
||||
if (!op->hasTrait<OpTrait::Elementwise>() &&
|
||||
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
|
||||
return failure();
|
||||
|
||||
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -307,8 +328,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
struct EarlyBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
|
||||
|
||||
|
@ -380,11 +400,12 @@ void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
|||
patterns->insert<
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MergeAssumingOpsPattern,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveElementwiseOpsIntoAssumingOpPattern,
|
||||
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveUpBroadcastInDimOpPattern,
|
||||
EarlyBroadcastInDimOpPattern,
|
||||
ShapeReificationPattern>(context);
|
||||
// clang-format on
|
||||
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
|
||||
|
|
|
@ -193,6 +193,30 @@ func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_elementwise_into_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?xf32>)
|
||||
func @move_elementwise_into_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[RES:.*]] = shape.assuming %[[ARG0]]
|
||||
// CHECK: %[[SOME:.*]] = "some.op"
|
||||
// CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]])
|
||||
// CHECK: %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]]
|
||||
// CHECK: shape.assuming_yield %[[BCAST_ADD]]
|
||||
// CHECK-NOT: tanh
|
||||
// CHECK-NOT: broadcast_add
|
||||
// CHECK: return %[[RES]]
|
||||
%0:2 = shape.assuming %arg0 -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
%1 = "some.op"() : () -> tensor<?xf32>
|
||||
shape.assuming_yield %arg1, %1 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
%1 = "mhlo.tanh"(%arg1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%2 = chlo.broadcast_add %1, %0#1
|
||||
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_shape_of_out_of_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
|
||||
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
||||
|
|
Loading…
Reference in New Issue