From 31536431e0df504adf2b9e01966861a5bcca8107 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 31 May 2021 13:49:21 -0700 Subject: [PATCH] [MLIR][HLO] Eliminate duplicate broadcastable constraints PiperOrigin-RevId: 376718433 --- .../move_up_dynamic_broadcasts_for_fusion.cc | 65 ++++++++++++++----- ...move_up_dynamic_broadcasts_for_fusion.mlir | 18 ----- 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 29fffa0..51576a7 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -306,25 +306,58 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; -struct EliminateDuplicateCstrBroadcastableOps - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Eliminate casted extent tensors. Instead, produce the concrete extent tensor +// type where possible. +struct CanonicalizeCastedShapeOfOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + LogicalResult matchAndRewrite(tensor::CastOp op, PatternRewriter &rewriter) const override { - // Search for previous occurence of the same constraint. - Operation *it = op->getPrevNode(); - while (it != nullptr) { - if (auto candidate = llvm::dyn_cast(it)) { - if (candidate.shapes() == op.shapes()) { - rewriter.replaceOp(op, candidate.result()); - return success(); - } + // Only merge tensor cast into `shape_of` ops. + auto shape_of_op = op.source().getDefiningOp(); + if (!shape_of_op) return failure(); + + // Desired type must be an extent tensor type. + auto result_ty = op.getType().dyn_cast(); + if (!result_ty || result_ty.getRank() != 1 || + !result_ty.getElementType().isIndex()) + return failure(); + + rewriter.replaceOpWithNewOp(op, result_ty, + shape_of_op.arg()); + if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); + return success(); + } +}; + +// TODO(frgossen): Remove this once it has landed upstream. +struct CanonicalizeBroadcastPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::BroadcastOp op, + PatternRewriter &rewriter) const override { + // Only concretize dynamic extent tensor result types. + auto resultTy = op.getType().dyn_cast(); + if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); + + // Infer resulting shape rank if possible. + int64_t maxRank = 0; + for (Value shape : op.shapes()) { + if (auto extentTensorTy = shape.getType().dyn_cast()) { + // Cannot infer resulting shape rank if any operand is dynamically + // ranked. + if (extentTensorTy.isDynamicDim(0)) return failure(); + maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); } - it = it->getPrevNode(); } - return failure(); + auto newOp = rewriter.create( + op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), + op.shapes()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + return success(); } }; @@ -399,7 +432,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - EliminateDuplicateCstrBroadcastableOps, + CanonicalizeBroadcastPattern, + CanonicalizeCastedShapeOfOpPattern, InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, @@ -409,7 +443,6 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( MoveUpBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on - shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context); } diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 6527ba9..eb56b65 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -374,21 +374,3 @@ func @sub_sub(%arg0: tensor, %arg1 : tensor, } return %7 : tensor } - -// ----- - -// CHECK-LABEL: @redundant_cstr_broadcastable -// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -func @redundant_cstr_broadcastable(%arg0: tensor, - %arg1 : tensor) { - // CHECK-DAG: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]] - // CHECK: shape.assuming %[[WITNESS]] - %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor - %1 = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor - %2 = shape.assuming_all %0, %1 - shape.assuming %2 -> () { - "some.op"() : () -> () - shape.assuming_yield - } - return -}