diff --git a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc index 25979e9..49d4652 100644 --- a/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc +++ b/lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc @@ -98,7 +98,7 @@ struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern { } }; -/// Move operation into a preceeding assuming op. This allows to process +/// Move operation into a preceding assuming op. This allows to process /// operations that depend on the assuming op's results. It will eventually /// allow to make assuming regions' constraints independent from each other. template @@ -107,7 +107,7 @@ struct MoveIntoAssumingOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Only move into immediately preceeding `assuming` op. + // Only move into immediately preceding `assuming` op. auto assuming_op = llvm::dyn_cast_or_null(op->getPrevNode()); if (!assuming_op) return failure(); @@ -239,6 +239,81 @@ struct MoveOutOfAssumingOpPattern : public OpRewritePattern { } }; +/// Merge assuming regions if their constraints are independent from each other. +struct MergeAssumingOpsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::AssumingOp op, + PatternRewriter &rewriter) const override { + // Merge assuming op with directly preceding one. + auto preceding_op = + llvm::dyn_cast_or_null(op->getPrevNode()); + if (!preceding_op) return failure(); + + // For now, both witnesses must be cstr_broadcastable. + // TODO(frgossen): Generalize this. + auto bcastable_a = op.witness().getDefiningOp(); + auto bcastable_b = + preceding_op.witness().getDefiningOp(); + if (!bcastable_a || !bcastable_b) return failure(); + + // Merge witnesses. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(preceding_op); + SmallVector new_operands; + new_operands.append(bcastable_a->getOperands().begin(), + bcastable_a->getOperands().end()); + new_operands.append(bcastable_b->getOperands().begin(), + bcastable_b->getOperands().end()); + Value new_witness = rewriter.create( + bcastable_a.getLoc(), new_operands); + + // Merge assuming ops. + Block *body_a = preceding_op.getBody(); + Block *body_b = op.getBody(); + auto new_assuming_op = rewriter.create( + preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) { + // Copy preceding op's body. + BlockAndValueMapping mapping; + for (auto &nested : body_a->without_terminator()) { + b.clone(nested, mapping); + } + + // Map result values of preceding assuming op. + auto yield_op_a = + llvm::dyn_cast(body_a->getTerminator()); + for (auto pair : + llvm::zip(preceding_op->getResults(), yield_op_a.operands())) { + mapping.map(std::get<0>(pair), mapping.lookup(std::get<1>(pair))); + } + + // Copy op's body. + for (auto &nested : body_b->without_terminator()) { + b.clone(nested, mapping); + } + + // Collect merged assuming op's results. + SmallVector mapped_results; + auto yield_op_b = + llvm::dyn_cast(body_b->getTerminator()); + for (Value v : yield_op_a.operands()) { + mapped_results.push_back(mapping.lookupOrDefault(v)); + } + for (Value v : yield_op_b.operands()) { + mapped_results.push_back(mapping.lookupOrDefault(v)); + } + return mapped_results; + }); + + // Replace the two assuming ops with the new corresponding results. + ValueRange new_results = new_assuming_op->getResults(); + size_t split_at = preceding_op->getNumResults(); + rewriter.replaceOp(preceding_op, new_results.take_front(split_at)); + rewriter.replaceOp(op, new_results.drop_front(split_at)); + return success(); + } +}; + // TODO(frgossen): Only move up broadcasting operations if there is a consumer. struct MoveUpBroadcastInDimOpPattern : public OpRewritePattern { @@ -311,6 +386,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( // clang-format off patterns->insert< InlineBroadcastedShapeOperandsPattern, + MergeAssumingOpsPattern, MoveIntoAssumingOpPattern, MoveIntoAssumingOpPattern, MoveOutOfAssumingOpPattern, diff --git a/tests/move_up_dynamic_broadcasts_for_fusion.mlir b/tests/move_up_dynamic_broadcasts_for_fusion.mlir index 97c5822..e64ff71 100644 --- a/tests/move_up_dynamic_broadcasts_for_fusion.mlir +++ b/tests/move_up_dynamic_broadcasts_for_fusion.mlir @@ -245,3 +245,43 @@ func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness, } return %0 : tensor<2xindex> } + +// ----- + +// CHECK: @merge_assuming_ops +// CHECK: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor) +func @merge_assuming_ops(%arg0: tensor, %arg1 : tensor, + %arg2: tensor) -> tensor { + // CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]] + // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]] + // CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] + // CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]] + // CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]] + // CHECK: %[[WITNESS_MERGED:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]] + // CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[WITNESS_MERGED]] + // CHECK-SAME: { + // CHECK: "some.op" + // CHECK: %[[RESULT0:.*]] = "some.producer" + // CHECK: "another.op" + // CHECK: %[[RESULT1:.*]] = "another.producer" + // CHECK: shape.assuming_yield %[[RESULT0]], %[[RESULT1]] + // CHECK: } + // CHECK: return %[[MERGED]]#1 + %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> + %2 = shape.shape_of %arg2 : tensor -> tensor<3xindex> + %3 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex> + %4 = shape.cstr_broadcastable %0, %1, %2 : tensor<2xindex>, tensor<2xindex>, + tensor<3xindex> + %5 = shape.assuming %3 -> (tensor) { + "some.op"() : () -> () + %6 = "some.producer"() : () -> tensor + shape.assuming_yield %6 : tensor + } + %7 = shape.assuming %4 -> (tensor) { + "another.op"() : () -> () + %8 = "another.producer"() : () -> tensor + shape.assuming_yield %8 : tensor + } + return %7 : tensor +}