diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 689c7bb..a6b5e90 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -305,17 +305,18 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; -// Eliminate casted extent tensors. Instead, produce the concrete extent tensor +// Eliminate extent tensor casts. Instead, produce the concrete extent tensor // type where possible. -struct CanonicalizeCastedShapeOfOpPattern +template +struct CanonicalizeCastedExtentTensorOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::CastOp op, PatternRewriter &rewriter) const override { - // Only merge tensor cast into `shape_of` ops. - auto shape_of_op = op.source().getDefiningOp(); - if (!shape_of_op) return failure(); + // Only merge tensor cast into a producer op if we know it supports it. + auto producer_op = op.source().getDefiningOp(); + if (!producer_op) return failure(); // Desired type must be an extent tensor type. auto result_ty = op.getType().dyn_cast(); @@ -323,9 +324,9 @@ struct CanonicalizeCastedShapeOfOpPattern !result_ty.getElementType().isIndex()) return failure(); - rewriter.replaceOpWithNewOp(op, result_ty, - shape_of_op.arg()); - if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); + rewriter.replaceOpWithNewOp(op, result_ty, producer_op->getOperands(), + producer_op->getAttrs()); + if (producer_op->getUses().empty()) rewriter.eraseOp(producer_op); return success(); } }; @@ -401,7 +402,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - CanonicalizeCastedShapeOfOpPattern, + CanonicalizeCastedExtentTensorOpPattern, + CanonicalizeCastedExtentTensorOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index df20e06..d76923a 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -311,9 +311,9 @@ func @do_not_merge_assuming_ops() { // ----- -// CHECK: @eliminate_extent_tensor_cast +// CHECK: @merge_extent_tensor_cast_into_shape_of // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) -func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { +func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) { // CHECK-NOT: shape_of // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () @@ -325,6 +325,19 @@ func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { // ----- +// CHECK: @merge_extent_tensor_cast_into_broadcast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xindex>, %[[ARG1:.*]]: tensor<3xindex>) +func @merge_extent_tensor_cast_into_broadcast(%arg0 : tensor<3xindex>, %arg1 : tensor<3xindex>) { + // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> + // CHECK: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () + %0 = shape.broadcast %arg0, %arg1 : tensor<3xindex>, tensor<3xindex> -> tensor + %1 = tensor.cast %0 : tensor to tensor<3xindex> + "use"(%1) : (tensor<3xindex>) -> () + return +} + +// ----- + // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. // CHECK-LABEL: @sub_sub // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor)