Krnl block (#170)

* Support krnl.block printing/parsing.

* Checkpoing, PoC working.

* Implement krnl.block operation.

* Make tuple -> make pair.

* Bug fix, white list krnl.iterate op while lowering.

* Add return loop op lowering.

* Bug fix.

* Allow using loop refs more than once if they are used by krnl.iterate op.

* More comments and include lit test.

* Make krnl.block definition more restrictive.

* Splitting tests creates modules, making affine_map matching more verbose, prefer not splitting since test cases are small.

* Use verbose mode for LIT test on Z.

* Use verbose build to diagnose.

* Missing libraries linking when building in shared mode.

* Fix whole-archive linkage.

* Try preloading affinetransforms.

* Try put AffineTransforms into LD_LIBRARY_PATH.

* Fix python syntax error.

* No need to link with whole-archive libs, as they are pre-loaded.

* Do not preload any library.

* Link with whole-archive libs.

* Explicitly shared linkage in CMake.

* Fix CMake syntax error.

* Restore test.py

* Update z13.sh

* Update z13.sh

* Provide krnl.block operation description.
This commit is contained in:
Tian Jin 2020-06-27 22:35:01 +08:00 committed by GitHub
parent fd3ee81bcf
commit f9cb113a84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 279 additions and 75 deletions

View File

@ -273,3 +273,20 @@ def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
let parser = ?;
let printer = ?;
}
def KrnlBlockOp : Op<Krnl_Dialect, "block"> {
let summary = "Krnl block operation";
let description = [{
Block a single for loop by a constant tile size. For instance,
$ib, $il = krnl.block %i, 4
means to block the for loop referred to by %i using a tile size of 4.
}];
let arguments = (ins
AnyType:$loop, I64Attr:$tile_size);
let results = (outs AnyType:$loop_block, AnyType:$loop_local);
let assemblyFormat = [{
$loop $tile_size attr-dict `:` functional-type($loop, results)
}];
}

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/LoopUtils.h"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
@ -20,51 +21,45 @@ using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlIterate operation.
//===----------------------------------------------------------------------===//
struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
using OpRewritePattern<KrnlIterateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlIterateOp iterateOp, PatternRewriter &rewriter) const override {
void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
SmallVector<std::pair<Value, AffineForOp>, 4> &nestedForOps) {
rewriter.setInsertionPointAfter(iterateOp);
SmallVector<std::pair<Value, AffineForOp>, 4> currentNestedForOps;
auto boundMapAttrs =
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
.getValue();
auto operandItr =
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
SmallVector<AffineForOp, 4> nestedForOps;
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
// Consume input loop operand, currently do not do anything with it.
operandItr++;
auto unoptimizedLoopRef = *(operandItr++);
// Organize operands into lower/upper bounds in affine.for ready formats.
SmallVector<Value, 4> lbOperands, ubOperands;
llvm::SmallVector<Value, 4> lbOperands, ubOperands;
AffineMap lbMap, ubMap;
for (int boundType = 0; boundType < 2; boundType++) {
auto &operands = boundType == 0 ? lbOperands : ubOperands;
auto &map = boundType == 0 ? lbMap : ubMap;
map = boundMapAttrs[boundIdx + boundType]
.cast<AffineMapAttr>()
.getValue();
map =
boundMapAttrs[boundIdx + boundType].cast<AffineMapAttr>().getValue();
operands.insert(
operands.end(), operandItr, operandItr + map.getNumInputs());
std::advance(operandItr, map.getNumInputs());
}
currentNestedForOps.emplace_back(std::make_pair(
unoptimizedLoopRef, rewriter.create<AffineForOp>(iterateOp.getLoc(),
lbOperands, lbMap, ubOperands, ubMap)));
nestedForOps.emplace_back(rewriter.create<AffineForOp>(
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap));
rewriter.setInsertionPoint(nestedForOps.back().getBody(),
nestedForOps.back().getBody()->begin());
rewriter.setInsertionPoint(currentNestedForOps.back().second.getBody(),
currentNestedForOps.back().second.getBody()->begin());
}
// Replace induction variable references from those introduced by a
// single krnl.iterate to those introduced by multiple affine.for
// operations.
for (int64_t i = 0; i < (int64_t)nestedForOps.size() - 1; i++) {
for (int64_t i = 0; i < (int64_t)currentNestedForOps.size() - 1; i++) {
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
auto forIV = nestedForOps[i].getBody()->getArgument(0);
auto forIV = currentNestedForOps[i].second.getBody()->getArgument(0);
iterateIV.replaceAllUsesWith(forIV);
iterateOp.bodyRegion().front().eraseArgument(0);
}
@ -74,26 +69,27 @@ struct KrnlIterateOpLowering : public OpRewritePattern<KrnlIterateOp> {
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
iterateOp.bodyRegion().front().eraseArgument(0);
if (nestedForOps.empty()) {
if (currentNestedForOps.empty()) {
// If no loops are involved, simply move operations from within iterateOp
// body region to the parent region of iterateOp.
rewriter.setInsertionPoint(iterateOp);
rewriter.setInsertionPointAfter(iterateOp);
iterateOp.bodyRegion().walk([&](Operation *op) {
if (!op->isKnownTerminator())
op->replaceAllUsesWith(rewriter.clone(*op));
});
} else {
// Transfer krnl.iterate region to innermost for op.
auto innermostForOp = nestedForOps.back();
auto innermostForOp = currentNestedForOps.back().second;
innermostForOp.region().getBlocks().clear();
rewriter.inlineRegionBefore(iterateOp.bodyRegion(),
innermostForOp.region(), innermostForOp.region().end());
auto &innerMostRegion = innermostForOp.region();
innerMostRegion.getBlocks().splice(
innerMostRegion.end(), iterateOp.bodyRegion().getBlocks());
}
rewriter.eraseOp(iterateOp);
return success();
iterateOp.erase();
nestedForOps.insert(nestedForOps.end(), currentNestedForOps.begin(),
currentNestedForOps.end());
}
};
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
@ -140,6 +136,36 @@ public:
}
};
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
//===----------------------------------------------------------------------===//
class KrnlBlockOpLowering : public OpRewritePattern<KrnlBlockOp> {
public:
using OpRewritePattern<KrnlBlockOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlBlockOp op, PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
//===----------------------------------------------------------------------===//
class KrnlReturnLoopOpLowering : public OpRewritePattern<KrnlReturnLoopsOp> {
public:
using OpRewritePattern<KrnlReturnLoopsOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlReturnLoopsOp op, PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// KrnlToAffineLoweringPass
//===----------------------------------------------------------------------===//
@ -152,28 +178,144 @@ struct KrnlToAffineLoweringPass
: public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
void runOnFunction() final;
};
// Helper function to test if KrnlIterateOp is nested under another
// KrnlIterateOp.
bool isIterateOpNested(KrnlIterateOp iterateOp) {
// krnl.iterate is dynamically legal, if and only if it is enclosed by
// another krnl.iterate.
Operation *op = iterateOp;
while ((op = op->getParentOp()))
if (auto parentOp = dyn_cast<KrnlIterateOp>(op))
return true;
return false;
}
Optional<KrnlIterateOp> nextIterateOp(FuncOp function) {
Optional<KrnlIterateOp> nextIterateOp;
function.walk([&](KrnlIterateOp op) {
if (!isIterateOpNested(op))
nextIterateOp = op;
});
return nextIterateOp;
}
bool hasOnePerfectlyNestedIterateOp(KrnlIterateOp op) {
auto childrenOps = op.bodyRegion().getOps();
auto childrenOpsIter = childrenOps.begin();
if (childrenOpsIter == childrenOps.end() ||
!isa<KrnlIterateOp>(*childrenOpsIter))
return false;
if (++childrenOpsIter == childrenOps.end() ||
!(*childrenOpsIter).isKnownTerminator())
return false;
return true;
}
} // end anonymous namespace.
void KrnlToAffineLoweringPass::runOnFunction() {
auto function = getFunction();
ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
// We expect IR to be free of Krnl Dialect Ops.
target.addIllegalDialect<KrnlOpsDialect>();
// Operations that should be converted to LLVM IRs directly.
target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>();
target.addLegalOp<KrnlGlobalOp>();
target.addLegalOp<KrnlGetRefOp>();
target.addLegalOp<KrnlIterateOp>();
OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
KrnlDefineLoopsLowering, KrnlOptimizeLoopsLowering>(&getContext());
patterns.insert<KrnlTerminatorLowering, KrnlDefineLoopsLowering,
KrnlOptimizeLoopsLowering, KrnlBlockOpLowering, KrnlReturnLoopOpLowering>(
&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure();
// Do not lower operations that pertain to schedules just yet.
target.addLegalOp<KrnlBlockOp>();
target.addLegalOp<KrnlDefineLoopsOp>();
target.addLegalOp<KrnlOptimizeLoopsOp>();
target.addLegalOp<KrnlReturnLoopsOp>();
if (failed(applyPartialConversion(function, target, patterns)))
return signalPassFailure();
OpBuilder builder(&getContext());
while (auto iterateOp = nextIterateOp(function)) {
// Collect a maximal set of loop band to lower. They must be a perfectly
// nested sequence of for loops (this limitation follows from the
// precondition of current loop manupulation utility libraries).
auto rootOp = iterateOp;
SmallVector<KrnlIterateOp, 4> loopBand = {*rootOp};
while (hasOnePerfectlyNestedIterateOp(*rootOp)) {
auto nestedIterateOp =
*rootOp->bodyRegion().getOps<KrnlIterateOp>().begin();
loopBand.emplace_back(nestedIterateOp);
rootOp = nestedIterateOp;
}
// Lower the band of iterateOps, initialize loopRefToLoop to be the list of
// loop reference and the for loop being referenced.
SmallVector<std::pair<Value, AffineForOp>, 4> loopRefToLoop;
for (auto op : loopBand)
lowerIterateOp(op, builder, loopRefToLoop);
// Manually lower schedule ops.
while (!loopRefToLoop.empty()) {
Value loopRef;
AffineForOp forOp;
std::tie(loopRef, forOp) = loopRefToLoop.pop_back_val();
// Ensure that loop references are single-use during the scheduling phase.
auto loopRefUsers = loopRef.getUsers();
SmallVector<Operation *, 4> unfilteredUsers(
loopRefUsers.begin(), loopRefUsers.end()),
users;
std::copy_if(unfilteredUsers.begin(), unfilteredUsers.end(),
std::back_inserter(users),
[](Operation *op) { return !isa<KrnlIterateOp>(op); });
assert(std::distance(users.begin(), users.end()) <= 1 &&
"Loop reference used more than once.");
// No schedule primitives associated with this loop reference, move on.
if (users.empty())
continue;
// Scheduling operations detected, transform loops as directed, while
// keeping the loopRefToLoop mapping up-to-date.
auto user = users.front();
if (isa<KrnlBlockOp>(user)) {
auto blockOp = cast<KrnlBlockOp>(user);
SmallVector<AffineForOp, 2> tiledLoops;
SmallVector<AffineForOp, 1> loopsToTile = {forOp};
if (failed(tilePerfectlyNested(loopsToTile,
cast<KrnlBlockOp>(user).tile_sizeAttr().getInt(),
&tiledLoops))) {
return signalPassFailure();
}
assert(tiledLoops.size() == 2);
assert(blockOp.getNumResults() == 2);
// Record the tiled loop references, and their corresponding tiled for
// loops in loopRefToLoop.
loopRefToLoop.emplace_back(
std::make_pair(blockOp.getResult(0), tiledLoops[0]));
loopRefToLoop.emplace_back(
std::make_pair(blockOp.getResult(1), tiledLoops[1]));
}
}
}
// KrnlIterateOp should be all gone by now.
target.addIllegalOp<KrnlIterateOp>();
// Remove/lower schedule related operations.
target.addIllegalOp<KrnlDefineLoopsOp>();
target.addIllegalOp<KrnlBlockOp>();
target.addIllegalOp<KrnlOptimizeLoopsOp>();
target.addIllegalOp<KrnlReturnLoopsOp>();
if (failed(applyPartialConversion(function, target, patterns)))
return signalPassFailure();
}
} // namespace

