Support krnl loop permutation (#215)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * Using a non-trivial example. * Add more complex example/test case.
This commit is contained in:
parent
c9e3ba2d64
commit
2e8f012195
|
@ -250,4 +250,87 @@ def KrnlBlockOp : Op<Krnl_Dialect, "block"> {
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$loop $tile_size attr-dict `:` functional-type($loop, results)
|
$loop $tile_size attr-dict `:` functional-type($loop, results)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlPermuteOp : Op<Krnl_Dialect, "permute"> {
|
||||||
|
let summary = "Krnl permute operation";
|
||||||
|
let description = [{
|
||||||
|
Permute a set of affine for loops using a specified permutation map.
|
||||||
|
The permutation map `map` should be constructed in such way that the
|
||||||
|
for loop referred to by the i-th operand to permute operation is sent
|
||||||
|
to the `map[i]`-th position.
|
||||||
|
|
||||||
|
For example, the following krnl dialect IR:
|
||||||
|
```
|
||||||
|
%ii, %jj, %kk = krnl.define_loops 3
|
||||||
|
krnl.permute(%ii, %jj, %kk) [1, 2, 0] : !krnl.loop, !krnl.loop, !krnl.loop
|
||||||
|
krnl.iterate (%ii, %jj, %kk) with (%ii -> %i = 0 to 10, %jj -> %j = 0 to 20, %kk -> %k = 0 to 30) {}
|
||||||
|
```
|
||||||
|
will be lowered to:
|
||||||
|
```
|
||||||
|
// Referenced by %kk
|
||||||
|
affine.for %arg0 = 0 to 30 {
|
||||||
|
// Referenced by %ii
|
||||||
|
affine.for %arg1 = 0 to 10 {
|
||||||
|
// Referenced by %jj
|
||||||
|
affine.for %arg2 = 0 to 20 {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For a more complicated example, we demonstrate 3-D tiling using krnl.block in
|
||||||
|
conjunction with krnl.permute:
|
||||||
|
```
|
||||||
|
%ii, %jj, %kk = krnl.define_loops 3
|
||||||
|
// Blocking each loop by a factor of 4.
|
||||||
|
%ib, %il = krnl.block %ii 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
%jb, %jl = krnl.block %jj 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
%kb, %kl = krnl.block %kk 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
// Move iteration over tile coordinates to be the outer loops and iterateion over
|
||||||
|
// the inter-tile elements to be the inner loops.
|
||||||
|
krnl.permute(%ib, %il, %jb, %jl, %kb, %kl) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
|
||||||
|
krnl.iterate(%ib, %il, %jb, %jl, %kb, %kl) with (%ii -> %i = 0 to 1024, %jj -> %j = 0 to 2048, %kk -> %k = 0 to 4096) {
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The above IR gets lowered to:
|
||||||
|
```
|
||||||
|
affine.for %arg0 = 0 to 1024 step 4 {
|
||||||
|
affine.for %arg1 = 0 to 2048 step 4 {
|
||||||
|
affine.for %arg2 = 0 to 4096 step 4 {
|
||||||
|
affine.for %arg3 = #map0(%arg0) to #map1(%arg0) {
|
||||||
|
affine.for %arg4 = #map0(%arg1) to #map1(%arg1) {
|
||||||
|
affine.for %arg5 = #map0(%arg2) to #map1(%arg2) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>:$loops, I64ArrayAttr:$map);
|
||||||
|
let results = (outs);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $loops `)` $map attr-dict `:` type($loops)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def KrnlUnrollOp : Op<Krnl_Dialect, "unroll"> {
|
||||||
|
let summary = "Krnl unroll operation";
|
||||||
|
let description = [{
|
||||||
|
Fully unroll the specified loops.
|
||||||
|
```
|
||||||
|
krnl.unroll %i
|
||||||
|
```
|
||||||
|
unrolls the loop referred to by %i fully.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyType:$loop);
|
||||||
|
let results = (outs);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$loop attr-dict `:` type($loop)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
|
@ -21,9 +21,24 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
|
//===----------------------------------------------------------------------===//
|
||||||
SmallVector<std::pair<Value, AffineForOp>, 4> &nestedForOps) {
|
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
|
||||||
rewriter.setInsertionPointAfter(iterateOp);
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &builder,
|
||||||
|
llvm::SmallDenseMap<Value, AffineForOp, 4> &refToOps) {
|
||||||
|
builder.setInsertionPointAfter(iterateOp);
|
||||||
SmallVector<std::pair<Value, AffineForOp>, 4> currentNestedForOps;
|
SmallVector<std::pair<Value, AffineForOp>, 4> currentNestedForOps;
|
||||||
auto boundMapAttrs =
|
auto boundMapAttrs =
|
||||||
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
||||||
|
@ -31,7 +46,7 @@ void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
|
||||||
auto operandItr =
|
auto operandItr =
|
||||||
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
|
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
|
||||||
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
|
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
|
||||||
// Consume input loop operand, currently do not do anything with it.
|
// Consume input loop operand, at this stage, do not do anything with it.
|
||||||
auto unoptimizedLoopRef = *(operandItr++);
|
auto unoptimizedLoopRef = *(operandItr++);
|
||||||
|
|
||||||
// Organize operands into lower/upper bounds in affine.for ready formats.
|
// Organize operands into lower/upper bounds in affine.for ready formats.
|
||||||
|
@ -46,11 +61,11 @@ void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
|
||||||
operands.end(), operandItr, operandItr + map.getNumInputs());
|
operands.end(), operandItr, operandItr + map.getNumInputs());
|
||||||
std::advance(operandItr, map.getNumInputs());
|
std::advance(operandItr, map.getNumInputs());
|
||||||
}
|
}
|
||||||
currentNestedForOps.emplace_back(std::make_pair(
|
auto forOp = builder.create<AffineForOp>(
|
||||||
unoptimizedLoopRef, rewriter.create<AffineForOp>(iterateOp.getLoc(),
|
iterateOp.getLoc(), lbOperands, lbMap, ubOperands, ubMap);
|
||||||
lbOperands, lbMap, ubOperands, ubMap)));
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(currentNestedForOps.back().second.getBody(),
|
currentNestedForOps.emplace_back(std::make_pair(unoptimizedLoopRef, forOp));
|
||||||
|
builder.setInsertionPoint(currentNestedForOps.back().second.getBody(),
|
||||||
currentNestedForOps.back().second.getBody()->begin());
|
currentNestedForOps.back().second.getBody()->begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,10 +87,10 @@ void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
|
||||||
if (currentNestedForOps.empty()) {
|
if (currentNestedForOps.empty()) {
|
||||||
// If no loops are involved, simply move operations from within iterateOp
|
// If no loops are involved, simply move operations from within iterateOp
|
||||||
// body region to the parent region of iterateOp.
|
// body region to the parent region of iterateOp.
|
||||||
rewriter.setInsertionPointAfter(iterateOp);
|
builder.setInsertionPointAfter(iterateOp);
|
||||||
iterateOp.bodyRegion().walk([&](Operation *op) {
|
iterateOp.bodyRegion().walk([&](Operation *op) {
|
||||||
if (!op->isKnownTerminator())
|
if (!op->isKnownTerminator())
|
||||||
op->replaceAllUsesWith(rewriter.clone(*op));
|
op->replaceAllUsesWith(builder.clone(*op));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Transfer krnl.iterate region to innermost for op.
|
// Transfer krnl.iterate region to innermost for op.
|
||||||
|
@ -86,56 +101,10 @@ void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
|
||||||
innerMostRegion.end(), iterateOp.bodyRegion().getBlocks());
|
innerMostRegion.end(), iterateOp.bodyRegion().getBlocks());
|
||||||
}
|
}
|
||||||
|
|
||||||
iterateOp.erase();
|
for (const auto &pair : currentNestedForOps)
|
||||||
nestedForOps.insert(nestedForOps.end(), currentNestedForOps.begin(),
|
refToOps.try_emplace(pair.first, pair.second);
|
||||||
currentNestedForOps.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
|
|
||||||
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
KrnlDefineLoopsOp op, PatternRewriter &rewriter) const override {
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
class KrnlBlockOpLowering : public OpRewritePattern<KrnlBlockOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<KrnlBlockOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
KrnlBlockOp op, PatternRewriter &rewriter) const override {
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// KrnlToAffineLoweringPass
|
// KrnlToAffineLoweringPass
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -148,141 +117,130 @@ struct KrnlToAffineLoweringPass
|
||||||
: public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
|
: public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper function to test if KrnlIterateOp is nested under another
|
|
||||||
// KrnlIterateOp.
|
|
||||||
bool isIterateOpNested(KrnlIterateOp iterateOp) {
|
|
||||||
// krnl.iterate is dynamically legal, if and only if it is enclosed by
|
|
||||||
// another krnl.iterate.
|
|
||||||
Operation *op = iterateOp;
|
|
||||||
while ((op = op->getParentOp()))
|
|
||||||
if (auto parentOp = dyn_cast<KrnlIterateOp>(op))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Optional<KrnlIterateOp> nextIterateOp(FuncOp function) {
|
|
||||||
Optional<KrnlIterateOp> nextIterateOp;
|
|
||||||
function.walk([&](KrnlIterateOp op) {
|
|
||||||
if (!isIterateOpNested(op))
|
|
||||||
nextIterateOp = op;
|
|
||||||
});
|
|
||||||
return nextIterateOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasOnePerfectlyNestedIterateOp(KrnlIterateOp op) {
|
|
||||||
auto childrenOps = op.bodyRegion().getOps();
|
|
||||||
auto childrenOpsIter = childrenOps.begin();
|
|
||||||
if (childrenOpsIter == childrenOps.end() ||
|
|
||||||
!isa<KrnlIterateOp>(*childrenOpsIter))
|
|
||||||
return false;
|
|
||||||
if (++childrenOpsIter == childrenOps.end() ||
|
|
||||||
!(*childrenOpsIter).isKnownTerminator())
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
void KrnlToAffineLoweringPass::runOnFunction() {
|
LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
|
||||||
auto function = getFunction();
|
llvm::SmallDenseMap<Value, AffineForOp, 4> &loopRefToOp,
|
||||||
ConversionTarget target(getContext());
|
llvm::SmallPtrSetImpl<Operation *> &opsToErase) {
|
||||||
|
// Recursively interpret nested operations.
|
||||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
|
for (auto ®ion : op->getRegions())
|
||||||
// We expect IR to be free of Krnl Dialect Ops.
|
for (auto &block : region.getBlocks()) {
|
||||||
target.addIllegalDialect<KrnlOpsDialect>();
|
auto &blockOps = block.getOperations();
|
||||||
|
for (auto itr = blockOps.begin(); itr != blockOps.end();)
|
||||||
// Operations that should be converted to LLVM IRs directly.
|
if (failed(interpretOperation(
|
||||||
target.addLegalOp<KrnlMemcpyOp>();
|
&(*itr), builder, loopRefToOp, opsToErase))) {
|
||||||
target.addLegalOp<KrnlEntryPointOp>();
|
return failure();
|
||||||
target.addLegalOp<KrnlGlobalOp>();
|
} else {
|
||||||
target.addLegalOp<KrnlGetRefOp>();
|
++itr;
|
||||||
target.addLegalOp<KrnlIterateOp>();
|
}
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
patterns.insert<KrnlTerminatorLowering, KrnlDefineLoopsLowering,
|
|
||||||
KrnlBlockOpLowering>(&getContext());
|
|
||||||
|
|
||||||
// Do not lower operations that pertain to schedules just yet.
|
|
||||||
target.addLegalOp<KrnlBlockOp>();
|
|
||||||
target.addLegalOp<KrnlDefineLoopsOp>();
|
|
||||||
if (failed(applyPartialConversion(function, target, patterns)))
|
|
||||||
return signalPassFailure();
|
|
||||||
|
|
||||||
OpBuilder builder(&getContext());
|
|
||||||
while (auto iterateOp = nextIterateOp(function)) {
|
|
||||||
// Collect a maximal set of loop band to lower. They must be a perfectly
|
|
||||||
// nested sequence of for loops (this limitation follows from the
|
|
||||||
// precondition of current loop manupulation utility libraries).
|
|
||||||
auto rootOp = iterateOp;
|
|
||||||
SmallVector<KrnlIterateOp, 4> loopBand = {*rootOp};
|
|
||||||
while (hasOnePerfectlyNestedIterateOp(*rootOp)) {
|
|
||||||
auto nestedIterateOp =
|
|
||||||
*rootOp->bodyRegion().getOps<KrnlIterateOp>().begin();
|
|
||||||
loopBand.emplace_back(nestedIterateOp);
|
|
||||||
rootOp = nestedIterateOp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lower the band of iterateOps, initialize loopRefToLoop to be the list of
|
if (auto defineOp = dyn_cast_or_null<KrnlDefineLoopsOp>(op)) {
|
||||||
// loop reference and the for loop being referenced.
|
// Collect users of defineLoops operations that are iterate operations.
|
||||||
SmallVector<std::pair<Value, AffineForOp>, 4> loopRefToLoop;
|
std::vector<KrnlIterateOp> iterateOps;
|
||||||
for (auto op : loopBand)
|
for (auto result : op->getResults())
|
||||||
lowerIterateOp(op, builder, loopRefToLoop);
|
for (auto *user : result.getUsers())
|
||||||
|
if (auto iterateOp = dyn_cast_or_null<KrnlIterateOp>(user))
|
||||||
|
if (std::find(iterateOps.begin(), iterateOps.end(), iterateOp) ==
|
||||||
|
iterateOps.end())
|
||||||
|
iterateOps.push_back(dyn_cast<KrnlIterateOp>(user));
|
||||||
|
|
||||||
// Manually lower schedule ops.
|
// Lower iterate operations and record the mapping between loop references
|
||||||
while (!loopRefToLoop.empty()) {
|
// and affine for loop operations in loopRefToOp map.
|
||||||
Value loopRef;
|
if (!iterateOps.empty()) {
|
||||||
AffineForOp forOp;
|
for (auto opToLower : iterateOps) {
|
||||||
std::tie(loopRef, forOp) = loopRefToLoop.pop_back_val();
|
if (opsToErase.count(opToLower) == 0) {
|
||||||
|
lowerIterateOp(opToLower, builder, loopRefToOp);
|
||||||
// Ensure that loop references are single-use during the scheduling phase.
|
opsToErase.insert(opToLower);
|
||||||
auto loopRefUsers = loopRef.getUsers();
|
|
||||||
SmallVector<Operation *, 4> unfilteredUsers(
|
|
||||||
loopRefUsers.begin(), loopRefUsers.end()),
|
|
||||||
users;
|
|
||||||
std::copy_if(unfilteredUsers.begin(), unfilteredUsers.end(),
|
|
||||||
std::back_inserter(users),
|
|
||||||
[](Operation *op) { return !isa<KrnlIterateOp>(op); });
|
|
||||||
assert(std::distance(users.begin(), users.end()) <= 1 &&
|
|
||||||
"Loop reference used more than once.");
|
|
||||||
|
|
||||||
// No schedule primitives associated with this loop reference, move on.
|
|
||||||
if (users.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Scheduling operations detected, transform loops as directed, while
|
|
||||||
// keeping the loopRefToLoop mapping up-to-date.
|
|
||||||
auto user = users.front();
|
|
||||||
if (isa<KrnlBlockOp>(user)) {
|
|
||||||
auto blockOp = cast<KrnlBlockOp>(user);
|
|
||||||
SmallVector<AffineForOp, 2> tiledLoops;
|
|
||||||
SmallVector<AffineForOp, 1> loopsToTile = {forOp};
|
|
||||||
if (failed(tilePerfectlyNested(loopsToTile,
|
|
||||||
cast<KrnlBlockOp>(user).tile_sizeAttr().getInt(),
|
|
||||||
&tiledLoops))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
}
|
||||||
assert(tiledLoops.size() == 2);
|
|
||||||
assert(blockOp.getNumResults() == 2);
|
|
||||||
// Record the tiled loop references, and their corresponding tiled for
|
|
||||||
// loops in loopRefToLoop.
|
|
||||||
loopRefToLoop.emplace_back(
|
|
||||||
std::make_pair(blockOp.getResult(0), tiledLoops[0]));
|
|
||||||
loopRefToLoop.emplace_back(
|
|
||||||
std::make_pair(blockOp.getResult(1), tiledLoops[1]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
opsToErase.insert(op);
|
||||||
|
return success();
|
||||||
|
} else if (auto iterateOp = dyn_cast_or_null<KrnlIterateOp>(op)) {
|
||||||
|
// If an iterateOp has no unoptimized loop references, then we need to lower
|
||||||
|
// them manually.
|
||||||
|
if (opsToErase.count(op) == 0) {
|
||||||
|
lowerIterateOp(iterateOp, builder, loopRefToOp);
|
||||||
|
opsToErase.insert(iterateOp);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
} else if (auto blockOp = dyn_cast_or_null<KrnlBlockOp>(op)) {
|
||||||
|
SmallVector<AffineForOp, 2> tiledLoops;
|
||||||
|
SmallVector<AffineForOp, 1> loopsToTile = {loopRefToOp[blockOp.loop()]};
|
||||||
|
if (failed(tilePerfectlyNested(
|
||||||
|
loopsToTile, blockOp.tile_sizeAttr().getInt(), &tiledLoops))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
assert(tiledLoops.size() == 2);
|
||||||
|
assert(blockOp.getNumResults() == 2);
|
||||||
|
|
||||||
|
// Record the tiled loop references, and their corresponding tiled
|
||||||
|
// for loops in loopRefToLoop.
|
||||||
|
loopRefToOp[blockOp.getResult(0)] = tiledLoops[0];
|
||||||
|
loopRefToOp[blockOp.getResult(1)] = tiledLoops[1];
|
||||||
|
|
||||||
|
opsToErase.insert(op);
|
||||||
|
return success();
|
||||||
|
} else if (auto permuteOp = dyn_cast_or_null<KrnlPermuteOp>(op)) {
|
||||||
|
// Collect loops to permute.
|
||||||
|
SmallVector<AffineForOp, 4> loopsToPermute;
|
||||||
|
std::transform(permuteOp.operand_begin(), permuteOp.operand_end(),
|
||||||
|
std::back_inserter(loopsToPermute),
|
||||||
|
[&](const Value &val) { return loopRefToOp[val]; });
|
||||||
|
|
||||||
|
// Construct permutation map from integer array attribute.
|
||||||
|
SmallVector<unsigned int, 4> permuteMap;
|
||||||
|
for (const auto &attr : permuteOp.map().getAsRange<IntegerAttr>())
|
||||||
|
permuteMap.emplace_back(attr.getValue().getSExtValue());
|
||||||
|
|
||||||
|
// Perform loop permutation.
|
||||||
|
permuteLoops(loopsToPermute, permuteMap);
|
||||||
|
|
||||||
|
opsToErase.insert(op);
|
||||||
|
return success();
|
||||||
|
} else if (auto unrollOp = dyn_cast_or_null<KrnlUnrollOp>(op)) {
|
||||||
|
// Unroll the affine for loop fully.
|
||||||
|
auto loopRef = unrollOp.loop();
|
||||||
|
loopUnrollFull(loopRefToOp[loopRef]);
|
||||||
|
|
||||||
|
opsToErase.insert(op);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// KrnlIterateOp should be all gone by now.
|
return success();
|
||||||
target.addIllegalOp<KrnlIterateOp>();
|
|
||||||
|
|
||||||
// Remove/lower schedule related operations.
|
|
||||||
target.addIllegalOp<KrnlDefineLoopsOp>();
|
|
||||||
target.addIllegalOp<KrnlBlockOp>();
|
|
||||||
if (failed(applyPartialConversion(function, target, patterns)))
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
|
OpBuilder builder(&getContext());
|
||||||
|
mlir::Operation *funcOp = getFunction();
|
||||||
|
|
||||||
|
// Interpret krnl dialect operations while looping recursively through
|
||||||
|
// operations within the current function, note that erasing operations while
|
||||||
|
// iterating is tricky because it can invalidate the iterator, so we collect
|
||||||
|
// the operations to be erased in a small ptr set `opsToErase`, and only erase
|
||||||
|
// after iteration completes.
|
||||||
|
llvm::SmallDenseMap<Value, AffineForOp, 4> loopRefToOp;
|
||||||
|
llvm::SmallPtrSet<Operation *, 4> opsToErase;
|
||||||
|
if (failed(interpretOperation(funcOp, builder, loopRefToOp, opsToErase))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Erase interpreted operations.
|
||||||
|
for (const auto &op : opsToErase)
|
||||||
|
op->erase();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addIllegalOp<KrnlTerminatorOp>();
|
||||||
|
target.addLegalOp<AffineTerminatorOp>();
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<KrnlTerminatorLowering>(&getContext());
|
||||||
|
DenseSet<Operation *> unconverted;
|
||||||
|
if (failed(applyPartialConversion(
|
||||||
|
getFunction(), target, patterns, &unconverted)))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @simple_permute() {
|
||||||
|
%ii, %jj = krnl.define_loops 2
|
||||||
|
krnl.permute(%ii, %jj) [1, 0] : !krnl.loop, !krnl.loop
|
||||||
|
krnl.iterate(%ii, %jj) with (%ii -> %i = 0 to 10, %jj -> %j = 0 to 20) {
|
||||||
|
%foo = addi %i, %i : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: simple_permute
|
||||||
|
// CHECK-NEXT: affine.for [[OUTER_LOOP_IV:%.+]] = 0 to 20 {
|
||||||
|
// CHECK-NEXT: affine.for [[INNER_LOOP_IV:%.+]] = 0 to 10 {
|
||||||
|
// CHECK-NEXT: [[ADD:%.+]] = addi [[INNER_LOOP_IV]], [[INNER_LOOP_IV]] : index
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @tiling() {
|
||||||
|
%ii, %ij = krnl.define_loops 2
|
||||||
|
%ib, %il = krnl.block %ii 5 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
%jb, %jl = krnl.block %ij 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
krnl.permute(%ib, %il, %jb, %jl) [0, 2, 1, 3] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
|
||||||
|
krnl.iterate(%ib, %jb, %il, %jl) with (%ii -> %i = 0 to 10, %ij -> %j = 0 to 20) {
|
||||||
|
%foo = addi %i, %i : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: tiling
|
||||||
|
// CHECK-NEXT: affine.for [[I_BLOCK_IV:%.+]] = 0 to 10 step 5 {
|
||||||
|
// CHECK-NEXT: affine.for [[J_BLOCK_IV:%.+]] = 0 to 20 step 4 {
|
||||||
|
// CHECK-NEXT: affine.for [[I_LOCAL_IV:%.+]] = #map{{.*}}([[I_BLOCK_IV]]) to #map{{.*}}([[I_BLOCK_IV]]) {
|
||||||
|
// CHECK-NEXT: affine.for [[J_LOCAL_IV:%.+]] = #map{{.*}}([[J_BLOCK_IV]]) to #map{{.*}}([[J_BLOCK_IV]]) {
|
||||||
|
// CHECK-NEXT: %0 = addi %arg2, %arg2 : index
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @tiling3d() {
|
||||||
|
%ii, %jj, %kk = krnl.define_loops 3
|
||||||
|
// Blocking each loop by a factor of 4.
|
||||||
|
%ib, %il = krnl.block %ii 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
%jb, %jl = krnl.block %jj 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
%kb, %kl = krnl.block %kk 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
|
||||||
|
// Move iteration over tile coordinates to be the outer loops and iterateion over
|
||||||
|
// the inter-tile elements to be the inner loops.
|
||||||
|
krnl.permute(%ib, %il, %jb, %jl, %kb, %kl) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
|
||||||
|
krnl.iterate(%ib, %il, %jb, %jl, %kb, %kl) with (%ii -> %i = 0 to 1024, %jj -> %j = 0 to 2048, %kk -> %k = 0 to 4096) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: tiling3d
|
||||||
|
// CHECK-NEXT: affine.for [[I_BLOCK_IV:%.+]] = 0 to 1024 step 4 {
|
||||||
|
// CHECK-NEXT: affine.for [[J_BLOCK_IV:%.+]] = 0 to 2048 step 4 {
|
||||||
|
// CHECK-NEXT: affine.for [[K_BLOCK_IV:%.+]] = 0 to 4096 step 4 {
|
||||||
|
// CHECK-NEXT: affine.for [[I_INNER_IV:%.+]] = #map0([[I_BLOCK_IV]]) to #map1([[I_BLOCK_IV]]) {
|
||||||
|
// CHECK-NEXT: affine.for [[J_INNER_IV:%.+]] = #map0([[J_BLOCK_IV]]) to #map1([[J_BLOCK_IV]]) {
|
||||||
|
// CHECK-NEXT: affine.for [[K_INNER_IV:%.+]] = #map0([[K_BLOCK_IV]]) to #map1([[K_BLOCK_IV]]) {
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
// RUN: onnx-mlir-opt --lower-krnl %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @simple_unroll() {
|
||||||
|
%ii = krnl.define_loops 1
|
||||||
|
krnl.unroll %ii : !krnl.loop
|
||||||
|
krnl.iterate(%ii) with (%ii -> %i = 0 to 4) {
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%foo = addi %i, %c1 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: simple_unroll
|
||||||
|
// CHECK-NEXT: [[CONST_IV_INIT:%.+]] = constant 0 : index
|
||||||
|
// CHECK-NEXT: [[CONST_ONE_0:%.+]] = constant 1 : index
|
||||||
|
// CHECK-NEXT: [[FIRST_RES:%.+]] = addi [[CONST_IV_INIT]], [[CONST_ONE_0]] : index
|
||||||
|
//CHECK-NEST: [[IV_TWO:%.+]] = affine.apply #map{{.+}}([[CONST_IV_INIT]])
|
||||||
|
//CHECK-NEST: [[CONST_ONE_1:%.+]] = constant 1 : index
|
||||||
|
//CHECK-NEST: %2 = addi %1, [[CONST_ONE_1]] : index
|
||||||
|
//CHECK-NEST: [[IV_THREE:%.+]] = affine.apply #map{{.+}}([[CONST_IV_INIT]])
|
||||||
|
//CHECK-NEST: [[CONST_ONE_2:%.+]] = constant 1 : index
|
||||||
|
//CHECK-NEST: %4 = addi %3, [[CONST_ONE_2]] : index
|
||||||
|
//CHECK-NEST: [[IV_FOUR:%.+]] = affine.apply #map{{.+}}([[CONST_IV_INIT]])
|
||||||
|
//CHECK-NEST: [[CONST_ONE_3:%.+]] = constant 1 : index
|
||||||
|
//CHECK-NEST: %6 = addi %5, [[CONST_ONE_3]] : index
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue