[MLIR][HLO] Extend broadcast propagation pass to enable more fusion

Move element-wise operations into assuming regions. This enables fusion
opportunities within the region.

PiperOrigin-RevId: 378362725
This commit is contained in:
A. Unique TensorFlower 2021-06-09 03:02:17 -07:00 committed by TensorFlow MLIR Team
parent d828b457b3
commit b9e45007d5
2 changed files with 99 additions and 54 deletions

View File

@ -99,15 +99,8 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
} }
}; };
/// Move operation into a preceding assuming op. This allows to process LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
/// operations that depend on the assuming op's results. It will eventually PatternRewriter &rewriter) {
/// allow to make assuming regions' constraints independent from each other.
template <typename OpTy>
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Only move into immediately preceding `assuming` op. // Only move into immediately preceding `assuming` op.
auto assuming_op = auto assuming_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode()); llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
@ -118,40 +111,39 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
// Find the operands to use if the op was within the assuming region. We // Find the operands to use if the op was within the assuming region. We
// will later use their copies, as we copy the assuming op and its body. // will later use their copies, as we copy the assuming op and its body.
SmallVector<Value, 8> new_operands_unmapped; SmallVector<Value, 8> new_operands_unmapped =
for (auto operand : op->getOperands()) { llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
new_operands_unmapped.push_back(operand);
for (auto result : llvm::enumerate(assuming_op->getResults())) { for (auto result : llvm::enumerate(assuming_op->getResults())) {
if (result.value() == operand) if (result.value() == v) return yield_op->getOperand(result.index());
new_operands_unmapped.back() = yield_op->getOperand(result.index());
}
} }
return v;
}));
// Insert the rewritten assuming op right before the old one. // Insert the rewritten assuming op right before the old one.
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op); rewriter.setInsertionPoint(assuming_op);
auto new_assuming_op = rewriter.create<shape::AssumingOp>( auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.witness(), assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
[&](OpBuilder &b, Location loc) {
// Copy body. // Copy body.
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
for (auto &nested : body->without_terminator()) for (auto &nested : body->without_terminator())
b.clone(nested, mapping); b.clone(nested, mapping);
// Copy op into the new body and use the mapped operands. // Copy op into the new body and use the mapped operands.
SmallVector<Value, 2> new_operands; for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
for (Value v_unmapped : new_operands_unmapped) { Value old_operand, new_operand_unmapped;
Value v = mapping.lookupOrDefault(v_unmapped); std::tie(old_operand, new_operand_unmapped) = it;
new_operands.push_back(v); mapping.map(old_operand,
mapping.lookupOrDefault(new_operand_unmapped));
} }
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands, Operation *new_op = b.clone(*op, mapping);
op->getAttrs());
// Yield the previous results and also the new one. // Yield the previous results and also the new ones.
SmallVector<Value, 2> mapped_results; auto mapped_results = llvm::to_vector<8>(llvm::map_range(
for (auto result : yield_op.operands()) yield_op.operands(),
mapped_results.push_back(mapping.lookupOrDefault(result)); [&](Value v) { return mapping.lookupOrDefault(v); }));
mapped_results.push_back(new_op); mapped_results.append(new_op->getResults().begin(),
new_op->getResults().end());
return mapped_results; return mapped_results;
}); });
@ -162,6 +154,35 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
rewriter.replaceOp(op, new_assuming_op_results.back()); rewriter.replaceOp(op, new_assuming_op_results.back());
return success(); return success();
} }
/// Move operation into a preceding assuming op. This allows to process
/// operations that depend on the assuming op's results. It will eventually
/// allow to make assuming regions' constraints independent from each other.
template <typename OpTy>
struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
}
};
// Move elementwise operations into assuming regions. This will eventually allow
// for more fusion opportunities.
struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Apply to all elementwise and broadcasting elementwise operations.
if (!op->hasTrait<OpTrait::Elementwise>() &&
!op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
return failure();
return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
}
}; };
/// Move operation out of assuming op. This is only valid for /// Move operation out of assuming op. This is only valid for
@ -307,8 +328,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
} }
}; };
// TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct EarlyBroadcastInDimOpPattern
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> { : public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern; using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
@ -380,11 +400,12 @@ void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
patterns->insert< patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern, MergeAssumingOpsPattern,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>, MoveElementwiseOpsIntoAssumingOpPattern,
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>, MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>, MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>, MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
MoveUpBroadcastInDimOpPattern, EarlyBroadcastInDimOpPattern,
ShapeReificationPattern>(context); ShapeReificationPattern>(context);
// clang-format on // clang-format on
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,

View File

@ -193,6 +193,30 @@ func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
// ----- // -----
// CHECK-LABEL: @move_elementwise_into_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?xf32>)
func @move_elementwise_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[RES:.*]] = shape.assuming %[[ARG0]]
// CHECK: %[[SOME:.*]] = "some.op"
// CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]])
// CHECK: %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]]
// CHECK: shape.assuming_yield %[[BCAST_ADD]]
// CHECK-NOT: tanh
// CHECK-NOT: broadcast_add
// CHECK: return %[[RES]]
%0:2 = shape.assuming %arg0 -> (tensor<?xf32>, tensor<?xf32>) {
%1 = "some.op"() : () -> tensor<?xf32>
shape.assuming_yield %arg1, %1 : tensor<?xf32>, tensor<?xf32>
}
%1 = "mhlo.tanh"(%arg1) : (tensor<?xf32>) -> tensor<?xf32>
%2 = chlo.broadcast_add %1, %0#1
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @move_shape_of_out_of_assuming // CHECK-LABEL: @move_shape_of_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) // CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness, func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,