From 0f341012c6a4a368fe4184061ca1092dbe8e0cb4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 May 2021 13:07:16 -0700 Subject: [PATCH] [MLIR][HLO] Eliminate duplicate broadcastable constraints PiperOrigin-RevId: 376715240 --- .../move_up_dynamic_broadcasts_for_fusion.cc | 65 +++++-------------- ...move_up_dynamic_broadcasts_for_fusion.mlir | 18 +++++ 2 files changed, 34 insertions(+), 49 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 51576a7..29fffa0 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -306,58 +306,25 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; -// Eliminate casted extent tensors. Instead, produce the concrete extent tensor -// type where possible. -struct CanonicalizeCastedShapeOfOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct EliminateDuplicateCstrBroadcastableOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tensor::CastOp op, + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, PatternRewriter &rewriter) const override { - // Only merge tensor cast into `shape_of` ops. - auto shape_of_op = op.source().getDefiningOp(); - if (!shape_of_op) return failure(); - - // Desired type must be an extent tensor type. - auto result_ty = op.getType().dyn_cast(); - if (!result_ty || result_ty.getRank() != 1 || - !result_ty.getElementType().isIndex()) - return failure(); - - rewriter.replaceOpWithNewOp(op, result_ty, - shape_of_op.arg()); - if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); - return success(); - } -}; - -// TODO(frgossen): Remove this once it has landed upstream. -struct CanonicalizeBroadcastPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(shape::BroadcastOp op, - PatternRewriter &rewriter) const override { - // Only concretize dynamic extent tensor result types. - auto resultTy = op.getType().dyn_cast(); - if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); - - // Infer resulting shape rank if possible. - int64_t maxRank = 0; - for (Value shape : op.shapes()) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { - // Cannot infer resulting shape rank if any operand is dynamically - // ranked. - if (extentTensorTy.isDynamicDim(0)) return failure(); - maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); + // Search for previous occurence of the same constraint. + Operation *it = op->getPrevNode(); + while (it != nullptr) { + if (auto candidate = llvm::dyn_cast(it)) { + if (candidate.shapes() == op.shapes()) { + rewriter.replaceOp(op, candidate.result()); + return success(); + } } + it = it->getPrevNode(); } - auto newOp = rewriter.create( - op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), - op.shapes()); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); - return success(); + return failure(); } }; @@ -432,8 +399,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - CanonicalizeBroadcastPattern, - CanonicalizeCastedShapeOfOpPattern, + EliminateDuplicateCstrBroadcastableOps, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, @@ -443,6 +409,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on + shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context); } diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index eb56b65..6527ba9 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -374,3 +374,21 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, } return %7 : tensor } + +// ----- + +// CHECK-LABEL: @redundant_cstr_broadcastable +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +func @redundant_cstr_broadcastable(%arg0: tensor, + %arg1 : tensor) { + // CHECK-DAG: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]] + // CHECK: shape.assuming %[[WITNESS]] + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor + %1 = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor + %2 = shape.assuming_all %0, %1 + shape.assuming %2 -> () { + "some.op"() : () -> () + shape.assuming_yield + } + return +}