From b9e45007d53ca2c204c50c241977a01f5977c2fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 Jun 2021 03:02:17 -0700 Subject: [PATCH] [MLIR][HLO] Extend broadcast propagation pass to enable more fusion Move element-wise operations into assuming regions. This enables fusion opportunities within the region. PiperOrigin-RevId: 378362725 --- .../mhlo/transforms/broadcast_propagation.cc | 129 ++++++++++-------- tests/broadcast_propagation.mlir | 24 ++++ 2 files changed, 99 insertions(+), 54 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/broadcast_propagation.cc b/lib/Dialect/mhlo/transforms/broadcast_propagation.cc index d8e6c95..a147855 100644 --- a/lib/Dialect/mhlo/transforms/broadcast_propagation.cc +++ b/lib/Dialect/mhlo/transforms/broadcast_propagation.cc @@ -99,6 +99,62 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern { } }; +LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op, + PatternRewriter &rewriter) { + // Only move into immediately preceding `assuming` op. + auto assuming_op = + llvm::dyn_cast_or_null(op->getPrevNode()); + if (!assuming_op) return failure(); + + Block *body = assuming_op.getBody(); + auto yield_op = cast(body->getTerminator()); + + // Find the operands to use if the op was within the assuming region. We + // will later use their copies, as we copy the assuming op and its body. + SmallVector new_operands_unmapped = + llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) { + for (auto result : llvm::enumerate(assuming_op->getResults())) { + if (result.value() == v) return yield_op->getOperand(result.index()); + } + return v; + })); + + // Insert the rewritten assuming op right before the old one. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(assuming_op); + auto new_assuming_op = rewriter.create( + assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) { + // Copy body. + BlockAndValueMapping mapping; + for (auto &nested : body->without_terminator()) + b.clone(nested, mapping); + + // Copy op into the new body and use the mapped operands. + for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) { + Value old_operand, new_operand_unmapped; + std::tie(old_operand, new_operand_unmapped) = it; + mapping.map(old_operand, + mapping.lookupOrDefault(new_operand_unmapped)); + } + Operation *new_op = b.clone(*op, mapping); + + // Yield the previous results and also the new ones. + auto mapped_results = llvm::to_vector<8>(llvm::map_range( + yield_op.operands(), + [&](Value v) { return mapping.lookupOrDefault(v); })); + mapped_results.append(new_op->getResults().begin(), + new_op->getResults().end()); + return mapped_results; + }); + + // Replace the assuming op and the root op with the corresponding result + // value. + ValueRange new_assuming_op_results = new_assuming_op->getResults(); + rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back()); + rewriter.replaceOp(op, new_assuming_op_results.back()); + return success(); +} + /// Move operation into a preceding assuming op. This allows to process /// operations that depend on the assuming op's results. It will eventually /// allow to make assuming regions' constraints independent from each other. @@ -108,59 +164,24 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Only move into immediately preceding `assuming` op. - auto assuming_op = - llvm::dyn_cast_or_null(op->getPrevNode()); - if (!assuming_op) return failure(); + return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter); + } +}; - Block *body = assuming_op.getBody(); - auto yield_op = cast(body->getTerminator()); +// Move elementwise operations into assuming regions. This will eventually allow +// for more fusion opportunities. +struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern { + explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} - // Find the operands to use if the op was within the assuming region. We - // will later use their copies, as we copy the assuming op and its body. - SmallVector new_operands_unmapped; - for (auto operand : op->getOperands()) { - new_operands_unmapped.push_back(operand); - for (auto result : llvm::enumerate(assuming_op->getResults())) { - if (result.value() == operand) - new_operands_unmapped.back() = yield_op->getOperand(result.index()); - } - } + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Apply to all elementwise and broadcasting elementwise operations. + if (!op->hasTrait() && + !op->hasTrait()) + return failure(); - // Insert the rewritten assuming op right before the old one. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(assuming_op); - auto new_assuming_op = rewriter.create( - assuming_op.getLoc(), assuming_op.witness(), - [&](OpBuilder &b, Location loc) { - // Copy body. - BlockAndValueMapping mapping; - for (auto &nested : body->without_terminator()) - b.clone(nested, mapping); - - // Copy op into the new body and use the mapped operands. - SmallVector new_operands; - for (Value v_unmapped : new_operands_unmapped) { - Value v = mapping.lookupOrDefault(v_unmapped); - new_operands.push_back(v); - } - Value new_op = b.create(loc, op->getResultTypes(), new_operands, - op->getAttrs()); - - // Yield the previous results and also the new one. - SmallVector mapped_results; - for (auto result : yield_op.operands()) - mapped_results.push_back(mapping.lookupOrDefault(result)); - mapped_results.push_back(new_op); - return mapped_results; - }); - - // Replace the assuming op and the root op with the corresponding result - // value. - ValueRange new_assuming_op_results = new_assuming_op->getResults(); - rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back()); - rewriter.replaceOp(op, new_assuming_op_results.back()); - return success(); + return MoveIntoAssumingOpMatchAndRewrite(op, rewriter); } }; @@ -307,8 +328,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { } }; -// TODO(frgossen): Only move up broadcasting operations if there is a consumer. -struct MoveUpBroadcastInDimOpPattern +struct EarlyBroadcastInDimOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -380,11 +400,12 @@ void PopulateBroadcastsPropagationPatterns(MLIRContext *context, patterns->insert< InlineBroadcastedShapeOperandsPattern, MergeAssumingOpsPattern, - MoveIntoAssumingOpPattern, + MoveElementwiseOpsIntoAssumingOpPattern, MoveIntoAssumingOpPattern, + MoveIntoAssumingOpPattern, MoveOutOfAssumingOpPattern, MoveOutOfAssumingOpPattern, - MoveUpBroadcastInDimOpPattern, + EarlyBroadcastInDimOpPattern, ShapeReificationPattern>(context); // clang-format on mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, diff --git a/tests/broadcast_propagation.mlir b/tests/broadcast_propagation.mlir index eb88a52..46e16f3 100644 --- a/tests/broadcast_propagation.mlir +++ b/tests/broadcast_propagation.mlir @@ -193,6 +193,30 @@ func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness, // ----- +// CHECK-LABEL: @move_elementwise_into_assuming +// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor) +func @move_elementwise_into_assuming(%arg0 : !shape.witness, + %arg1 : tensor) -> tensor { + // CHECK: %[[RES:.*]] = shape.assuming %[[ARG0]] + // CHECK: %[[SOME:.*]] = "some.op" + // CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]]) + // CHECK: %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]] + // CHECK: shape.assuming_yield %[[BCAST_ADD]] + // CHECK-NOT: tanh + // CHECK-NOT: broadcast_add + // CHECK: return %[[RES]] + %0:2 = shape.assuming %arg0 -> (tensor, tensor) { + %1 = "some.op"() : () -> tensor + shape.assuming_yield %arg1, %1 : tensor, tensor + } + %1 = "mhlo.tanh"(%arg1) : (tensor) -> tensor + %2 = chlo.broadcast_add %1, %0#1 + : (tensor, tensor) -> tensor + return %2 : tensor +} + +// ----- + // CHECK-LABEL: @move_shape_of_out_of_assuming // CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) func @move_shape_of_out_of_assuming(%arg0 : !shape.witness,