[MLIR][HLO] Use canonicalization patterns in broadcast propagation pass
Replace local canonicalization patterns with those from upstream. PiperOrigin-RevId: 376708719
This commit is contained in:
parent
5f5db13715
commit
511a1db4f3
|
@ -306,6 +306,61 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Eliminate casted extent tensors. Instead, produce the concrete extent tensor
|
||||
// type where possible.
|
||||
struct CanonicalizeCastedShapeOfOpPattern
|
||||
: public OpRewritePattern<tensor::CastOp> {
|
||||
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::CastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only merge tensor cast into `shape_of` ops.
|
||||
auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
|
||||
if (!shape_of_op) return failure();
|
||||
|
||||
// Desired type must be an extent tensor type.
|
||||
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_ty || result_ty.getRank() != 1 ||
|
||||
!result_ty.getElementType().isIndex())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
|
||||
shape_of_op.arg());
|
||||
if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Remove this once it has landed upstream.
|
||||
struct CanonicalizeBroadcastPattern
|
||||
: public OpRewritePattern<shape::BroadcastOp> {
|
||||
using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(shape::BroadcastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only concretize dynamic extent tensor result types.
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultTy || !resultTy.isDynamicDim(0)) return failure();
|
||||
|
||||
// Infer resulting shape rank if possible.
|
||||
int64_t maxRank = 0;
|
||||
for (Value shape : op.shapes()) {
|
||||
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
|
||||
// Cannot infer resulting shape rank if any operand is dynamically
|
||||
// ranked.
|
||||
if (extentTensorTy.isDynamicDim(0)) return failure();
|
||||
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
|
||||
}
|
||||
}
|
||||
|
||||
auto newOp = rewriter.create<shape::BroadcastOp>(
|
||||
op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()),
|
||||
op.shapes());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
|
@ -377,6 +432,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
CanonicalizeBroadcastPattern,
|
||||
CanonicalizeCastedShapeOfOpPattern,
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MergeAssumingOpsPattern,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
|
@ -386,7 +443,6 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
MoveUpBroadcastInDimOpPattern,
|
||||
ShapeReificationPattern>(context);
|
||||
// clang-format on
|
||||
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
|
||||
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue