[MLIR][MHLO] Merge assuming ops with compatible witnesses
PiperOrigin-RevId: 366018349
This commit is contained in:
parent
c8157ba4df
commit
bbe0aa204c
|
@ -98,7 +98,7 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Move operation into a preceeding assuming op. This allows to process
|
/// Move operation into a preceding assuming op. This allows to process
|
||||||
/// operations that depend on the assuming op's results. It will eventually
|
/// operations that depend on the assuming op's results. It will eventually
|
||||||
/// allow to make assuming regions' constraints independent from each other.
|
/// allow to make assuming regions' constraints independent from each other.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -107,7 +107,7 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Only move into immediately preceeding `assuming` op.
|
// Only move into immediately preceding `assuming` op.
|
||||||
auto assuming_op =
|
auto assuming_op =
|
||||||
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||||
if (!assuming_op) return failure();
|
if (!assuming_op) return failure();
|
||||||
|
@ -239,6 +239,81 @@ struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Merge assuming regions if their constraints are independent from each other.
|
||||||
|
struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
||||||
|
using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(shape::AssumingOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Merge assuming op with directly preceding one.
|
||||||
|
auto preceding_op =
|
||||||
|
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
|
||||||
|
if (!preceding_op) return failure();
|
||||||
|
|
||||||
|
// For now, both witnesses must be cstr_broadcastable.
|
||||||
|
// TODO(frgossen): Generalize this.
|
||||||
|
auto bcastable_a = op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
|
||||||
|
auto bcastable_b =
|
||||||
|
preceding_op.witness().getDefiningOp<shape::CstrBroadcastableOp>();
|
||||||
|
if (!bcastable_a || !bcastable_b) return failure();
|
||||||
|
|
||||||
|
// Merge witnesses.
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPoint(preceding_op);
|
||||||
|
SmallVector<Value, 8> new_operands;
|
||||||
|
new_operands.append(bcastable_a->getOperands().begin(),
|
||||||
|
bcastable_a->getOperands().end());
|
||||||
|
new_operands.append(bcastable_b->getOperands().begin(),
|
||||||
|
bcastable_b->getOperands().end());
|
||||||
|
Value new_witness = rewriter.create<shape::CstrBroadcastableOp>(
|
||||||
|
bcastable_a.getLoc(), new_operands);
|
||||||
|
|
||||||
|
// Merge assuming ops.
|
||||||
|
Block *body_a = preceding_op.getBody();
|
||||||
|
Block *body_b = op.getBody();
|
||||||
|
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
|
||||||
|
preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
|
||||||
|
// Copy preceding op's body.
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
for (auto &nested : body_a->without_terminator()) {
|
||||||
|
b.clone(nested, mapping);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map result values of preceding assuming op.
|
||||||
|
auto yield_op_a =
|
||||||
|
llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
|
||||||
|
for (auto pair :
|
||||||
|
llvm::zip(preceding_op->getResults(), yield_op_a.operands())) {
|
||||||
|
mapping.map(std::get<0>(pair), mapping.lookup(std::get<1>(pair)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy op's body.
|
||||||
|
for (auto &nested : body_b->without_terminator()) {
|
||||||
|
b.clone(nested, mapping);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect merged assuming op's results.
|
||||||
|
SmallVector<Value, 4> mapped_results;
|
||||||
|
auto yield_op_b =
|
||||||
|
llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
|
||||||
|
for (Value v : yield_op_a.operands()) {
|
||||||
|
mapped_results.push_back(mapping.lookupOrDefault(v));
|
||||||
|
}
|
||||||
|
for (Value v : yield_op_b.operands()) {
|
||||||
|
mapped_results.push_back(mapping.lookupOrDefault(v));
|
||||||
|
}
|
||||||
|
return mapped_results;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Replace the two assuming ops with the new corresponding results.
|
||||||
|
ValueRange new_results = new_assuming_op->getResults();
|
||||||
|
size_t split_at = preceding_op->getNumResults();
|
||||||
|
rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
|
||||||
|
rewriter.replaceOp(op, new_results.drop_front(split_at));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||||
struct MoveUpBroadcastInDimOpPattern
|
struct MoveUpBroadcastInDimOpPattern
|
||||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
|
@ -311,6 +386,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
|
MergeAssumingOpsPattern,
|
||||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||||
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||||
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||||
|
|
|
@ -245,3 +245,43 @@ func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
||||||
}
|
}
|
||||||
return %0 : tensor<2xindex>
|
return %0 : tensor<2xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: @merge_assuming_ops
|
||||||
|
// CHECK: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
||||||
|
func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
|
%arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> {
|
||||||
|
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]]
|
||||||
|
// CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]]
|
||||||
|
// CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
||||||
|
// CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]]
|
||||||
|
// CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]]
|
||||||
|
// CHECK-SAME: {
|
||||||
|
// CHECK: "some.op"
|
||||||
|
// CHECK: %[[RESULT0:.*]] = "some.producer"
|
||||||
|
// CHECK: "another.op"
|
||||||
|
// CHECK: %[[RESULT1:.*]] = "another.producer"
|
||||||
|
// CHECK: shape.assuming_yield %[[RESULT0]], %[[RESULT1]]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return %[[MERGED]]#1
|
||||||
|
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
%2 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
|
||||||
|
%3 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex>
|
||||||
|
%4 = shape.cstr_broadcastable %0, %1, %2 : tensor<2xindex>, tensor<2xindex>,
|
||||||
|
tensor<3xindex>
|
||||||
|
%5 = shape.assuming %3 -> (tensor<?x32xf16>) {
|
||||||
|
"some.op"() : () -> ()
|
||||||
|
%6 = "some.producer"() : () -> tensor<?x32xf16>
|
||||||
|
shape.assuming_yield %6 : tensor<?x32xf16>
|
||||||
|
}
|
||||||
|
%7 = shape.assuming %4 -> (tensor<?x?x32xf16>) {
|
||||||
|
"another.op"() : () -> ()
|
||||||
|
%8 = "another.producer"() : () -> tensor<?x?x32xf16>
|
||||||
|
shape.assuming_yield %8 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
return %7 : tensor<?x?x32xf16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue