[MLIR][MHLO] Merge assuming ops with compatible witnesses

PiperOrigin-RevId: 366018349
This commit is contained in:
A. Unique TensorFlower 2021-03-31 06:10:15 -07:00 committed by TensorFlow MLIR Team
parent c8157ba4df
commit bbe0aa204c
2 changed files with 118 additions and 2 deletions

View File

@ -98,7 +98,7 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
}
};
/// Move operation into a preceeding assuming op. This allows to process
/// Move operation into a preceding assuming op. This allows to process
/// operations that depend on the assuming op's results. It will eventually
/// allow to make assuming regions' constraints independent from each other.
template <typename OpTy>
@ -107,7 +107,7 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Only move into immediately preceeding `assuming` op.
// Only move into immediately preceding `assuming` op.
auto assuming_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!assuming_op) return failure();
@ -239,6 +239,81 @@ struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
}
};
/// Merge assuming regions if their constraints are independent from each other.
struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::AssumingOp op,
PatternRewriter &rewriter) const override {
// Merge assuming op with directly preceding one.
auto preceding_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!preceding_op) return failure();
// For now, both witnesses must be cstr_broadcastable.
// TODO(frgossen): Generalize this.
auto bcastable_a = op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
auto bcastable_b =
preceding_op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
if (!bcastable_a || !bcastable_b) return failure();
// Merge witnesses.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(preceding_op);
SmallVector<Value, 8> new_operands;
new_operands.append(bcastable_a->getOperands().begin(),
bcastable_a->getOperands().end());
new_operands.append(bcastable_b->getOperands().begin(),
bcastable_b->getOperands().end());
Value new_witness = rewriter.create<shape::CstrBroadcastableOp>(
bcastable_a.getLoc(), new_operands);
// Merge assuming ops.
Block *body_a = preceding_op.getBody();
Block *body_b = op.getBody();
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
// Copy preceding op's body.
BlockAndValueMapping mapping;
for (auto &nested : body_a->without_terminator()) {
b.clone(nested, mapping);
}
// Map result values of preceding assuming op.
auto yield_op_a =
llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
for (auto pair :
llvm::zip(preceding_op->getResults(), yield_op_a.operands())) {
mapping.map(std::get<0>(pair), mapping.lookup(std::get<1>(pair)));
}
// Copy op's body.
for (auto &nested : body_b->without_terminator()) {
b.clone(nested, mapping);
}
// Collect merged assuming op's results.
SmallVector<Value, 4> mapped_results;
auto yield_op_b =
llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
for (Value v : yield_op_a.operands()) {
mapped_results.push_back(mapping.lookupOrDefault(v));
}
for (Value v : yield_op_b.operands()) {
mapped_results.push_back(mapping.lookupOrDefault(v));
}
return mapped_results;
});
// Replace the two assuming ops with the new corresponding results.
ValueRange new_results = new_assuming_op->getResults();
size_t split_at = preceding_op->getNumResults();
rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
rewriter.replaceOp(op, new_results.drop_front(split_at));
return success();
}
};
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> {
@ -311,6 +386,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
// clang-format off
patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,

View File

@ -245,3 +245,43 @@ func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
}
return %0 : tensor<2xindex>
}
// -----
// CHECK: @merge_assuming_ops
// CHECK: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
// CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
// CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]]
// CHECK-SAME: {
// CHECK: "some.op"
// CHECK: %[[RESULT0:.*]] = "some.producer"
// CHECK: "another.op"
// CHECK: %[[RESULT1:.*]] = "another.producer"
// CHECK: shape.assuming_yield %[[RESULT0]], %[[RESULT1]]
// CHECK: }
// CHECK: return %[[MERGED]]#1
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
%2 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
%3 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
%4 = shape.cstr_broadcastable %0, %1, %2 : tensor<2xindex>, tensor<2xindex>,
tensor<3xindex>
%5 = shape.assuming %3 -> (tensor<?x32xf16>) {
"some.op"() : () -> ()
%6 = "some.producer"() : () -> tensor<?x32xf16>
shape.assuming_yield %6 : tensor<?x32xf16>
}
%7 = shape.assuming %4 -> (tensor<?x?x32xf16>) {
"another.op"() : () -> ()
%8 = "another.producer"() : () -> tensor<?x?x32xf16>
shape.assuming_yield %8 : tensor<?x?x32xf16>
}
return %7 : tensor<?x?x32xf16>
}