[MLIR][MHLO] Do not yield results of ops that were moved out of assuming regions

When an op is moved out of an assuming region we already know statically that it
is independent of the assuming region. Hence, there is no need to yield its
results.

PiperOrigin-RevId: 366001405
This commit is contained in:
A. Unique TensorFlower 2021-03-31 03:49:29 -07:00 committed by TensorFlow MLIR Team
parent 8ade5d78c8
commit 9206805c58
2 changed files with 76 additions and 13 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
@ -184,10 +185,56 @@ struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
};
if (!llvm::all_of(op->getOperands(), is_available)) return failure();
// Move op before the assuming region.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
op->getOperands(), op->getAttrs());
Operation *new_op = rewriter.clone(*op);
rewriter.replaceOp(op, new_op->getResults());
// If the assuming region yields none of the new op's results, these values
// are exclusively used in the assuming op's body. In these cases there is
// no need for further rewrites.
auto is_new_op_result = [&](Value v) {
return llvm::is_contained(new_op->getResults(), v);
};
auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
if (llvm::none_of(yield_op.operands(), is_new_op_result)) return success();
// If the assuming region yields any of the new op's results, these values
// can instead bypass the assuming region. There is no need to yield them
// explicitly as they are assumed to be independent. The assuming op is
// rewritten accordingly.
SmallVector<Value, 2> replacement_values;
auto new_assuming_op = rewriter.create<shape::AssumingOp>(
assuming_op.getLoc(), assuming_op.witness(),
[&](OpBuilder &b, Location) {
// Copy body.
BlockAndValueMapping mapping;
for (Operation &nested : body->without_terminator()) {
b.clone(nested, mapping);
}
// Collect new yield operands.
SmallVector<Value, 2> new_yield_operands;
for (Value result : yield_op.operands()) {
if (is_new_op_result(result)) {
replacement_values.push_back(result);
} else {
new_yield_operands.push_back(mapping.lookup(result));
replacement_values.push_back(nullptr);
}
}
return new_yield_operands;
});
// Use the assuming op's results for the missing replacement values.
auto src = new_assuming_op.getResults().begin();
for (auto &dst : replacement_values) {
if (dst) continue;
dst = *src++;
}
rewriter.replaceOp(assuming_op, replacement_values);
return success();
}
};

View File

@ -179,11 +179,9 @@ func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (!shape.witness) {
// CHECK-NOT: cstr_broadcastable
// CHECK: shape.assuming_yield %[[WITNESS]]
// CHECK: }
// CHECK: return %[[ASSUMING_RESULT]]
// CHECK-NOT: assuming
// CHECK-NOT: cstr_broadcastable
// CHECK: return %[[WITNESS]]
%0 = shape.assuming %arg0 -> (!shape.witness) {
%1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
shape.assuming_yield %1 : !shape.witness
@ -198,12 +196,10 @@ func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (tensor<2xindex>) {
// CHECK-NOT: shape_of
// CHECK: shape.assuming_yield %[[SHAPE]]
// CHECK: }
// CHECK: return %[[ASSUMING_RESULT]]
%0 = shape.assuming %arg0 -> (tensor<2xindex>) {
// CHECK-NOT: assuming
// CHECK-NOT: cstr_broadcastable
// CHECK: return %[[SHAPE]]
%0 = shape.assuming %arg0 -> (tensor<2xindex>) {
%1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
shape.assuming_yield %1 : tensor<2xindex>
}
@ -212,6 +208,26 @@ func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
// -----
// CHECK-LABEL: @move_shape_of_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
// CHECK: %{{.*}} = shape.assuming %[[ARG0]] -> (tensor<2x?xf32>) {
// CHECK: %[[SOME_VAL:.*]] = "some.op"() : () -> tensor<2x?xf32>
// CHECK: shape.assuming_yield %[[SOME_VAL]] : tensor<2x?xf32>
// CHECK: }
// CHECK: return %[[SHAPE]]
%0:2 = shape.assuming %arg0 -> (tensor<2x?xf32>, tensor<2xindex>) {
%1 = "some.op"() : () -> (tensor<2x?xf32>)
%2 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
shape.assuming_yield %1, %2 : tensor<2x?xf32>, tensor<2xindex>
}
return %0#1 : tensor<2xindex>
}
// -----
// CHECK-LABEL: @not_move_shape_of_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,