[MLIR][MHLO] Move `cstr_broadcastable` and `shape_of` out of `assuming` regions

Add pattern to move operations out of assuming op. This only valid for
constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It will
eventually allow to make assuming regions' constraints independent from each
other so that they can be merged.

PiperOrigin-RevId: 365993145
This commit is contained in:
A. Unique TensorFlower 2021-03-31 02:38:29 -07:00 committed by TensorFlow MLIR Team
parent af2aaa6144
commit 8ade5d78c8
2 changed files with 106 additions and 12 deletions

View File

@ -162,6 +162,36 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
}
};
/// Move operation out of assuming op. This is only valid for
/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
/// will eventually allow to make assuming regions' constraints independent from
/// each other.
template <typename OpTy>
struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Must be inside of an assuming op.
auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
if (!assuming_op) return failure();
// Operands must not be defined within the assuming op.
Block *body = assuming_op.getBody();
auto is_available = [&](Value v) {
Operation *def = v.getDefiningOp();
return def == nullptr || def->getBlock() != body;
};
if (!llvm::all_of(op->getOperands(), is_available)) return failure();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(assuming_op);
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
op->getOperands(), op->getAttrs());
return success();
}
};
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> {
@ -236,6 +266,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
MoveUpBroadcastInDimOpPattern,
ShapeReificationPattern>(context);
// clang-format on

View File

@ -114,36 +114,40 @@ func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
// -----
// CHECK-LABEL: @move_shape_of_into_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>, %[[ARG2:.*]]: tensor<?x32xf32>)
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>)
func @move_shape_of_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
%arg1 : tensor<?x32xf32>) -> tensor<3xindex> {
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]]
// CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<?x32xf32>
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[DUMMY_TENSOR]]
// CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[SHAPE]]
// CHECK: }
// CHECK-NOT: shape_of
// CHECK: return %[[ASSUMING_RESULTS]]#2
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
%1 = "dummy.tensor"() : () -> tensor<?x32xf32>
shape.assuming_yield %arg1, %1 : tensor<?x32xf32>, tensor<?x32xf32>
}
%1 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
return %1 : tensor<3xindex>
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
return %2 : tensor<3xindex>
}
// -----
// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>)
func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
%arg1 : tensor<2xindex>) -> !shape.witness {
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]]
// CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<3xindex>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[DUMMY_TENSOR]]
// CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[WITNESS]]
// CHECK: }
// CHECK-NOT: cstr_broadcastable
// CHECK: return %[[ASSUMING_RESULTS]]#2
%0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
%1 = "dummy.tensor"() : () -> tensor<3xindex>
shape.assuming_yield %arg1, %1 : tensor<2xindex>, tensor<3xindex>
}
%1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
return %1 : !shape.witness
@ -167,3 +171,61 @@ func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
return %2 : tensor<3xindex>
}
// -----
// CHECK-LABEL: @move_cstr_broadcastable_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (!shape.witness) {
// CHECK-NOT: cstr_broadcastable
// CHECK: shape.assuming_yield %[[WITNESS]]
// CHECK: }
// CHECK: return %[[ASSUMING_RESULT]]
%0 = shape.assuming %arg0 -> (!shape.witness) {
%1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
shape.assuming_yield %1 : !shape.witness
}
return %0 : !shape.witness
}
// -----
// CHECK-LABEL: @move_shape_of_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (tensor<2xindex>) {
// CHECK-NOT: shape_of
// CHECK: shape.assuming_yield %[[SHAPE]]
// CHECK: }
// CHECK: return %[[ASSUMING_RESULT]]
%0 = shape.assuming %arg0 -> (tensor<2xindex>) {
%1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
shape.assuming_yield %1 : tensor<2xindex>
}
return %0 : tensor<2xindex>
}
// -----
// CHECK-LABEL: @not_move_shape_of_out_of_assuming
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
%arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
// CHECK-NOT: shape_of
// CHECK: shape.assuming
// CHECK-SAME: {
// CHECK: "some.tensor"
// CHECK: shape_of
// CHECK: }
%0 = shape.assuming %arg0 -> (tensor<2xindex>) {
%1 = "some.tensor"() : () -> tensor<2x?xf32>
%2 = shape.shape_of %1 : tensor<2x?xf32> -> tensor<2xindex>
shape.assuming_yield %2 : tensor<2xindex>
}
return %0 : tensor<2xindex>
}