[MLIR][HLO] Use canonicalization patterns in broadcast propagation pass

Replace local canonicalization patterns with those from upstream.

PiperOrigin-RevId: 376707588
This commit is contained in:
A. Unique TensorFlower 2021-05-31 11:43:37 -07:00 committed by TensorFlow MLIR Team
parent 1f786eb934
commit 5f5db13715
1 changed files with 1 additions and 57 deletions

View File

@ -306,61 +306,6 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
} }
}; };
// Eliminate casted extent tensors. Instead, produce the concrete extent tensor
// type where possible.
struct CanonicalizeCastedShapeOfOpPattern
: public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
// Only merge tensor cast into `shape_of` ops.
auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
if (!shape_of_op) return failure();
// Desired type must be an extent tensor type.
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
if (!result_ty || result_ty.getRank() != 1 ||
!result_ty.getElementType().isIndex())
return failure();
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
shape_of_op.arg());
if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
return success();
}
};
// TODO(frgossen): Remove this once it has landed upstream.
struct CanonicalizeBroadcastPattern
: public OpRewritePattern<shape::BroadcastOp> {
using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::BroadcastOp op,
PatternRewriter &rewriter) const override {
// Only concretize dynamic extent tensor result types.
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
if (!resultTy || !resultTy.isDynamicDim(0)) return failure();
// Infer resulting shape rank if possible.
int64_t maxRank = 0;
for (Value shape : op.shapes()) {
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
// Cannot infer resulting shape rank if any operand is dynamically
// ranked.
if (extentTensorTy.isDynamicDim(0)) return failure();
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
}
}
auto newOp = rewriter.create<shape::BroadcastOp>(
op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()),
op.shapes());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
// TODO(frgossen): Only move up broadcasting operations if there is a consumer. // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpPattern struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> { : public OpRewritePattern<DynamicBroadcastInDimOp> {
@ -432,8 +377,6 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
CanonicalizeBroadcastPattern,
CanonicalizeCastedShapeOfOpPattern,
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern, MergeAssumingOpsPattern,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>, MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
@ -443,6 +386,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MoveUpBroadcastInDimOpPattern, MoveUpBroadcastInDimOpPattern,
ShapeReificationPattern>(context); ShapeReificationPattern>(context);
// clang-format on // clang-format on
shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
tensor::CastOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
} }