[MLIR][HLO] Extend broadcast propagation pass to enable more fusion

Move element-wise operations into assuming regions. This enables fusion
opportunities within the region.

PiperOrigin-RevId: 378362725
This commit is contained in:
A. Unique TensorFlower 2021-06-09 03:02:17 -07:00 committed by TensorFlow MLIR Team
parent d828b457b3
commit b9e45007d5
2 changed files with 99 additions and 54 deletions

View File

@ -99,6 +99,62 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
} }
}; };
LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
PatternRewriter &rewriter) {
// Only move into immediately preceding `assuming` op.
auto assuming_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!assuming_op) return failure();
Block *body = assuming_op.getBody();
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
// Find the operands to use if the op was within the assuming region. We
// will later use their copies, as we copy the assuming op and its body.
SmallVector<Value, 8> new_operands_unmapped =
llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
for (auto result : llvm::enumerate(assuming_op->getResults())) {
if (result.value() == v) return yield_op->getOperand(result.index());
}
return v;
}));
// Insert the rewritten assuming op right before the old one.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
// Copy body.
BlockAndValueMapping mapping;
for (auto &nested : body->without_terminator())
b.clone(nested, mapping);
// Copy op into the new body and use the mapped operands.
for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
Value old_operand, new_operand_unmapped;
std::tie(old_operand, new_operand_unmapped) = it;
mapping.map(old_operand,
mapping.lookupOrDefault(new_operand_unmapped));
}
Operation *new_op = b.clone(*op, mapping);
// Yield the previous results and also the new ones.
auto mapped_results = llvm::to_vector<8>(llvm::map_range(
yield_op.operands(),
[&](Value v) { return mapping.lookupOrDefault(v); }));
mapped_results.append(new_op->getResults().begin(),
new_op->getResults().end());
return mapped_results;
});
// Replace the assuming op and the root op with the corresponding result
// value.
ValueRange new_assuming_op_results = new_assuming_op->getResults();
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
rewriter.replaceOp(op, new_assuming_op_results.back());
return success();
}
/// Move operation into a preceding assuming op. This allows to process /// Move operation into a preceding assuming op. This allows to process
/// operations that depend on the assuming op's results. It will eventually /// operations that depend on the assuming op's results. It will eventually
/// allow to make assuming regions' constraints independent from each other. /// allow to make assuming regions' constraints independent from each other.
@ -108,59 +164,24 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op, LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Only move into immediately preceding `assuming` op. return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
auto assuming_op = }
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode()); };
if (!assuming_op) return failure();
Block *body = assuming_op.getBody(); // Move elementwise operations into assuming regions. This will eventually allow
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator()); // for more fusion opportunities.
struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
// Find the operands to use if the op was within the assuming region. We LogicalResult matchAndRewrite(Operation *op,
// will later use their copies, as we copy the assuming op and its body. PatternRewriter &rewriter) const override {
SmallVector<Value, 8> new_operands_unmapped; // Apply to all elementwise and broadcasting elementwise operations.
for (auto operand : op->getOperands()) { if (!op->hasTrait<OpTrait::Elementwise>() &&
new_operands_unmapped.push_back(operand); !op->hasTrait<chlo::OpTrait::BroadcastingElementwise>())
for (auto result : llvm::enumerate(assuming_op->getResults())) { return failure();
if (result.value() == operand)
new_operands_unmapped.back() = yield_op->getOperand(result.index());
}
}
// Insert the rewritten assuming op right before the old one. return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.witness(),
[&](OpBuilder &b, Location loc) {
// Copy body.
BlockAndValueMapping mapping;
for (auto &nested : body->without_terminator())
b.clone(nested, mapping);
// Copy op into the new body and use the mapped operands.
SmallVector<Value, 2> new_operands;
for (Value v_unmapped : new_operands_unmapped) {
Value v = mapping.lookupOrDefault(v_unmapped);
new_operands.push_back(v);
}
Value new_op = b.create<OpTy>(loc, op->getResultTypes(), new_operands,
op->getAttrs());
// Yield the previous results and also the new one.
SmallVector<Value, 2> mapped_results;
for (auto result : yield_op.operands())
mapped_results.push_back(mapping.lookupOrDefault(result));
mapped_results.push_back(new_op);
return mapped_results;
});
// Replace the assuming op and the root op with the corresponding result
// value.
ValueRange new_assuming_op_results = new_assuming_op->getResults();
rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
rewriter.replaceOp(op, new_assuming_op_results.back());
return success();
} }
}; };
@ -307,8 +328,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
} }
}; };
// TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct EarlyBroadcastInDimOpPattern
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> { : public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern; using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
@ -380,11 +400,12 @@ void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
patterns->insert< patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern, MergeAssumingOpsPattern,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>, MoveElementwiseOpsIntoAssumingOpPattern,
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>, MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>, MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>, MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
MoveUpBroadcastInDimOpPattern, EarlyBroadcastInDimOpPattern,
ShapeReificationPattern>(context); ShapeReificationPattern>(context);
// clang-format on // clang-format on
mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,

View File

@ -193,6 +193,30 @@ func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
// ----- // -----
// CHECK-LABEL: @move_elementwise_into_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?xf32>)
func @move_elementwise_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[RES:.*]] = shape.assuming %[[ARG0]]
// CHECK: %[[SOME:.*]] = "some.op"
// CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]])
// CHECK: %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]]
// CHECK: shape.assuming_yield %[[BCAST_ADD]]
// CHECK-NOT: tanh
// CHECK-NOT: broadcast_add
// CHECK: return %[[RES]]
%0:2 = shape.assuming %arg0 -> (tensor<?xf32>, tensor<?xf32>) {
%1 = "some.op"() : () -> tensor<?xf32>
shape.assuming_yield %arg1, %1 : tensor<?xf32>, tensor<?xf32>
}
%1 = "mhlo.tanh"(%arg1) : (tensor<?xf32>) -> tensor<?xf32>
%2 = chlo.broadcast_add %1, %0#1
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %2 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @move_shape_of_out_of_assuming // CHECK-LABEL: @move_shape_of_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) // CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness, func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,