45
test/mlir/krnl/block.mlir Normal file
View File

@ -0,0 +1,45 @@
// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s
// CHECK-DAG: #{{.*}} = affine_map<(d0) -> (d0)>
// CHECK-DAG: #{{.*}} = affine_map<(d0) -> (d0 + 2)>
// CHECK-DAG: #{{.*}} = affine_map<() -> (0)>
// CHECK-DAG: #{{.*}} = affine_map<() -> (10)>
// CHECK-DAG: #{{.*}} = affine_map<(d0, d1) -> (d1 + 2, d0 + 4, 10)>
// CHECK-DAG: #{{.*}} = affine_map<(d0) -> (d0 + 4, 10)>
func @simple_block() {
// CHECK-LABEL: simple_block
// CHECK-NEXT: affine.for [[OUTER_LOOP:%.+]] = 0 to 10 step 2 {
// CHECK-NEXT: affine.for [[INNER_LOOP:%.+]] = #map{{.*}}([[OUTER_LOOP]]) to #map{{.*}}([[OUTER_LOOP]]) {
// CHECK-NEXT: %0 = addi [[INNER_LOOP]], [[INNER_LOOP]] : index
// CHECK-NEXT: }
// CHECK-NEXT: }
%ii = krnl.define_loops 1
%ib, %il = krnl.block %ii 2 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
krnl.iterate(%ib, %il) with (%ii -> %i = 0 to 10) {
%foo = addi %i, %i : index
}
return
}
func @block_nested() {
// CHECK-LABEL: block_nested
// CHECK-NEXT: affine.for [[OUTER_LOOP:%.+]] = 0 to 10 step 4 {
// CHECK-NEXT: affine.for [[MIDDLE_LOOP:%.+]] = #map{{.*}}([[OUTER_LOOP]]) to min #map{{.*}}([[OUTER_LOOP]]) step 2 {
// CHECK-NEXT: affine.for [[INNER_LOOP:%.+]] = #map{{.*}}([[MIDDLE_LOOP]]) to min #map{{.*}}([[OUTER_LOOP]], [[MIDDLE_LOOP]]) {
// CHECK-NEXT: %0 = addi [[INNER_LOOP]], [[INNER_LOOP]] : index
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
%ii = krnl.define_loops 1
%ib, %il = krnl.block %ii 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%ilb, %ill = krnl.block %il 2 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
krnl.iterate(%ib, %ilb, %ill) with (%ii -> %i = 0 to 10) {
%foo = addi %i, %i : index
}
return
}