diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 0ac6fcd..51576a7 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -306,6 +306,61 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; +// Eliminate casted extent tensors. Instead, produce the concrete extent tensor +// type where possible. +struct CanonicalizeCastedShapeOfOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + // Only merge tensor cast into `shape_of` ops. + auto shape_of_op = op.source().getDefiningOp(); + if (!shape_of_op) return failure(); + + // Desired type must be an extent tensor type. + auto result_ty = op.getType().dyn_cast(); + if (!result_ty || result_ty.getRank() != 1 || + !result_ty.getElementType().isIndex()) + return failure(); + + rewriter.replaceOpWithNewOp(op, result_ty, + shape_of_op.arg()); + if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); + return success(); + } +}; + +// TODO(frgossen): Remove this once it has landed upstream. +struct CanonicalizeBroadcastPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::BroadcastOp op, + PatternRewriter &rewriter) const override { + // Only concretize dynamic extent tensor result types. + auto resultTy = op.getType().dyn_cast(); + if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); + + // Infer resulting shape rank if possible. + int64_t maxRank = 0; + for (Value shape : op.shapes()) { + if (auto extentTensorTy = shape.getType().dyn_cast()) { + // Cannot infer resulting shape rank if any operand is dynamically + // ranked. + if (extentTensorTy.isDynamicDim(0)) return failure(); + maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); + } + } + + auto newOp = rewriter.create( + op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), + op.shapes()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + return success(); + } +}; + // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -377,6 +432,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< + CanonicalizeBroadcastPattern, + CanonicalizeCastedShapeOfOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, @@ -386,7 +443,6 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on - shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context); }