[MLIR][HLO] Add `shape.broadcast` canonicalization to unblock broadcast moving

PiperOrigin-RevId: 372120309
This commit is contained in:
A. Unique TensorFlower 2021-05-05 07:15:58 -07:00 committed by TensorFlow MLIR Team
parent 6bc854f5d9
commit d8c40b691c
2 changed files with 40 additions and 8 deletions

View File

@ -330,6 +330,36 @@ struct CanonicalizeCastedShapeOfOpPattern
}
};
// TODO(frgossen): Remove this once it has landed upstream.
struct CanonicalizeBroadcastPattern
: public OpRewritePattern<shape::BroadcastOp> {
using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::BroadcastOp op,
PatternRewriter &rewriter) const override {
// Only concretize dynamic extent tensor result types.
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
if (!resultTy || !resultTy.isDynamicDim(0)) return failure();
// Infer resulting shape rank if possible.
int64_t maxRank = 0;
for (Value shape : op.shapes()) {
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
// Cannot infer resulting shape rank if any operand is dynamically
// ranked.
if (extentTensorTy.isDynamicDim(0)) return failure();
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
}
}
auto newOp = rewriter.create<shape::BroadcastOp>(
op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()),
op.shapes());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> {
@ -401,6 +431,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
CanonicalizeBroadcastPattern,
CanonicalizeCastedShapeOfOpPattern,
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern,
@ -411,6 +442,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MoveUpBroadcastInDimOpPattern,
ShapeReificationPattern>(context);
// clang-format on
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
}
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {

View File

@ -337,7 +337,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
// CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]]
// CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]]
// CHECK-SAME: {
// CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]]
// CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]]
// CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]>
@ -346,7 +345,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
// CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]]
// CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]]
// CHECK: shape.assuming_yield %[[RESULT]]
// CHECK: }
// CHECK: return %[[ASSUMING_RESULT]]
%0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
%1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
@ -354,9 +352,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
%3 = shape.assuming %2 -> (tensor<?x32xf16>) {
%8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex>
%9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex>
%10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
%13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
%10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex>
%11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
%13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16>
%14 = mhlo.subtract %12, %13 : tensor<?x32xf16>
shape.assuming_yield %14 : tensor<?x32xf16>
}
@ -366,9 +365,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>,
%7 = shape.assuming %6 -> (tensor<?x?x32xf16>) {
%8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex>
%9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex>
%10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
%13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
%10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
%11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
%13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16>
%14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16>
shape.assuming_yield %14 : tensor<?x?x32xf16>
}