From 0569b7f7a4e951c80acd59b3d2fe47a88f8f736b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 23 Apr 2021 10:51:38 -0700 Subject: [PATCH] [MLIR][MHLO] Generalize extent tensor cast elimination in bcast moving PiperOrigin-RevId: 370112887 --- .../move_up_dynamic_broadcasts_for_fusion.cc | 20 +++++++++---------- ...move_up_dynamic_broadcasts_for_fusion.mlir | 17 ++-------------- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index a6b5e90..689c7bb 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -305,18 +305,17 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; -// Eliminate extent tensor casts. Instead, produce the concrete extent tensor +// Eliminate casted extent tensors. Instead, produce the concrete extent tensor // type where possible. -template -struct CanonicalizeCastedExtentTensorOpPattern +struct CanonicalizeCastedShapeOfOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::CastOp op, PatternRewriter &rewriter) const override { - // Only merge tensor cast into a producer op if we know it supports it. - auto producer_op = op.source().getDefiningOp(); - if (!producer_op) return failure(); + // Only merge tensor cast into `shape_of` ops. + auto shape_of_op = op.source().getDefiningOp(); + if (!shape_of_op) return failure(); // Desired type must be an extent tensor type. auto result_ty = op.getType().dyn_cast(); @@ -324,9 +323,9 @@ struct CanonicalizeCastedExtentTensorOpPattern !result_ty.getElementType().isIndex()) return failure(); - rewriter.replaceOpWithNewOp(op, result_ty, producer_op->getOperands(), - producer_op->getAttrs()); - if (producer_op->getUses().empty()) rewriter.eraseOp(producer_op); + rewriter.replaceOpWithNewOp(op, result_ty, + shape_of_op.arg()); + if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); return success(); } }; @@ -402,8 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - CanonicalizeCastedExtentTensorOpPattern, - CanonicalizeCastedExtentTensorOpPattern, + CanonicalizeCastedShapeOfOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index d76923a..df20e06 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -311,9 +311,9 @@ func @do_not_merge_assuming_ops() { // ----- -// CHECK: @merge_extent_tensor_cast_into_shape_of +// CHECK: @eliminate_extent_tensor_cast // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) -func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) { +func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { // CHECK-NOT: shape_of // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () @@ -325,19 +325,6 @@ func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) { // ----- -// CHECK: @merge_extent_tensor_cast_into_broadcast -// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xindex>, %[[ARG1:.*]]: tensor<3xindex>) -func @merge_extent_tensor_cast_into_broadcast(%arg0 : tensor<3xindex>, %arg1 : tensor<3xindex>) { - // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> - // CHECK: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () - %0 = shape.broadcast %arg0, %arg1 : tensor<3xindex>, tensor<3xindex> -> tensor - %1 = tensor.cast %0 : tensor to tensor<3xindex> - "use"(%1) : (tensor<3xindex>) -> () - return -} - -// ----- - // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. // CHECK-LABEL: @sub_sub // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor)