Introduce helper class to generate KRNL code and apply it to Convolution (#93)

* helper to gen krnl code, applied to conv

* suggested changes, name, removed set insertion point

* format

* suggested changes

* added comments and made a small name change
This commit is contained in:
Alexandre Eichenberger 2020-02-24 17:20:15 -05:00 committed by GitHub
parent 9c398c0121
commit fcb5f35993
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 303 additions and 131 deletions

View File

@ -1,2 +1,3 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes
AlignAfterOpenBracket: DontAlign

View File

@ -12,9 +12,8 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin());
@ -25,12 +24,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operands[0]});
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
auto &inputOperand = operands[0];
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto &kernelOperand = operands[1];
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
// R = ConvNoBias(D, K)
//
@ -91,123 +92,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
loc, FloatAttr::get(memRefType.getElementType(), 0));
Value subchannels;
if (kernelShape[1] < 0) {
subchannels =
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
} else {
subchannels = rewriter.create<ConstantIndexOp>(
loc, kernelShape[1]);
subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
}
// 1. Define outer loops and emit empty optimization block:
int64_t nOuterLoops = (group > 1) ? 3 : 2;
std::vector<Value> outerLoops;
std::vector<Value> optimizedOuterLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
optimizedOuterLoops, nOuterLoops);
// Prepare iteration arguments over outer loop nest.
KrnlIterateOperandPack pack(
rewriter, outerLoops, optimizedOuterLoops);
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
outerLoops.createDefineAndOptimizeOp();
// for n = 0 .. N:
pack.pushConstantBound(0);
if (inputShape[0] < 0)
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
else
pack.pushConstantBound(inputShape[0]);
int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
// for g = 0 .. N:
if (group > 1) {
pack.pushConstantBound(0);
pack.pushConstantBound(group);
}
int gIndex = -1;
if (group > 1)
gIndex = outerLoops.pushBounds(0, group);
// for m = 0 .. kernelsPerGroup:
pack.pushConstantBound(0);
pack.pushConstantBound(kernelsPerGroup);
// Outer loop iteration.
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &outerIterationBlock = iterateOp.bodyRegion().front();
// Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
rewriter.setInsertionPointToStart(&outerIterationBlock);
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
// Outer loop iteration
outerLoops.createIterateOp();
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
{
// 2. Emit the body of the outer loop nest.
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
// If group is not set then the value of the kernel ID is
// identical to that of the loop over kernels.
Value kernel = outerIterationBlock.getArguments()[1];
Value kernel = outerLoops.getInductionVar(mIndex);
if (group > 1) {
// Middle loop is over groups and third loop is over the
// kernel identifiers in the current group.
auto kernelsOffset = rewriter.create<MulIOp>(loc,
outerIterationBlock.getArguments()[1],
kernelsPerGroupValue);
kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
outerIterationBlock.getArguments()[2]);
auto kernelsOffset = rewriter.create<MulIOp>(
loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
kernel = rewriter.create<AddIOp>(
loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
}
// 2.2 Define spatial loops
int64_t nSpatialLoops = resultShape.size() - 2;
std::vector<Value> spatialLoops;
std::vector<Value> optimizedSpatialLoops;
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
optimizedSpatialLoops, nSpatialLoops);
// 2.3 Prepare iteration arguments for spatial loop nest.
KrnlIterateOperandPack spatialPack(
rewriter, spatialLoops, optimizedSpatialLoops);
BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
spatialLoops.createDefineAndOptimizeOp();
for (int i = 2; i < resultShape.size(); ++i)
addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
spatialLoops.pushBounds(0, alloc, i);
// 2.4 Emit loop nest over output spatial dimensions.
// for rX = 0 .. RX
auto spatialIterateOp =
rewriter.create<KrnlIterateOp>(loc, spatialPack);
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
// 2.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
rewriter.setInsertionPointToStart(&spatialIterationBlock);
spatialLoops.createIterateOp();
rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
{
// 3. Emit the body of the spatial loop nest.
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
SmallVector<Value, 4> resultIndices;
// n
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// kernel
resultIndices.emplace_back(kernel);
// rX
for (auto arg : spatialIterationBlock.getArguments())
for (auto arg : spatialLoops.getIterateBlock()->getArguments())
resultIndices.emplace_back(arg);
// Store initializer value into output location.
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
// 3.2 Define inner loops.
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
std::vector<Value> innerLoops;
std::vector<Value> optimizedInnerLoops;
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
optimizedInnerLoops, nInnerLoops);
// 3.3 Prepare iteration arguments for inner loop nest.
KrnlIterateOperandPack innerPack(
rewriter, innerLoops, optimizedInnerLoops);
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
innerLoops.createDefineAndOptimizeOp();
// for c = 0 .. C/group
innerPack.pushConstantBound(0);
innerPack.pushConstantBound(kernelShape[1]);
int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
// for Kx = 0 .. KX
for (int i = 2; i < kernelShape.size(); ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
innerLoops.pushBounds(0, kernelOperand, i);
// 3.4 Emit inner loop nest.
auto innerIterateOp =
rewriter.create<KrnlIterateOp>(loc, innerPack);
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
// 3.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
rewriter.setInsertionPointToStart(&innerIterationBlock);
innerLoops.createIterateOp();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{
// 4. Emit inner loop body
// R[n][kernel][r1][r2] =
@ -217,13 +177,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// 4.1 Prepare indices for accesing the data tensor.
SmallVector<Value, 4> dataIndices;
// n
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// g * (C / group) + c
Value channelDepth = innerIterationBlock.getArguments()[0];
Value channelDepth = innerLoops.getInductionVar(cIndex);
if (group > 1)
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
rewriter.create<MulIOp>(loc, subchannels,
outerIterationBlock.getArguments()[1]));
rewriter.create<MulIOp>(
loc, subchannels, outerLoops.getInductionVar(gIndex)));
dataIndices.emplace_back(channelDepth);
// sX * rX + kX
auto stridesAttribute = convOp.stridesAttr();
@ -233,15 +193,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
for (auto stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
for (int i = 0; i < kernelShape.size() - 2; ++i) {
Value spatialIndex = spatialIterationBlock.getArguments()[i];
Value spatialIndex = spatialLoops.getInductionVar(i);
// If strides are present then emit the correct access index.
if (stridesAttribute && strides[i] > 1)
spatialIndex = rewriter.create<MulIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, strides[i]),
spatialIterationBlock.getArguments()[i]);
dataIndices.emplace_back(
rewriter.create<AddIOp>(loc, spatialIndex,
innerIterationBlock.getArguments()[i+1]));
spatialLoops.getInductionVar(i));
dataIndices.emplace_back(rewriter.create<AddIOp>(
loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
}
// 4.2 Prepare indices for accessing the kernel tensor.
@ -249,17 +208,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// kernel
kernelIndices.emplace_back(kernel);
// c
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
// kX
for (int i = 0; i < kernelShape.size() - 2; ++i)
kernelIndices.emplace_back(
innerIterationBlock.getArguments()[i+1]);
kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
// 4.3 Compute convolution.
auto loadData =
rewriter.create<LoadOp>(loc, operands[0], dataIndices);
rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
auto loadKernel =
rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
auto loadPartialSum =
rewriter.create<LoadOp>(loc, alloc, resultIndices);
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,

View File

@ -1,4 +1,5 @@
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "src/dialect/krnl/krnl_ops.hpp"
@ -9,9 +10,8 @@ namespace onnf {
using namespace mlir;
ParseResult
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
Value &operand) {
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
const Type &operandType, Value &operand) {
// If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) {
// Parse operand types:
@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
_parser.parseOperandList(operand_refs);
// Record operands:
for (auto& operand_ref : operand_refs)
for (auto &operand_ref : operand_refs)
_operandRefQueue.emplace(operand_ref);
}
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
return success();
}
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
Value &operand) {
ParseResult KrnlDialectOperandParser::ParseOperand(
const Type &operandType, Value &operand) {
if (ParseOptionalOperand(operandType, operand))
return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand.");
@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
return success();
}
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
unsigned numSymbols, OpAsmPrinter& p) {
void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
unsigned numSymbols, OpAsmPrinter &p) {
p << '(';
p.printOperands(begin, begin + numDims);
p << ')';
@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
}
void printBound(AffineMapAttr boundMap,
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
OpAsmPrinter& p) {
Operation::operand_iterator &boundOperandsBeg, const char *prefix,
OpAsmPrinter &p) {
AffineMap map = boundMap.getValue();
// Check if this bound should be printed using custom assembly form.
@ -120,9 +120,10 @@ void printBound(AffineMapAttr boundMap,
printDimAndSymbolList(
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
}
} // namespace onnf
} // namespace onnf
namespace mlir {
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
@ -137,4 +138,125 @@ void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
boundMaps.emplace_back(AffineMapAttr::get(map));
_operands.emplace_back(operand);
}
} // namespace mlir
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
createdIterateOp(false) {
if (originalLoopNum <= 0)
emitError(loc, "expected positive number of original loops");
}
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
: BuildKrnlLoop(rewriter, loc,
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
BuildKrnlLoop::~BuildKrnlLoop() {
if (!createdDefineOp)
emitError(loc, "expected to create define op");
if (!createdIterateOp)
emitError(loc, "expected to create iteration op");
if (pack)
free(pack);
}
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
// insert define loop op
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
originalLoops.reserve(originalLoopNum);
for (auto result : loopsOp.getResults())
originalLoops.push_back(result);
// inserte optimize loop op.
auto optimizedLoopsOp =
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
optLoops.reserve(originalLoopNum);
// Emit empty optimizations
if (withEmptyOptimization) {
for (auto result : optimizedLoopsOp.getResults())
optLoops.push_back(result);
optBlock = &optimizedLoopsOp.region().front();
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(optBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.restoreInsertionPoint(ip);
}
// prepare data structure to push bounds
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
createdOptimizeOp = true;
}
// push bounds (lower and upper) and return index for loop info
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushConstantBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
pack->pushConstantBound(lowerBound);
// process upperBound as a dimension of mem ref, possibly non-constant
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
if (shape[upperBoundMemRefIndex] < 0) {
if (upperBoundMustBeConstant)
emitError(loc, "bound expected to be constant");
pack->pushOperandBound(
rewriter
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
.getResult());
} else
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
pack->pushOperandBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
// create iter
void BuildKrnlLoop::createIterateOp() {
if (!createdDefineOp)
emitError(loc, "must create define op before iterate op");
// Tight now, optimize (possibly empty) is mandatory. This may change
if (!createdOptimizeOp)
emitError(loc, "must create optimize op before iterate op");
// have to have defined all bounds
if (pushCount != originalLoopNum) {
printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
emitError(loc, "must push bounds for all original loops");
}
// create iterate op
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
iterBlock = &iterateOp.bodyRegion().front();
createdIterateOp = true;
}
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization) {
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
if (originalLoopNum != loopNum)
emitError(loc, "mismatch in loop numbers from constructor and define");
createDefineAndOptimizeOp(withEmptyOptimization);
for (int i = 0; i < originalLoopNum; ++i)
pushBounds(0, memRefOperand, i);
createIterateOp();
}
// get induction variable to be use within iter
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
emitError(loc, "original loop index is out of bound");
return iterBlock->getArguments()[originalLoopIndex];
}
} // namespace mlir

