From f16e5a3a676f11ab757475ee8332dc2b78680b7f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Jun 2021 03:13:39 -0700 Subject: [PATCH] [MLIR][HLO] Use canonicalization patterns in broadcast propagation pass Replace local canonicalization patterns with those from upstream. PiperOrigin-RevId: 376794178 --- .../move_up_dynamic_broadcasts_for_fusion.cc | 61 +------------------ 1 file changed, 3 insertions(+), 58 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 51576a7..7f9c129 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -276,7 +276,8 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { llvm::dyn_cast(body_a->getTerminator()); for (auto pair : llvm::zip(preceding_op->getResults(), yield_op_a.operands())) { - mapping.map(std::get<0>(pair), mapping.lookup(std::get<1>(pair))); + mapping.map(std::get<0>(pair), + mapping.lookupOrDefault(std::get<1>(pair))); } // Copy op's body. @@ -306,61 +307,6 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; -// Eliminate casted extent tensors. Instead, produce the concrete extent tensor -// type where possible. -struct CanonicalizeCastedShapeOfOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::CastOp op, - PatternRewriter &rewriter) const override { - // Only merge tensor cast into `shape_of` ops. - auto shape_of_op = op.source().getDefiningOp(); - if (!shape_of_op) return failure(); - - // Desired type must be an extent tensor type. - auto result_ty = op.getType().dyn_cast(); - if (!result_ty || result_ty.getRank() != 1 || - !result_ty.getElementType().isIndex()) - return failure(); - - rewriter.replaceOpWithNewOp(op, result_ty, - shape_of_op.arg()); - if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); - return success(); - } -}; - -// TODO(frgossen): Remove this once it has landed upstream. -struct CanonicalizeBroadcastPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(shape::BroadcastOp op, - PatternRewriter &rewriter) const override { - // Only concretize dynamic extent tensor result types. - auto resultTy = op.getType().dyn_cast(); - if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); - - // Infer resulting shape rank if possible. - int64_t maxRank = 0; - for (Value shape : op.shapes()) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { - // Cannot infer resulting shape rank if any operand is dynamically - // ranked. - if (extentTensorTy.isDynamicDim(0)) return failure(); - maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); - } - } - - auto newOp = rewriter.create( - op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), - op.shapes()); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); - return success(); - } -}; - // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -432,8 +378,6 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - CanonicalizeBroadcastPattern, - CanonicalizeCastedShapeOfOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, @@ -443,6 +387,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on + shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context); }