[MLIR][MHLO] Move `cstr_broadcastable` and `shape_of` out of `assuming` regions
Add pattern to move operations out of assuming op. This only valid for constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It will eventually allow to make assuming regions' constraints independent from each other so that they can be merged. PiperOrigin-RevId: 365993145
This commit is contained in:
parent
af2aaa6144
commit
8ade5d78c8
|
@ -162,6 +162,36 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Move operation out of assuming op. This is only valid for
|
||||
/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
|
||||
/// will eventually allow to make assuming regions' constraints independent from
|
||||
/// each other.
|
||||
template <typename OpTy>
|
||||
struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Must be inside of an assuming op.
|
||||
auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
|
||||
if (!assuming_op) return failure();
|
||||
|
||||
// Operands must not be defined within the assuming op.
|
||||
Block *body = assuming_op.getBody();
|
||||
auto is_available = [&](Value v) {
|
||||
Operation *def = v.getDefiningOp();
|
||||
return def == nullptr || def->getBlock() != body;
|
||||
};
|
||||
if (!llvm::all_of(op->getOperands(), is_available)) return failure();
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(assuming_op);
|
||||
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
|
||||
op->getOperands(), op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
|
@ -236,6 +266,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
|
||||
MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
|
||||
MoveUpBroadcastInDimOpPattern,
|
||||
ShapeReificationPattern>(context);
|
||||
// clang-format on
|
||||
|
|
|
@ -114,36 +114,40 @@ func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_shape_of_into_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>, %[[ARG2:.*]]: tensor<?x32xf32>)
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>)
|
||||
func @move_shape_of_into_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> {
|
||||
%arg1 : tensor<?x32xf32>) -> tensor<3xindex> {
|
||||
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) {
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG2]]
|
||||
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[SHAPE]]
|
||||
// CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<?x32xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[DUMMY_TENSOR]]
|
||||
// CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[SHAPE]]
|
||||
// CHECK: }
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: return %[[ASSUMING_RESULTS]]#2
|
||||
%0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) {
|
||||
shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32>
|
||||
%1 = "dummy.tensor"() : () -> tensor<?x32xf32>
|
||||
shape.assuming_yield %arg1, %1 : tensor<?x32xf32>, tensor<?x32xf32>
|
||||
}
|
||||
%1 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||
return %2 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_cstr_broadcastable_into_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>)
|
||||
func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
|
||||
%arg1 : tensor<2xindex>) -> !shape.witness {
|
||||
// CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) {
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
|
||||
// CHECK: shape.assuming_yield %[[ARG1]], %[[ARG2]], %[[WITNESS]]
|
||||
// CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<3xindex>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[DUMMY_TENSOR]]
|
||||
// CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[WITNESS]]
|
||||
// CHECK: }
|
||||
// CHECK-NOT: cstr_broadcastable
|
||||
// CHECK: return %[[ASSUMING_RESULTS]]#2
|
||||
%0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) {
|
||||
shape.assuming_yield %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
|
||||
%1 = "dummy.tensor"() : () -> tensor<3xindex>
|
||||
shape.assuming_yield %arg1, %1 : tensor<2xindex>, tensor<3xindex>
|
||||
}
|
||||
%1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex>
|
||||
return %1 : !shape.witness
|
||||
|
@ -167,3 +171,61 @@ func @not_move_shape_of_into_assuming(%arg0 : !shape.witness,
|
|||
%2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex>
|
||||
return %2 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_cstr_broadcastable_out_of_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>)
|
||||
func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness {
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (!shape.witness) {
|
||||
// CHECK-NOT: cstr_broadcastable
|
||||
// CHECK: shape.assuming_yield %[[WITNESS]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[ASSUMING_RESULT]]
|
||||
%0 = shape.assuming %arg0 -> (!shape.witness) {
|
||||
%1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex>
|
||||
shape.assuming_yield %1 : !shape.witness
|
||||
}
|
||||
return %0 : !shape.witness
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @move_shape_of_out_of_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
|
||||
func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
|
||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] -> (tensor<2xindex>) {
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: shape.assuming_yield %[[SHAPE]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[ASSUMING_RESULT]]
|
||||
%0 = shape.assuming %arg0 -> (tensor<2xindex>) {
|
||||
%1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex>
|
||||
shape.assuming_yield %1 : tensor<2xindex>
|
||||
}
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @not_move_shape_of_out_of_assuming
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>)
|
||||
func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness,
|
||||
%arg1 : tensor<2x?xf32>) -> tensor<2xindex> {
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: shape.assuming
|
||||
// CHECK-SAME: {
|
||||
// CHECK: "some.tensor"
|
||||
// CHECK: shape_of
|
||||
// CHECK: }
|
||||
%0 = shape.assuming %arg0 -> (tensor<2xindex>) {
|
||||
%1 = "some.tensor"() : () -> tensor<2x?xf32>
|
||||
%2 = shape.shape_of %1 : tensor<2x?xf32> -> tensor<2xindex>
|
||||
shape.assuming_yield %2 : tensor<2xindex>
|
||||
}
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue