[MLIR] Fix merge of assuming ops

Assuming ops can only be merged if their witnesses will dominate the merged
assuming op. This is not the case if the second op's witness is a result of the
first.

PiperOrigin-RevId: 369192868
This commit is contained in:
A. Unique TensorFlower 2021-04-19 04:20:21 -07:00 committed by TensorFlow MLIR Team
parent 0bb866a799
commit 9374a1c0c5
2 changed files with 22 additions and 1 deletions

View File

@ -245,10 +245,12 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
LogicalResult matchAndRewrite(shape::AssumingOp op, LogicalResult matchAndRewrite(shape::AssumingOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Merge assuming op with directly preceding one. // Merge assuming op with directly preceding one if both witnesses are
// availiable.
auto preceding_op = auto preceding_op =
llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode()); llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
if (!preceding_op) return failure(); if (!preceding_op) return failure();
if (op.witness().getDefiningOp() == preceding_op) return failure();
// Merge witnesses. // Merge witnesses.
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);

View File

@ -292,6 +292,25 @@ func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
// ----- // -----
// Do not merge assuming ops if witness will not dominate use.
// CHECK: @do_not_merge_assuming_ops
func @do_not_merge_assuming_ops() {
// CHECK: shape.assuming
// CHECK: shape.assuming
%0 = "some.witness"() : () -> !shape.witness
%1 = shape.assuming %0 -> (!shape.witness) {
%2 = "some.witness"() : () -> !shape.witness
shape.assuming_yield %2 : !shape.witness
}
shape.assuming %1 {
"some.op"() : () -> ()
shape.assuming_yield
}
return
}
// -----
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
// CHECK-LABEL: @sub_sub // CHECK-LABEL: @sub_sub
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)