[MLIR][HLO] Extend broadcast propagation pass to enable more fusion
Move element-wise operations into assuming regions. This enables fusion opportunities within the region. PiperOrigin-RevId: 378362725
This commit is contained in:
parent
d828b457b3
commit
b9e45007d5
|
@ -99,15 +99,8 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Move operation into a preceding assuming op. This allows to process
|
LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
|
||||||
/// operations that depend on the assuming op's results. It will eventually
|
PatternRewriter &rewriter) {
|
||||||
/// allow to make assuming regions' constraints independent from each other.
|
|
||||||
template <typename OpTy>
|
|
||||||
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|
||||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Only move into immediately preceding `assuming` op.
|
// Only move into immediately preceding `assuming` op.
|
||||||
auto assuming_op =
|
auto assuming_op =
|
||||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||||
|
@ -118,40 +111,39 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||||
|
|
||||||
// Find the operands to use if the op was within the assuming region. We
|
// Find the operands to use if the op was within the assuming region. We
|
||||||
// will later use their copies, as we copy the assuming op and its body.
|
// will later use their copies, as we copy the assuming op and its body.
|
||||||
SmallVector<Value, 8> new_operands_unmapped;
|
SmallVector<Value, 8> new_operands_unmapped =
|
||||||
for (auto operand : op->getOperands()) {
|
llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
|
||||||
new_operands_unmapped.push_back(operand);
|
|
||||||
for (auto result : llvm::enumerate(assuming_op->getResults())) {
|
for (auto result : llvm::enumerate(assuming_op->getResults())) {
|
||||||
if (result.value() == operand)
|
if (result.value() == v) return yield_op->getOperand(result.index());
|
||||||
new_operands_unmapped.back() = yield_op->getOperand(result.index());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return v;
|
||||||
|
}));
|
||||||
|
|
||||||
// Insert the rewritten assuming op right before the old one.
|
// Insert the rewritten assuming op right before the old one.
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPoint(assuming_op);
|
rewriter.setInsertionPoint(assuming_op);
|
||||||
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
||||||
assuming_op.getLoc(), assuming_op.witness(),
|
assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
|
||||||
[&](OpBuilder &b, Location loc) {
|
|
||||||
// Copy body.
|
// Copy body.
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
for (auto &nested : body->without_terminator())
|
for (auto &nested : body->without_terminator())
|
||||||
b.clone(nested, mapping);
|
b.clone(nested, mapping);
|
||||||
|
|
||||||
// Copy op into the new body and use the mapped operands.
|
// Copy op into the new body and use the mapped operands.
|
||||||
SmallVector<Value, 2> new_operands;
|
for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
|
||||||
for (Value v_unmapped : new_operands_unmapped) {
|
Value old_operand, new_operand_unmapped;
|
||||||
Value v = mapping.lookupOrDefault(v_unmapped);
|
std::tie(old_operand, new_operand_unmapped) = it;
|
||||||
new_operands.push_back(v);
|
mapping.map(old_operand,
|
||||||
|
mapping.lookupOrDefault(new_operand_unmapped));
|
||||||
}
|
}
|
||||||
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands,
|
Operation *new_op = b.clone(*op, mapping);
|
||||||
op->getAttrs());
|
|
||||||
|
|
||||||
// Yield the previous results and also the new one.
|
// Yield the previous results and also the new ones.
|
||||||
SmallVector<Value, 2> mapped_results;
|
auto mapped_results = llvm::to_vector<8>(llvm::map_range(
|
||||||
for (auto result : yield_op.operands())
|
yield_op.operands(),
|
||||||
mapped_results.push_back(mapping.lookupOrDefault(result));
|
[&](Value v) { return mapping.lookupOrDefault(v); }));
|
||||||
mapped_results.push_back(new_op);
|
mapped_results.append(new_op->getResults().begin(),
|
||||||
|
new_op->getResults().end());
|
||||||
return mapped_results;
|
return mapped_results;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -161,6 +153,35 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||||
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
|
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
|
||||||
rewriter.replaceOp(op, new_assuming_op_results.back());
|
rewriter.replaceOp(op, new_assuming_op_results.back());
|
||||||
return success();
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Move operation into a preceding assuming op. This allows to process
|
||||||
|
/// operations that depend on the assuming op's results. It will eventually
|
||||||
|
/// allow to make assuming regions' constraints independent from each other.
|
||||||
|
template <typename OpTy>
|
||||||
|
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||||
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Move elementwise operations into assuming regions. This will eventually allow
|
||||||
|
// for more fusion opportunities.
|
||||||
|
struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
|
||||||
|
explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
|
||||||
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Apply to all elementwise and broadcasting elementwise operations.
|
||||||
|
if (!op->hasTrait<OpTrait::Elementwise>() &&
|
||||||
|
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -307,8 +328,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
struct EarlyBroadcastInDimOpPattern
|
||||||
struct MoveUpBroadcastInDimOpPattern
|
|
||||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
|
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
@ -380,11 +400,12 @@ void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
MergeAssumingOpsPattern,
|
MergeAssumingOpsPattern,
|
||||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
MoveElementwiseOpsIntoAssumingOpPattern,
|
||||||
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||||
|
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||||
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||||
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
|
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
|
||||||
MoveUpBroadcastInDimOpPattern,
|
EarlyBroadcastInDimOpPattern,
|
||||||
ShapeReificationPattern>(context);
|
ShapeReificationPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
|
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
|
||||||
|
|
|
@ -193,6 +193,30 @@ func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @move_elementwise_into_assuming
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?xf32>)
|
||||||
|
func @move_elementwise_into_assuming(%arg0 : !shape.witness,
|
||||||
|
%arg1 : tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK: %[[RES:.*]] = shape.assuming %[[ARG0]]
|
||||||
|
// CHECK: %[[SOME:.*]] = "some.op"
|
||||||
|
// CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]])
|
||||||
|
// CHECK: %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]]
|
||||||
|
// CHECK: shape.assuming_yield %[[BCAST_ADD]]
|
||||||
|
// CHECK-NOT: tanh
|
||||||
|
// CHECK-NOT: broadcast_add
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%0:2 = shape.assuming %arg0 -> (tensor<?xf32>, tensor<?xf32>) {
|
||||||
|
%1 = "some.op"() : () -> tensor<?xf32>
|
||||||
|
shape.assuming_yield %arg1, %1 : tensor<?xf32>, tensor<?xf32>
|
||||||
|
}
|
||||||
|
%1 = "mhlo.tanh"(%arg1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
%2 = chlo.broadcast_add %1, %0#1
|
||||||
|
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
return %2 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @move_shape_of_out_of_assuming
|
// CHECK-LABEL: @move_shape_of_out_of_assuming
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
|
||||||
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
||||||
|
|
Loading…
Reference in New Issue