View File

@ -8,39 +8,38 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnf {
class KrnlDialectOperandParser {
public:
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
public:
explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
: _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
mlir::Value &operand);
mlir::ParseResult ParseOptionalOperand(
const mlir::Type &operandType, mlir::Value &operand);
// Parse an optional operand and push it to an operand list.
mlir::ParseResult
ParseOptionalOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Parse a required operand.
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
mlir::Value &operand);
mlir::ParseResult ParseOperand(
const mlir::Type &operandType, mlir::Value &operand);
// Parse a required operand and push it to an operand list.
mlir::ParseResult
ParseOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList);
// Do we have more operands to parse?
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
private:
mlir::OpAsmParser& _parser;
private:
mlir::OpAsmParser &_parser;
mlir::Builder& _builder;
mlir::Builder &_builder;
// A queue storing the parsed SSA id references.
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
@ -50,24 +49,24 @@ class KrnlDialectOperandParser {
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
// Main difference is that it advances the iterator `begin` as it consumes
// dimension and symbol operands.
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
// Adapted from:
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
// Main difference is that it advances the iterator `boundOperandsBeg` as it
// prints bound.
void printBound(mlir::AffineMapAttr boundMap,
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
mlir::OpAsmPrinter& p);
} // namespace onnf
mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
mlir::OpAsmPrinter &p);
} // namespace onnf
namespace mlir {
struct KrnlIterateOperandPack {
KrnlIterateOperandPack(mlir::Builder &builder,
llvm::ArrayRef<mlir::Value> inputLoops,
llvm::ArrayRef<mlir::Value> optimizedLoops)
llvm::ArrayRef<mlir::Value> inputLoops,
llvm::ArrayRef<mlir::Value> optimizedLoops)
: builder(builder), inputLoops(inputLoops),
optimizedLoops(optimizedLoops) {
_operands.insert(
@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
size_t getNumInputLoops() const { return inputLoops.size(); }
private:
private:
int _boundIdx = 0;
llvm::SmallVector<mlir::Value, 8> _operands;
@ -97,7 +96,99 @@ struct KrnlIterateOperandPack {
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
mlir::Builder& builder;
mlir::Builder &builder;
};
} // namespace mlir
// Helper function to write kernel loops. This class will let us build a single
// define/optimize/iterate operation combo. We can then insert optimizations in
// the body of the optimization operation, and operations in the body of the
// iterate operation.
//
// The sequence is as follow:
//
// 1) Create a object giving the rewriter, location, and number of loop in the
// original (non optimized) loop.
//
// 2) Create define & optimize ops (currently paired). Optimizations can then
// be added to the inner block of the optimize operation. Make sure to set the
// insertion point to that block for optimizations to go in the right place.
//
// 3) Push the bounds for each of the original loops. Bounds are pushed in
// pairs (lower & upper bounds). THere are a few methods to do it depending on
// the type of the bounds. When pushing bounds, the method returns a number
// that represent the index associated with that iteration (induction variable
// and bounds). That index can be used later to extract the induction variable
// for reference in computation and/or index calculations of mem refs.
//
// 4) Once all the bounds are pushed, create the iterate operation. Once this
// is done, we can add operations within the iterate blocks by setting the
// insertion point to it. Value of the induction variables can be retrieved
// using the proper index (determined when pushin the bounds).
class BuildKrnlLoop {
public:
// Create a build kernel loop for the given location and loop number.
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
// Do the same, but where the loop number corresponds to the dimensionality of
// the mem ref operand.
BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
~BuildKrnlLoop();
// Create define and optimize loop with loopNum original loops. If
// withEmptyOptimization, the optimization is simply the identity function (no
// optimizations).
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
// Push bounds (lower and upper) for each of the loops, in order. It returns
// the index associated with the loop iteration. This index is in the range
// from zero to original loop number -1, and is monotonally increasing from
// call to call. This index is later used in the getInductionVar call.
int pushBounds(int64_t lowerBound, int64_t upperBound);
int pushBounds(int64_t lowerBound, Value upperBound);
int pushBounds(Value lowerBound, Value upperBound);
// same, where the lower bound is an integer, and the uppoer bound is given by
// the size of the mem ref operand along the upperBoundMemRefIndex dimension.
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
// Create an iterate op.
void createIterateOp();
// Create an define, optimize and iterate op, with the same loop nummber as
// the rank of the memRefOperand. The lower bound of each loops is zero, and
// the upper bound of each loops is the dimension given by the mem refs
void createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization = true);
// Get the (original loop) induction variable associated with the given index.
// Use the index returned when pushing the bounds.
BlockArgument &getInductionVar(int originalLoopIndex);
// Get blocks. This allow us to set the insertion point to the inner block of
// the optimize and the iterate Operation
Block *getOptimizationBlock() { return optBlock; }
Block *getIterateBlock() { return iterBlock; }
// get original or optimized loops
std::vector<Value> &getOriginalLoops() { return originalLoops; }
std::vector<Value> &getOptimizedLoops() { return optLoops; }
private:
// inputs
ConversionPatternRewriter &rewriter;
Location loc;
int originalLoopNum;
// track loops and bounds
std::vector<Value> originalLoops;
std::vector<Value> optLoops;
KrnlIterateOperandPack *pack;
int pushCount;
bool createdDefineOp;
bool createdOptimizeOp;
bool createdIterateOp;
// insertion points (opt block, iterate)
Block *optBlock;
Block *iterBlock;
};
} // namespace mlir