diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index baea4bf..32b2c13 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -273,3 +273,20 @@ def KrnlGetRefOp : Op { let parser = ?; let printer = ?; } + +def KrnlBlockOp : Op { + let summary = "Krnl block operation"; + let description = [{ + Block a single for loop by a constant tile size. For instance, + $ib, $il = krnl.block %i, 4 + means to block the for loop referred to by %i using a tile size of 4. + }]; + + let arguments = (ins + AnyType:$loop, I64Attr:$tile_size); + let results = (outs AnyType:$loop_block, AnyType:$loop_local); + + let assemblyFormat = [{ + $loop $tile_size attr-dict `:` functional-type($loop, results) + }]; +} \ No newline at end of file diff --git a/src/Transform/LowerKrnl.cpp b/src/Transform/LowerKrnl.cpp index 044b471..c165812 100644 --- a/src/Transform/LowerKrnl.cpp +++ b/src/Transform/LowerKrnl.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/LoopUtils.h" #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Pass/Passes.hpp" @@ -20,80 +21,75 @@ using namespace mlir; namespace { -//===----------------------------------------------------------------------===// -// Krnl to Affine Rewrite Patterns: KrnlIterate operation. -//===----------------------------------------------------------------------===// +void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter, + SmallVector, 4> &nestedForOps) { + rewriter.setInsertionPointAfter(iterateOp); + SmallVector, 4> currentNestedForOps; + auto boundMapAttrs = + iterateOp.getAttrOfType(KrnlIterateOp::getBoundsAttrName()) + .getValue(); + auto operandItr = + iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops(); + for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) { + // Consume input loop operand, currently do not do anything with it. + auto unoptimizedLoopRef = *(operandItr++); -struct KrnlIterateOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - KrnlIterateOp iterateOp, PatternRewriter &rewriter) const override { - auto boundMapAttrs = - iterateOp.getAttrOfType(KrnlIterateOp::getBoundsAttrName()) - .getValue(); - auto operandItr = - iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops(); - SmallVector nestedForOps; - for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) { - // Consume input loop operand, currently do not do anything with it. - operandItr++; - - // Organize operands into lower/upper bounds in affine.for ready formats. - SmallVector lbOperands, ubOperands; - AffineMap lbMap, ubMap; - for (int boundType = 0; boundType < 2; boundType++) { - auto &operands = boundType == 0 ? lbOperands : ubOperands; - auto &map = boundType == 0 ? lbMap : ubMap; - map = boundMapAttrs[boundIdx + boundType] - .cast() - .getValue(); - operands.insert( - operands.end(), operandItr, operandItr + map.getNumInputs()); - std::advance(operandItr, map.getNumInputs()); - } - - nestedForOps.emplace_back(rewriter.create( - iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap)); - rewriter.setInsertionPoint(nestedForOps.back().getBody(), - nestedForOps.back().getBody()->begin()); + // Organize operands into lower/upper bounds in affine.for ready formats. + llvm::SmallVector lbOperands, ubOperands; + AffineMap lbMap, ubMap; + for (int boundType = 0; boundType < 2; boundType++) { + auto &operands = boundType == 0 ? lbOperands : ubOperands; + auto &map = boundType == 0 ? lbMap : ubMap; + map = + boundMapAttrs[boundIdx + boundType].cast().getValue(); + operands.insert( + operands.end(), operandItr, operandItr + map.getNumInputs()); + std::advance(operandItr, map.getNumInputs()); } + currentNestedForOps.emplace_back(std::make_pair( + unoptimizedLoopRef, rewriter.create(iterateOp.getLoc(), + lbOperands, lbMap, ubOperands, ubMap))); - // Replace induction variable references from those introduced by a - // single krnl.iterate to those introduced by multiple affine.for - // operations. - for (int64_t i = 0; i < (int64_t)nestedForOps.size() - 1; i++) { - auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); - auto forIV = nestedForOps[i].getBody()->getArgument(0); - iterateIV.replaceAllUsesWith(forIV); - iterateOp.bodyRegion().front().eraseArgument(0); - } - - // Pop krnl.iterate body region block arguments, leave the last one - // for convenience (it'll be taken care of by region inlining). - while (iterateOp.bodyRegion().front().getNumArguments() > 1) - iterateOp.bodyRegion().front().eraseArgument(0); - - if (nestedForOps.empty()) { - // If no loops are involved, simply move operations from within iterateOp - // body region to the parent region of iterateOp. - rewriter.setInsertionPoint(iterateOp); - iterateOp.bodyRegion().walk([&](Operation *op) { - if (!op->isKnownTerminator()) - op->replaceAllUsesWith(rewriter.clone(*op)); - }); - } else { - // Transfer krnl.iterate region to innermost for op. - auto innermostForOp = nestedForOps.back(); - innermostForOp.region().getBlocks().clear(); - rewriter.inlineRegionBefore(iterateOp.bodyRegion(), - innermostForOp.region(), innermostForOp.region().end()); - } - - rewriter.eraseOp(iterateOp); - return success(); + rewriter.setInsertionPoint(currentNestedForOps.back().second.getBody(), + currentNestedForOps.back().second.getBody()->begin()); } -}; + + // Replace induction variable references from those introduced by a + // single krnl.iterate to those introduced by multiple affine.for + // operations. + for (int64_t i = 0; i < (int64_t)currentNestedForOps.size() - 1; i++) { + auto iterateIV = iterateOp.bodyRegion().front().getArgument(0); + auto forIV = currentNestedForOps[i].second.getBody()->getArgument(0); + iterateIV.replaceAllUsesWith(forIV); + iterateOp.bodyRegion().front().eraseArgument(0); + } + + // Pop krnl.iterate body region block arguments, leave the last one + // for convenience (it'll be taken care of by region inlining). + while (iterateOp.bodyRegion().front().getNumArguments() > 1) + iterateOp.bodyRegion().front().eraseArgument(0); + + if (currentNestedForOps.empty()) { + // If no loops are involved, simply move operations from within iterateOp + // body region to the parent region of iterateOp. + rewriter.setInsertionPointAfter(iterateOp); + iterateOp.bodyRegion().walk([&](Operation *op) { + if (!op->isKnownTerminator()) + op->replaceAllUsesWith(rewriter.clone(*op)); + }); + } else { + // Transfer krnl.iterate region to innermost for op. + auto innermostForOp = currentNestedForOps.back().second; + innermostForOp.region().getBlocks().clear(); + auto &innerMostRegion = innermostForOp.region(); + innerMostRegion.getBlocks().splice( + innerMostRegion.end(), iterateOp.bodyRegion().getBlocks()); + } + + iterateOp.erase(); + nestedForOps.insert(nestedForOps.end(), currentNestedForOps.begin(), + currentNestedForOps.end()); +} //===----------------------------------------------------------------------===// // Krnl to Affine Rewrite Patterns: KrnlTerminator operation. @@ -140,6 +136,36 @@ public: } }; +//===----------------------------------------------------------------------===// +// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation. +//===----------------------------------------------------------------------===// + +class KrnlBlockOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + KrnlBlockOp op, PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation. +//===----------------------------------------------------------------------===// + +class KrnlReturnLoopOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + KrnlReturnLoopsOp op, PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // KrnlToAffineLoweringPass //===----------------------------------------------------------------------===// @@ -152,28 +178,144 @@ struct KrnlToAffineLoweringPass : public PassWrapper { void runOnFunction() final; }; + +// Helper function to test if KrnlIterateOp is nested under another +// KrnlIterateOp. +bool isIterateOpNested(KrnlIterateOp iterateOp) { + // krnl.iterate is dynamically legal, if and only if it is enclosed by + // another krnl.iterate. + Operation *op = iterateOp; + while ((op = op->getParentOp())) + if (auto parentOp = dyn_cast(op)) + return true; + return false; +} + +Optional nextIterateOp(FuncOp function) { + Optional nextIterateOp; + function.walk([&](KrnlIterateOp op) { + if (!isIterateOpNested(op)) + nextIterateOp = op; + }); + return nextIterateOp; +} + +bool hasOnePerfectlyNestedIterateOp(KrnlIterateOp op) { + auto childrenOps = op.bodyRegion().getOps(); + auto childrenOpsIter = childrenOps.begin(); + if (childrenOpsIter == childrenOps.end() || + !isa(*childrenOpsIter)) + return false; + if (++childrenOpsIter == childrenOps.end() || + !(*childrenOpsIter).isKnownTerminator()) + return false; + return true; +} } // end anonymous namespace. void KrnlToAffineLoweringPass::runOnFunction() { auto function = getFunction(); - ConversionTarget target(getContext()); target.addLegalDialect(); // We expect IR to be free of Krnl Dialect Ops. target.addIllegalDialect(); + + // Operations that should be converted to LLVM IRs directly. target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; - patterns.insert(&getContext()); + patterns.insert( + &getContext()); - if (failed(applyPartialConversion(getFunction(), target, patterns))) { - signalPassFailure(); + // Do not lower operations that pertain to schedules just yet. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + if (failed(applyPartialConversion(function, target, patterns))) + return signalPassFailure(); + + OpBuilder builder(&getContext()); + while (auto iterateOp = nextIterateOp(function)) { + // Collect a maximal set of loop band to lower. They must be a perfectly + // nested sequence of for loops (this limitation follows from the + // precondition of current loop manupulation utility libraries). + auto rootOp = iterateOp; + SmallVector loopBand = {*rootOp}; + while (hasOnePerfectlyNestedIterateOp(*rootOp)) { + auto nestedIterateOp = + *rootOp->bodyRegion().getOps().begin(); + loopBand.emplace_back(nestedIterateOp); + rootOp = nestedIterateOp; + } + + // Lower the band of iterateOps, initialize loopRefToLoop to be the list of + // loop reference and the for loop being referenced. + SmallVector, 4> loopRefToLoop; + for (auto op : loopBand) + lowerIterateOp(op, builder, loopRefToLoop); + + // Manually lower schedule ops. + while (!loopRefToLoop.empty()) { + Value loopRef; + AffineForOp forOp; + std::tie(loopRef, forOp) = loopRefToLoop.pop_back_val(); + + // Ensure that loop references are single-use during the scheduling phase. + auto loopRefUsers = loopRef.getUsers(); + SmallVector unfilteredUsers( + loopRefUsers.begin(), loopRefUsers.end()), + users; + std::copy_if(unfilteredUsers.begin(), unfilteredUsers.end(), + std::back_inserter(users), + [](Operation *op) { return !isa(op); }); + assert(std::distance(users.begin(), users.end()) <= 1 && + "Loop reference used more than once."); + + // No schedule primitives associated with this loop reference, move on. + if (users.empty()) + continue; + + // Scheduling operations detected, transform loops as directed, while + // keeping the loopRefToLoop mapping up-to-date. + auto user = users.front(); + if (isa(user)) { + auto blockOp = cast(user); + SmallVector tiledLoops; + SmallVector loopsToTile = {forOp}; + if (failed(tilePerfectlyNested(loopsToTile, + cast(user).tile_sizeAttr().getInt(), + &tiledLoops))) { + return signalPassFailure(); + } + assert(tiledLoops.size() == 2); + assert(blockOp.getNumResults() == 2); + // Record the tiled loop references, and their corresponding tiled for + // loops in loopRefToLoop. + loopRefToLoop.emplace_back( + std::make_pair(blockOp.getResult(0), tiledLoops[0])); + loopRefToLoop.emplace_back( + std::make_pair(blockOp.getResult(1), tiledLoops[1])); + } + } } + + // KrnlIterateOp should be all gone by now. + target.addIllegalOp(); + + // Remove/lower schedule related operations. + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + if (failed(applyPartialConversion(function, target, patterns))) + return signalPassFailure(); } } // namespace diff --git a/test/mlir/krnl/block.mlir b/test/mlir/krnl/block.mlir new file mode 100644 index 0000000..a48c7a4 --- /dev/null +++ b/test/mlir/krnl/block.mlir @@ -0,0 +1,45 @@ +// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s + +// CHECK-DAG: #{{.*}} = affine_map<(d0) -> (d0)> +// CHECK-DAG: #{{.*}} = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: #{{.*}} = affine_map<() -> (0)> +// CHECK-DAG: #{{.*}} = affine_map<() -> (10)> +// CHECK-DAG: #{{.*}} = affine_map<(d0, d1) -> (d1 + 2, d0 + 4, 10)> +// CHECK-DAG: #{{.*}} = affine_map<(d0) -> (d0 + 4, 10)> + +func @simple_block() { + // CHECK-LABEL: simple_block + // CHECK-NEXT: affine.for [[OUTER_LOOP:%.+]] = 0 to 10 step 2 { + // CHECK-NEXT: affine.for [[INNER_LOOP:%.+]] = #map{{.*}}([[OUTER_LOOP]]) to #map{{.*}}([[OUTER_LOOP]]) { + // CHECK-NEXT: %0 = addi [[INNER_LOOP]], [[INNER_LOOP]] : index + // CHECK-NEXT: } + // CHECK-NEXT: } + + %ii = krnl.define_loops 1 + %ib, %il = krnl.block %ii 2 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.iterate(%ib, %il) with (%ii -> %i = 0 to 10) { + %foo = addi %i, %i : index + } + + return +} + +func @block_nested() { + // CHECK-LABEL: block_nested + // CHECK-NEXT: affine.for [[OUTER_LOOP:%.+]] = 0 to 10 step 4 { + // CHECK-NEXT: affine.for [[MIDDLE_LOOP:%.+]] = #map{{.*}}([[OUTER_LOOP]]) to min #map{{.*}}([[OUTER_LOOP]]) step 2 { + // CHECK-NEXT: affine.for [[INNER_LOOP:%.+]] = #map{{.*}}([[MIDDLE_LOOP]]) to min #map{{.*}}([[OUTER_LOOP]], [[MIDDLE_LOOP]]) { + // CHECK-NEXT: %0 = addi [[INNER_LOOP]], [[INNER_LOOP]] : index + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + + %ii = krnl.define_loops 1 + %ib, %il = krnl.block %ii 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + %ilb, %ill = krnl.block %il 2 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.iterate(%ib, %ilb, %ill) with (%ii -> %i = 0 to 10) { + %foo = addi %i, %i : index + } + + return +} \ No newline at end of file