2020-03-19 16:48:09 +08:00
|
|
|
//===-------------- LowerKrnl.cpp - Krnl Dialect Lowering -----------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019-2020 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
//
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-04-02 00:38:34 +08:00
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2019-11-28 11:56:34 +08:00
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-06-27 22:35:01 +08:00
|
|
|
#include "mlir/Transforms/LoopUtils.h"
|
2019-11-28 11:56:34 +08:00
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
|
|
|
#include "src/Pass/Passes.hpp"
|
2019-11-28 11:56:34 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &rewriter,
|
|
|
|
SmallVector<std::pair<Value, AffineForOp>, 4> &nestedForOps) {
|
|
|
|
rewriter.setInsertionPointAfter(iterateOp);
|
|
|
|
SmallVector<std::pair<Value, AffineForOp>, 4> currentNestedForOps;
|
|
|
|
auto boundMapAttrs =
|
|
|
|
iterateOp.getAttrOfType<ArrayAttr>(KrnlIterateOp::getBoundsAttrName())
|
|
|
|
.getValue();
|
|
|
|
auto operandItr =
|
|
|
|
iterateOp.operand_begin() + iterateOp.getNumOptimizedLoops();
|
|
|
|
for (size_t boundIdx = 0; boundIdx < boundMapAttrs.size(); boundIdx += 2) {
|
|
|
|
// Consume input loop operand, currently do not do anything with it.
|
|
|
|
auto unoptimizedLoopRef = *(operandItr++);
|
2019-11-28 11:56:34 +08:00
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
// Organize operands into lower/upper bounds in affine.for ready formats.
|
|
|
|
llvm::SmallVector<Value, 4> lbOperands, ubOperands;
|
|
|
|
AffineMap lbMap, ubMap;
|
|
|
|
for (int boundType = 0; boundType < 2; boundType++) {
|
|
|
|
auto &operands = boundType == 0 ? lbOperands : ubOperands;
|
|
|
|
auto &map = boundType == 0 ? lbMap : ubMap;
|
|
|
|
map =
|
|
|
|
boundMapAttrs[boundIdx + boundType].cast<AffineMapAttr>().getValue();
|
|
|
|
operands.insert(
|
|
|
|
operands.end(), operandItr, operandItr + map.getNumInputs());
|
|
|
|
std::advance(operandItr, map.getNumInputs());
|
2019-11-28 11:56:34 +08:00
|
|
|
}
|
2020-06-27 22:35:01 +08:00
|
|
|
currentNestedForOps.emplace_back(std::make_pair(
|
|
|
|
unoptimizedLoopRef, rewriter.create<AffineForOp>(iterateOp.getLoc(),
|
|
|
|
lbOperands, lbMap, ubOperands, ubMap)));
|
2019-11-28 11:56:34 +08:00
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
rewriter.setInsertionPoint(currentNestedForOps.back().second.getBody(),
|
|
|
|
currentNestedForOps.back().second.getBody()->begin());
|
|
|
|
}
|
2019-11-28 11:56:34 +08:00
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
// Replace induction variable references from those introduced by a
|
|
|
|
// single krnl.iterate to those introduced by multiple affine.for
|
|
|
|
// operations.
|
|
|
|
for (int64_t i = 0; i < (int64_t)currentNestedForOps.size() - 1; i++) {
|
|
|
|
auto iterateIV = iterateOp.bodyRegion().front().getArgument(0);
|
|
|
|
auto forIV = currentNestedForOps[i].second.getBody()->getArgument(0);
|
|
|
|
iterateIV.replaceAllUsesWith(forIV);
|
|
|
|
iterateOp.bodyRegion().front().eraseArgument(0);
|
|
|
|
}
|
2019-11-28 11:56:34 +08:00
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
// Pop krnl.iterate body region block arguments, leave the last one
|
|
|
|
// for convenience (it'll be taken care of by region inlining).
|
|
|
|
while (iterateOp.bodyRegion().front().getNumArguments() > 1)
|
|
|
|
iterateOp.bodyRegion().front().eraseArgument(0);
|
|
|
|
|
|
|
|
if (currentNestedForOps.empty()) {
|
|
|
|
// If no loops are involved, simply move operations from within iterateOp
|
|
|
|
// body region to the parent region of iterateOp.
|
|
|
|
rewriter.setInsertionPointAfter(iterateOp);
|
|
|
|
iterateOp.bodyRegion().walk([&](Operation *op) {
|
|
|
|
if (!op->isKnownTerminator())
|
|
|
|
op->replaceAllUsesWith(rewriter.clone(*op));
|
|
|
|
});
|
|
|
|
} else {
|
|
|
|
// Transfer krnl.iterate region to innermost for op.
|
|
|
|
auto innermostForOp = currentNestedForOps.back().second;
|
|
|
|
innermostForOp.region().getBlocks().clear();
|
|
|
|
auto &innerMostRegion = innermostForOp.region();
|
|
|
|
innerMostRegion.getBlocks().splice(
|
|
|
|
innerMostRegion.end(), iterateOp.bodyRegion().getBlocks());
|
2019-11-28 11:56:34 +08:00
|
|
|
}
|
2020-06-27 22:35:01 +08:00
|
|
|
|
|
|
|
iterateOp.erase();
|
|
|
|
nestedForOps.insert(nestedForOps.end(), currentNestedForOps.begin(),
|
|
|
|
currentNestedForOps.end());
|
|
|
|
}
|
2019-11-28 11:56:34 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Krnl to Affine Rewrite Patterns: KrnlTerminator operation.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
class KrnlTerminatorLowering : public OpRewritePattern<KrnlTerminatorOp> {
|
2019-12-20 02:27:15 +08:00
|
|
|
public:
|
2019-11-28 11:56:34 +08:00
|
|
|
using OpRewritePattern<KrnlTerminatorOp>::OpRewritePattern;
|
|
|
|
|
2020-04-10 23:27:00 +08:00
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
KrnlTerminatorOp op, PatternRewriter &rewriter) const override {
|
2019-11-28 11:56:34 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AffineTerminatorOp>(op);
|
2020-04-02 00:38:34 +08:00
|
|
|
return success();
|
2019-11-28 11:56:34 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Krnl to Affine Rewrite Patterns: KrnlDefineLoops operation.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
class KrnlDefineLoopsLowering : public OpRewritePattern<KrnlDefineLoopsOp> {
|
2019-12-20 02:27:15 +08:00
|
|
|
public:
|
2019-11-28 11:56:34 +08:00
|
|
|
using OpRewritePattern<KrnlDefineLoopsOp>::OpRewritePattern;
|
|
|
|
|
2020-04-10 23:27:00 +08:00
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
KrnlDefineLoopsOp op, PatternRewriter &rewriter) const override {
|
2019-11-28 11:56:34 +08:00
|
|
|
rewriter.eraseOp(op);
|
2020-04-02 00:38:34 +08:00
|
|
|
return success();
|
2019-11-28 11:56:34 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
class KrnlOptimizeLoopsLowering : public OpRewritePattern<KrnlOptimizeLoopsOp> {
|
2019-12-20 02:27:15 +08:00
|
|
|
public:
|
2019-11-28 11:56:34 +08:00
|
|
|
using OpRewritePattern<KrnlOptimizeLoopsOp>::OpRewritePattern;
|
|
|
|
|
2020-04-10 23:27:00 +08:00
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
KrnlOptimizeLoopsOp op, PatternRewriter &rewriter) const override {
|
2019-11-28 11:56:34 +08:00
|
|
|
rewriter.eraseOp(op);
|
2020-04-02 00:38:34 +08:00
|
|
|
return success();
|
2019-11-28 11:56:34 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
class KrnlBlockOpLowering : public OpRewritePattern<KrnlBlockOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern<KrnlBlockOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
KrnlBlockOp op, PatternRewriter &rewriter) const override {
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Krnl to Affine Rewrite Patterns: KrnlOptimizeLoops operation.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
class KrnlReturnLoopOpLowering : public OpRewritePattern<KrnlReturnLoopsOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern<KrnlReturnLoopsOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(
|
|
|
|
KrnlReturnLoopsOp op, PatternRewriter &rewriter) const override {
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2019-11-28 11:56:34 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// KrnlToAffineLoweringPass
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// This is a partial lowering to affine loops of the krnl dialect operations.
|
|
|
|
/// At this stage the dialect will contain standard operations as well like
|
|
|
|
/// add and multiply, this pass will leave these operations intact.
|
|
|
|
namespace {
|
|
|
|
struct KrnlToAffineLoweringPass
|
2020-04-27 17:03:56 +08:00
|
|
|
: public PassWrapper<KrnlToAffineLoweringPass, FunctionPass> {
|
2019-11-28 11:56:34 +08:00
|
|
|
void runOnFunction() final;
|
|
|
|
};
|
2020-06-27 22:35:01 +08:00
|
|
|
|
|
|
|
// Helper function to test if KrnlIterateOp is nested under another
|
|
|
|
// KrnlIterateOp.
|
|
|
|
bool isIterateOpNested(KrnlIterateOp iterateOp) {
|
|
|
|
// krnl.iterate is dynamically legal, if and only if it is enclosed by
|
|
|
|
// another krnl.iterate.
|
|
|
|
Operation *op = iterateOp;
|
|
|
|
while ((op = op->getParentOp()))
|
|
|
|
if (auto parentOp = dyn_cast<KrnlIterateOp>(op))
|
|
|
|
return true;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
Optional<KrnlIterateOp> nextIterateOp(FuncOp function) {
|
|
|
|
Optional<KrnlIterateOp> nextIterateOp;
|
|
|
|
function.walk([&](KrnlIterateOp op) {
|
|
|
|
if (!isIterateOpNested(op))
|
|
|
|
nextIterateOp = op;
|
|
|
|
});
|
|
|
|
return nextIterateOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool hasOnePerfectlyNestedIterateOp(KrnlIterateOp op) {
|
|
|
|
auto childrenOps = op.bodyRegion().getOps();
|
|
|
|
auto childrenOpsIter = childrenOps.begin();
|
|
|
|
if (childrenOpsIter == childrenOps.end() ||
|
|
|
|
!isa<KrnlIterateOp>(*childrenOpsIter))
|
|
|
|
return false;
|
|
|
|
if (++childrenOpsIter == childrenOps.end() ||
|
|
|
|
!(*childrenOpsIter).isKnownTerminator())
|
|
|
|
return false;
|
|
|
|
return true;
|
|
|
|
}
|
2019-12-20 02:27:15 +08:00
|
|
|
} // end anonymous namespace.
|
2019-11-28 11:56:34 +08:00
|
|
|
|
|
|
|
void KrnlToAffineLoweringPass::runOnFunction() {
|
|
|
|
auto function = getFunction();
|
|
|
|
ConversionTarget target(getContext());
|
|
|
|
|
2020-04-02 00:38:34 +08:00
|
|
|
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
|
2019-11-28 11:56:34 +08:00
|
|
|
// We expect IR to be free of Krnl Dialect Ops.
|
|
|
|
target.addIllegalDialect<KrnlOpsDialect>();
|
2020-06-27 22:35:01 +08:00
|
|
|
|
|
|
|
// Operations that should be converted to LLVM IRs directly.
|
2019-12-14 04:28:56 +08:00
|
|
|
target.addLegalOp<KrnlMemcpyOp>();
|
2019-12-22 13:25:02 +08:00
|
|
|
target.addLegalOp<KrnlEntryPointOp>();
|
2020-04-02 01:51:06 +08:00
|
|
|
target.addLegalOp<KrnlGlobalOp>();
|
2020-06-10 04:48:33 +08:00
|
|
|
target.addLegalOp<KrnlGetRefOp>();
|
2020-06-27 22:35:01 +08:00
|
|
|
target.addLegalOp<KrnlIterateOp>();
|
2019-11-28 11:56:34 +08:00
|
|
|
|
|
|
|
OwningRewritePatternList patterns;
|
2020-06-27 22:35:01 +08:00
|
|
|
patterns.insert<KrnlTerminatorLowering, KrnlDefineLoopsLowering,
|
|
|
|
KrnlOptimizeLoopsLowering, KrnlBlockOpLowering, KrnlReturnLoopOpLowering>(
|
|
|
|
&getContext());
|
|
|
|
|
|
|
|
// Do not lower operations that pertain to schedules just yet.
|
|
|
|
target.addLegalOp<KrnlBlockOp>();
|
|
|
|
target.addLegalOp<KrnlDefineLoopsOp>();
|
|
|
|
target.addLegalOp<KrnlOptimizeLoopsOp>();
|
|
|
|
target.addLegalOp<KrnlReturnLoopsOp>();
|
|
|
|
if (failed(applyPartialConversion(function, target, patterns)))
|
|
|
|
return signalPassFailure();
|
|
|
|
|
|
|
|
OpBuilder builder(&getContext());
|
|
|
|
while (auto iterateOp = nextIterateOp(function)) {
|
|
|
|
// Collect a maximal set of loop band to lower. They must be a perfectly
|
|
|
|
// nested sequence of for loops (this limitation follows from the
|
|
|
|
// precondition of current loop manupulation utility libraries).
|
|
|
|
auto rootOp = iterateOp;
|
|
|
|
SmallVector<KrnlIterateOp, 4> loopBand = {*rootOp};
|
|
|
|
while (hasOnePerfectlyNestedIterateOp(*rootOp)) {
|
|
|
|
auto nestedIterateOp =
|
|
|
|
*rootOp->bodyRegion().getOps<KrnlIterateOp>().begin();
|
|
|
|
loopBand.emplace_back(nestedIterateOp);
|
|
|
|
rootOp = nestedIterateOp;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Lower the band of iterateOps, initialize loopRefToLoop to be the list of
|
|
|
|
// loop reference and the for loop being referenced.
|
|
|
|
SmallVector<std::pair<Value, AffineForOp>, 4> loopRefToLoop;
|
|
|
|
for (auto op : loopBand)
|
|
|
|
lowerIterateOp(op, builder, loopRefToLoop);
|
2019-11-28 11:56:34 +08:00
|
|
|
|
2020-06-27 22:35:01 +08:00
|
|
|
// Manually lower schedule ops.
|
|
|
|
while (!loopRefToLoop.empty()) {
|
|
|
|
Value loopRef;
|
|
|
|
AffineForOp forOp;
|
|
|
|
std::tie(loopRef, forOp) = loopRefToLoop.pop_back_val();
|
|
|
|
|
|
|
|
// Ensure that loop references are single-use during the scheduling phase.
|
|
|
|
auto loopRefUsers = loopRef.getUsers();
|
|
|
|
SmallVector<Operation *, 4> unfilteredUsers(
|
|
|
|
loopRefUsers.begin(), loopRefUsers.end()),
|
|
|
|
users;
|
|
|
|
std::copy_if(unfilteredUsers.begin(), unfilteredUsers.end(),
|
|
|
|
std::back_inserter(users),
|
|
|
|
[](Operation *op) { return !isa<KrnlIterateOp>(op); });
|
|
|
|
assert(std::distance(users.begin(), users.end()) <= 1 &&
|
|
|
|
"Loop reference used more than once.");
|
|
|
|
|
|
|
|
// No schedule primitives associated with this loop reference, move on.
|
|
|
|
if (users.empty())
|
|
|
|
continue;
|
|
|
|
|
|
|
|
// Scheduling operations detected, transform loops as directed, while
|
|
|
|
// keeping the loopRefToLoop mapping up-to-date.
|
|
|
|
auto user = users.front();
|
|
|
|
if (isa<KrnlBlockOp>(user)) {
|
|
|
|
auto blockOp = cast<KrnlBlockOp>(user);
|
|
|
|
SmallVector<AffineForOp, 2> tiledLoops;
|
|
|
|
SmallVector<AffineForOp, 1> loopsToTile = {forOp};
|
|
|
|
if (failed(tilePerfectlyNested(loopsToTile,
|
|
|
|
cast<KrnlBlockOp>(user).tile_sizeAttr().getInt(),
|
|
|
|
&tiledLoops))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
assert(tiledLoops.size() == 2);
|
|
|
|
assert(blockOp.getNumResults() == 2);
|
|
|
|
// Record the tiled loop references, and their corresponding tiled for
|
|
|
|
// loops in loopRefToLoop.
|
|
|
|
loopRefToLoop.emplace_back(
|
|
|
|
std::make_pair(blockOp.getResult(0), tiledLoops[0]));
|
|
|
|
loopRefToLoop.emplace_back(
|
|
|
|
std::make_pair(blockOp.getResult(1), tiledLoops[1]));
|
|
|
|
}
|
|
|
|
}
|
2019-12-22 13:25:02 +08:00
|
|
|
}
|
2020-06-27 22:35:01 +08:00
|
|
|
|
|
|
|
// KrnlIterateOp should be all gone by now.
|
|
|
|
target.addIllegalOp<KrnlIterateOp>();
|
|
|
|
|
|
|
|
// Remove/lower schedule related operations.
|
|
|
|
target.addIllegalOp<KrnlDefineLoopsOp>();
|
|
|
|
target.addIllegalOp<KrnlBlockOp>();
|
|
|
|
target.addIllegalOp<KrnlOptimizeLoopsOp>();
|
|
|
|
target.addIllegalOp<KrnlReturnLoopsOp>();
|
|
|
|
if (failed(applyPartialConversion(function, target, patterns)))
|
|
|
|
return signalPassFailure();
|
2019-11-28 11:56:34 +08:00
|
|
|
}
|
|
|
|
|
2019-12-20 02:27:15 +08:00
|
|
|
} // namespace
|
2019-11-28 11:56:34 +08:00
|
|
|
|
|
|
|
std::unique_ptr<Pass> mlir::createLowerKrnlPass() {
|
|
|
|
return std::make_unique<KrnlToAffineLoweringPass>();
|
|
|
|
}
|