diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index a7ab806..689c7bb 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -305,6 +305,31 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; +// Eliminate casted extent tensors. Instead, produce the concrete extent tensor +// type where possible. +struct CanonicalizeCastedShapeOfOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + // Only merge tensor cast into `shape_of` ops. + auto shape_of_op = op.source().getDefiningOp(); + if (!shape_of_op) return failure(); + + // Desired type must be an extent tensor type. + auto result_ty = op.getType().dyn_cast(); + if (!result_ty || result_ty.getRank() != 1 || + !result_ty.getElementType().isIndex()) + return failure(); + + rewriter.replaceOpWithNewOp(op, result_ty, + shape_of_op.arg()); + if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); + return success(); + } +}; + // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -376,6 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< + CanonicalizeCastedShapeOfOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 9bc8dac..df20e06 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -311,6 +311,20 @@ func @do_not_merge_assuming_ops() { // ----- +// CHECK: @eliminate_extent_tensor_cast +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) +func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { + // CHECK-NOT: shape_of + // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> + // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () + %0 = shape.shape_of %arg : tensor<2x?x4xf32> -> tensor + %1 = tensor.cast %0 : tensor to tensor<3xindex> + "use"(%1) : (tensor<3xindex>) -> () + return +} + +// ----- + // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. // CHECK-LABEL: @sub_sub // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor)