[MLIR][HLO] Add `shape.broadcast` canonicalization to unblock broadcast moving
PiperOrigin-RevId: 372120309
This commit is contained in:
parent
6bc854f5d9
commit
d8c40b691c
|
@ -330,6 +330,36 @@ struct CanonicalizeCastedShapeOfOpPattern
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO(frgossen): Remove this once it has landed upstream.
|
||||||
|
struct CanonicalizeBroadcastPattern
|
||||||
|
: public OpRewritePattern<shape::BroadcastOp> {
|
||||||
|
using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(shape::BroadcastOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Only concretize dynamic extent tensor result types.
|
||||||
|
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!resultTy || !resultTy.isDynamicDim(0)) return failure();
|
||||||
|
|
||||||
|
// Infer resulting shape rank if possible.
|
||||||
|
int64_t maxRank = 0;
|
||||||
|
for (Value shape : op.shapes()) {
|
||||||
|
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
// Cannot infer resulting shape rank if any operand is dynamically
|
||||||
|
// ranked.
|
||||||
|
if (extentTensorTy.isDynamicDim(0)) return failure();
|
||||||
|
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newOp = rewriter.create<shape::BroadcastOp>(
|
||||||
|
op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()),
|
||||||
|
op.shapes());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||||
struct MoveUpBroadcastInDimOpPattern
|
struct MoveUpBroadcastInDimOpPattern
|
||||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
|
@ -401,6 +431,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
|
CanonicalizeBroadcastPattern,
|
||||||
CanonicalizeCastedShapeOfOpPattern,
|
CanonicalizeCastedShapeOfOpPattern,
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
MergeAssumingOpsPattern,
|
MergeAssumingOpsPattern,
|
||||||
|
@ -411,6 +442,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MoveUpBroadcastInDimOpPattern,
|
MoveUpBroadcastInDimOpPattern,
|
||||||
ShapeReificationPattern>(context);
|
ShapeReificationPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
|
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
|
||||||
|
|
|
@ -337,7 +337,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
// CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
|
// CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
|
||||||
// CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
|
// CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
|
||||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
|
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
|
||||||
// CHECK-SAME: {
|
|
||||||
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
|
||||||
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
|
||||||
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
|
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
|
||||||
|
@ -346,7 +345,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
|
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
|
||||||
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
|
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
|
||||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||||
// CHECK: }
|
|
||||||
// CHECK: return %[[ASSUMING_RESULT]]
|
// CHECK: return %[[ASSUMING_RESULT]]
|
||||||
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
|
@ -354,9 +352,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
%3 = shape.assuming %2 -> (tensor<?x32xf16>) {
|
%3 = shape.assuming %2 -> (tensor<?x32xf16>) {
|
||||||
%8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
%8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
%9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
%9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
%10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
|
%10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
|
%11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex>
|
||||||
%13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
|
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
|
||||||
|
%13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
|
||||||
%14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
|
%14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
|
||||||
shape.assuming_yield %14 : tensor<?x32xf16>
|
shape.assuming_yield %14 : tensor<?x32xf16>
|
||||||
}
|
}
|
||||||
|
@ -366,9 +365,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
|
||||||
%7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
|
%7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
|
||||||
%8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
|
%8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
|
||||||
%9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
|
%9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
|
||||||
%10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex>
|
%10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
%11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex>
|
||||||
%13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
|
%13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
|
||||||
%14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
|
%14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
|
||||||
shape.assuming_yield %14 : tensor<?x?x32xf16>
|
shape.assuming_yield %14 : tensor<?x?x32xf16>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue