[MLIR][MHLO] Merge assuming ops with compatible witnesses
PiperOrigin-RevId: 366018349
This commit is contained in:
parent
c8157ba4df
commit
bbe0aa204c
|
@ -98,7 +98,7 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Move operation into a preceeding assuming op. This allows to process
|
||||
/// Move operation into a preceding assuming op. This allows to process
|
||||
/// operations that depend on the assuming op's results. It will eventually
|
||||
/// allow to make assuming regions' constraints independent from each other.
|
||||
template <typename OpTy>
|
||||
|
@ -107,7 +107,7 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only move into immediately preceeding `assuming` op.
|
||||
// Only move into immediately preceding `assuming` op.
|
||||
auto assuming_op =
|
||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||
if (!assuming_op) return failure();
|
||||
|
@ -239,6 +239,81 @@ struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Merge assuming regions if their constraints are independent from each other.
|
||||
struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
||||
using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(shape::AssumingOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Merge assuming op with directly preceding one.
|
||||
auto preceding_op =
|
||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||
if (!preceding_op) return failure();
|
||||
|
||||
// For now, both witnesses must be cstr_broadcastable.
|
||||
// TODO(frgossen): Generalize this.
|
||||
auto bcastable_a = op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
|
||||
auto bcastable_b =
|
||||
preceding_op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
|
||||
if (!bcastable_a || !bcastable_b) return failure();
|
||||
|
||||
// Merge witnesses.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(preceding_op);
|
||||
SmallVector<Value, 8> new_operands;
|
||||
new_operands.append(bcastable_a->getOperands().begin(),
|
||||
bcastable_a->getOperands().end());
|
||||
new_operands.append(bcastable_b->getOperands().begin(),
|
||||
bcastable_b->getOperands().end());
|
||||
Value new_witness = rewriter.create<shape::CstrBroadcastableOp>(
|
||||
bcastable_a.getLoc(), new_operands);
|
||||
|
||||
// Merge assuming ops.
|
||||
Block *body_a = preceding_op.getBody();
|
||||
Block *body_b = op.getBody();
|
||||
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
||||
preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
|
||||
// Copy preceding op's body.
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto &nested : body_a->without_terminator()) {
|
||||
b.clone(nested, mapping);
|
||||
}
|
||||
|
||||
// Map result values of preceding assuming op.
|
||||
auto yield_op_a =
|
||||
llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
|
||||
for (auto pair :
|
||||
llvm::zip(preceding_op->getResults(), yield_op_a.operands())) {
|
||||
mapping.map(std::get<0>(pair), mapping.lookup(std::get<1>(pair)));
|
||||
}
|
||||
|
||||
// Copy op's body.
|
||||
for (auto &nested : body_b->without_terminator()) {
|
||||
b.clone(nested, mapping);
|
||||
}
|
||||
|
||||
// Collect merged assuming op's results.
|
||||
SmallVector<Value, 4> mapped_results;
|
||||
auto yield_op_b =
|
||||
llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
|
||||
for (Value v : yield_op_a.operands()) {
|
||||
mapped_results.push_back(mapping.lookupOrDefault(v));
|
||||
}
|
||||
for (Value v : yield_op_b.operands()) {
|
||||
mapped_results.push_back(mapping.lookupOrDefault(v));
|
||||
}
|
||||
return mapped_results;
|
||||
});
|
||||
|
||||
// Replace the two assuming ops with the new corresponding results.
|
||||
ValueRange new_results = new_assuming_op->getResults();
|
||||
size_t split_at = preceding_op->getNumResults();
|
||||
rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
|
||||
rewriter.replaceOp(op, new_results.drop_front(split_at));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
|
@ -311,6 +386,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
// clang-format off
|
||||
patterns->insert<
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MergeAssumingOpsPattern,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
|
|
|
@ -245,3 +245,43 @@ func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
|||
}
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: @merge_assuming_ops
|
||||
// CHECK: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
||||
func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
|
||||
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
||||
// CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
|
||||
// CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
||||
// CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
||||
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]]
|
||||
// CHECK-SAME: {
|
||||
// CHECK: "some.op"
|
||||
// CHECK: %[[RESULT0:.*]] = "some.producer"
|
||||
// CHECK: "another.op"
|
||||
// CHECK: %[[RESULT1:.*]] = "another.producer"
|
||||
// CHECK: shape.assuming_yield %[[RESULT0]], %[[RESULT1]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[MERGED]]#1
|
||||
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
||||
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
||||
%2 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
|
||||
%3 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
|
||||
%4 = shape.cstr_broadcastable %0, %1, %2 : tensor<2xindex>, tensor<2xindex>,
|
||||
tensor<3xindex>
|
||||
%5 = shape.assuming %3 -> (tensor<?x32xf16>) {
|
||||
"some.op"() : () -> ()
|
||||
%6 = "some.producer"() : () -> tensor<?x32xf16>
|
||||
shape.assuming_yield %6 : tensor<?x32xf16>
|
||||
}
|
||||
%7 = shape.assuming %4 -> (tensor<?x?x32xf16>) {
|
||||
"another.op"() : () -> ()
|
||||
%8 = "another.producer"() : () -> tensor<?x?x32xf16>
|
||||
shape.assuming_yield %8 : tensor<?x?x32xf16>
|
||||
}
|
||||
return %7 : tensor<?x?x32xf16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue