From d8c40b691cca952ac1ae6cc8d322143f6fc9d464 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 May 2021 07:15:58 -0700 Subject: [PATCH] [MLIR][HLO] Add `shape.broadcast` canonicalization to unblock broadcast moving PiperOrigin-RevId: 372120309 --- .../move_up_dynamic_broadcasts_for_fusion.cc | 32 +++++++++++++++++++ ...move_up_dynamic_broadcasts_for_fusion.mlir | 16 +++++----- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 689c7bb..82bb69c 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -330,6 +330,36 @@ struct CanonicalizeCastedShapeOfOpPattern } }; +// TODO(frgossen): Remove this once it has landed upstream. +struct CanonicalizeBroadcastPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::BroadcastOp op, + PatternRewriter &rewriter) const override { + // Only concretize dynamic extent tensor result types. + auto resultTy = op.getType().dyn_cast(); + if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); + + // Infer resulting shape rank if possible. + int64_t maxRank = 0; + for (Value shape : op.shapes()) { + if (auto extentTensorTy = shape.getType().dyn_cast()) { + // Cannot infer resulting shape rank if any operand is dynamically + // ranked. + if (extentTensorTy.isDynamicDim(0)) return failure(); + maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); + } + } + + auto newOp = rewriter.create( + op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), + op.shapes()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + return success(); + } +}; + // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -401,6 +431,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< + CanonicalizeBroadcastPattern, CanonicalizeCastedShapeOfOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, @@ -411,6 +442,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on + tensor::CastOp::getCanonicalizationPatterns(*patterns, context); } std::unique_ptr createMoveUpDynamicBroadcastsForFusionPass() { diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index df20e06..eb56b65 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -337,7 +337,6 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, // CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] // CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]] - // CHECK-SAME: { // CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] // CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] // CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> @@ -346,7 +345,6 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, // CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] // CHECK: shape.assuming_yield %[[RESULT]] - // CHECK: } // CHECK: return %[[ASSUMING_RESULT]] %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> @@ -354,9 +352,10 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, %3 = shape.assuming %2 -> (tensor) { %8 = shape.shape_of %arg0 : tensor -> tensor<2xindex> %9 = shape.shape_of %arg1 : tensor -> tensor<2xindex> - %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> - %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor - %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor + %11 = tensor.cast %10 : tensor to tensor<2xindex> + %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor %14 = mhlo.subtract %12, %13 : tensor shape.assuming_yield %14 : tensor } @@ -366,9 +365,10 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, %7 = shape.assuming %6 -> (tensor) { %8 = shape.shape_of %arg2 : tensor -> tensor<3xindex> %9 = shape.shape_of %3 : tensor -> tensor<2xindex> - %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex> - %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xindex>) -> tensor - %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor + %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor + %11 = tensor.cast %10 : tensor to tensor<3xindex> + %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xindex>) -> tensor + %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor %14 = mhlo.subtract %12, %13 : tensor shape.assuming_yield %14 : tensor }