Introduce helper class to generate KRNL code and apply it to Convolution (#93)

* helper to gen krnl code, applied to conv

* suggested changes, name, removed set insertion point

* format

* suggested changes

* added comments and made a small name change
This commit is contained in:
Alexandre Eichenberger 2020-02-24 17:20:15 -05:00 committed by GitHub
parent 9c398c0121
commit fcb5f35993
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 303 additions and 131 deletions

View File

@ -1,2 +1,3 @@
BasedOnStyle: LLVM BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes AlwaysBreakTemplateDeclarations: Yes
AlignAfterOpenBracket: DontAlign

View File

@ -12,8 +12,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
ONNXConvNoBiasOpLowering(MLIRContext *ctx) ONNXConvNoBiasOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
@ -25,12 +24,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
{operands[0]}); memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape(); auto resultShape = memRefType.getShape();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape(); auto &inputOperand = operands[0];
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape(); auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto &kernelOperand = operands[1];
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
// R = ConvNoBias(D, K) // R = ConvNoBias(D, K)
// //
@ -91,123 +92,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
loc, FloatAttr::get(memRefType.getElementType(), 0)); loc, FloatAttr::get(memRefType.getElementType(), 0));
Value subchannels; Value subchannels;
if (kernelShape[1] < 0) { if (kernelShape[1] < 0) {
subchannels = subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
} else { } else {
subchannels = rewriter.create<ConstantIndexOp>( subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
loc, kernelShape[1]);
} }
// 1. Define outer loops and emit empty optimization block: // 1. Define outer loops and emit empty optimization block:
int64_t nOuterLoops = (group > 1) ? 3 : 2; int64_t nOuterLoops = (group > 1) ? 3 : 2;
std::vector<Value> outerLoops; BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
std::vector<Value> optimizedOuterLoops; outerLoops.createDefineAndOptimizeOp();
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
optimizedOuterLoops, nOuterLoops);
// Prepare iteration arguments over outer loop nest.
KrnlIterateOperandPack pack(
rewriter, outerLoops, optimizedOuterLoops);
// for n = 0 .. N: // for n = 0 .. N:
pack.pushConstantBound(0); int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
if (inputShape[0] < 0)
pack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
else
pack.pushConstantBound(inputShape[0]);
// for g = 0 .. N: // for g = 0 .. N:
if (group > 1) { int gIndex = -1;
pack.pushConstantBound(0); if (group > 1)
pack.pushConstantBound(group); gIndex = outerLoops.pushBounds(0, group);
}
// for m = 0 .. kernelsPerGroup: // for m = 0 .. kernelsPerGroup:
pack.pushConstantBound(0); int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
pack.pushConstantBound(kernelsPerGroup); // Outer loop iteration
// Outer loop iteration. outerLoops.createIterateOp();
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
Block &outerIterationBlock = iterateOp.bodyRegion().front();
// Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
rewriter.setInsertionPointToStart(&outerIterationBlock);
{ {
// 2. Emit the body of the outer loop nest. // 2. Emit the body of the outer loop nest.
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
// If group is not set then the value of the kernel ID is // If group is not set then the value of the kernel ID is
// identical to that of the loop over kernels. // identical to that of the loop over kernels.
Value kernel = outerIterationBlock.getArguments()[1]; Value kernel = outerLoops.getInductionVar(mIndex);
if (group > 1) { if (group > 1) {
// Middle loop is over groups and third loop is over the // Middle loop is over groups and third loop is over the
// kernel identifiers in the current group. // kernel identifiers in the current group.
auto kernelsOffset = rewriter.create<MulIOp>(loc, auto kernelsOffset = rewriter.create<MulIOp>(
outerIterationBlock.getArguments()[1], loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
kernelsPerGroupValue); kernel = rewriter.create<AddIOp>(
kernel = rewriter.create<AddIOp>(loc, kernelsOffset, loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
outerIterationBlock.getArguments()[2]);
} }
// 2.2 Define spatial loops // 2.2 Define spatial loops
int64_t nSpatialLoops = resultShape.size() - 2; int64_t nSpatialLoops = resultShape.size() - 2;
std::vector<Value> spatialLoops; BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
std::vector<Value> optimizedSpatialLoops; spatialLoops.createDefineAndOptimizeOp();
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
optimizedSpatialLoops, nSpatialLoops);
// 2.3 Prepare iteration arguments for spatial loop nest.
KrnlIterateOperandPack spatialPack(
rewriter, spatialLoops, optimizedSpatialLoops);
for (int i = 2; i < resultShape.size(); ++i) for (int i = 2; i < resultShape.size(); ++i)
addDimensionToPack(rewriter, loc, spatialPack, alloc, i); spatialLoops.pushBounds(0, alloc, i);
// 2.4 Emit loop nest over output spatial dimensions. // 2.4 Emit loop nest over output spatial dimensions.
// for rX = 0 .. RX // for rX = 0 .. RX
auto spatialIterateOp = spatialLoops.createIterateOp();
rewriter.create<KrnlIterateOp>(loc, spatialPack); rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
// 2.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
rewriter.setInsertionPointToStart(&spatialIterationBlock);
{ {
// 3. Emit the body of the spatial loop nest. // 3. Emit the body of the spatial loop nest.
// 3.1 Emit: R[n][kernel][r1][r2] = 0; // 3.1 Emit: R[n][kernel][r1][r2] = 0;
SmallVector<Value, 4> resultIndices; SmallVector<Value, 4> resultIndices;
// n // n
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// kernel // kernel
resultIndices.emplace_back(kernel); resultIndices.emplace_back(kernel);
// rX // rX
for (auto arg : spatialIterationBlock.getArguments()) for (auto arg : spatialLoops.getIterateBlock()->getArguments())
resultIndices.emplace_back(arg); resultIndices.emplace_back(arg);
// Store initializer value into output location. // Store initializer value into output location.
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices); rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
// 3.2 Define inner loops. // 3.2 Define inner loops.
int64_t nInnerLoops = 1 + (kernelShape.size() - 2); int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
std::vector<Value> innerLoops; BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
std::vector<Value> optimizedInnerLoops; innerLoops.createDefineAndOptimizeOp();
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
optimizedInnerLoops, nInnerLoops);
// 3.3 Prepare iteration arguments for inner loop nest.
KrnlIterateOperandPack innerPack(
rewriter, innerLoops, optimizedInnerLoops);
// for c = 0 .. C/group // for c = 0 .. C/group
innerPack.pushConstantBound(0); int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
innerPack.pushConstantBound(kernelShape[1]);
// for Kx = 0 .. KX // for Kx = 0 .. KX
for (int i = 2; i < kernelShape.size(); ++i) for (int i = 2; i < kernelShape.size(); ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[1], i); innerLoops.pushBounds(0, kernelOperand, i);
// 3.4 Emit inner loop nest. // 3.4 Emit inner loop nest.
auto innerIterateOp = innerLoops.createIterateOp();
rewriter.create<KrnlIterateOp>(loc, innerPack); rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
// 3.5 Emit optimizations for outer loops:
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
rewriter.setInsertionPointToStart(&innerIterationBlock);
{ {
// 4. Emit inner loop body // 4. Emit inner loop body
// R[n][kernel][r1][r2] = // R[n][kernel][r1][r2] =
@ -217,13 +177,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// 4.1 Prepare indices for accesing the data tensor. // 4.1 Prepare indices for accesing the data tensor.
SmallVector<Value, 4> dataIndices; SmallVector<Value, 4> dataIndices;
// n // n
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
// g * (C / group) + c // g * (C / group) + c
Value channelDepth = innerIterationBlock.getArguments()[0]; Value channelDepth = innerLoops.getInductionVar(cIndex);
if (group > 1) if (group > 1)
channelDepth = rewriter.create<AddIOp>(loc, channelDepth, channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
rewriter.create<MulIOp>(loc, subchannels, rewriter.create<MulIOp>(
outerIterationBlock.getArguments()[1])); loc, subchannels, outerLoops.getInductionVar(gIndex)));
dataIndices.emplace_back(channelDepth); dataIndices.emplace_back(channelDepth);
// sX * rX + kX // sX * rX + kX
auto stridesAttribute = convOp.stridesAttr(); auto stridesAttribute = convOp.stridesAttr();
@ -233,15 +193,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
for (auto stride : stridesAttribute.getValue()) for (auto stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt()); strides.emplace_back(stride.cast<IntegerAttr>().getInt());
for (int i = 0; i < kernelShape.size() - 2; ++i) { for (int i = 0; i < kernelShape.size() - 2; ++i) {
Value spatialIndex = spatialIterationBlock.getArguments()[i]; Value spatialIndex = spatialLoops.getInductionVar(i);
// If strides are present then emit the correct access index. // If strides are present then emit the correct access index.
if (stridesAttribute && strides[i] > 1) if (stridesAttribute && strides[i] > 1)
spatialIndex = rewriter.create<MulIOp>(loc, spatialIndex = rewriter.create<MulIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, strides[i]), rewriter.create<ConstantIndexOp>(loc, strides[i]),
spatialIterationBlock.getArguments()[i]); spatialLoops.getInductionVar(i));
dataIndices.emplace_back( dataIndices.emplace_back(rewriter.create<AddIOp>(
rewriter.create<AddIOp>(loc, spatialIndex, loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
innerIterationBlock.getArguments()[i+1]));
} }
// 4.2 Prepare indices for accessing the kernel tensor. // 4.2 Prepare indices for accessing the kernel tensor.
@ -249,17 +208,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// kernel // kernel
kernelIndices.emplace_back(kernel); kernelIndices.emplace_back(kernel);
// c // c
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
// kX // kX
for (int i = 0; i < kernelShape.size() - 2; ++i) for (int i = 0; i < kernelShape.size() - 2; ++i)
kernelIndices.emplace_back( kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
innerIterationBlock.getArguments()[i+1]);
// 4.3 Compute convolution. // 4.3 Compute convolution.
auto loadData = auto loadData =
rewriter.create<LoadOp>(loc, operands[0], dataIndices); rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
auto loadKernel = auto loadKernel =
rewriter.create<LoadOp>(loc, operands[1], kernelIndices); rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
auto loadPartialSum = auto loadPartialSum =
rewriter.create<LoadOp>(loc, alloc, resultIndices); rewriter.create<LoadOp>(loc, alloc, resultIndices);
Value result = rewriter.create<AddFOp>(loc, loadPartialSum, Value result = rewriter.create<AddFOp>(loc, loadPartialSum,

View File

@ -1,4 +1,5 @@
#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "src/dialect/krnl/krnl_ops.hpp" #include "src/dialect/krnl/krnl_ops.hpp"
@ -9,9 +10,8 @@ namespace onnf {
using namespace mlir; using namespace mlir;
ParseResult ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType, const Type &operandType, Value &operand) {
Value &operand) {
// If operand queue is empty, parse more operands and cache them. // If operand queue is empty, parse more operands and cache them.
if (_operandRefQueue.empty()) { if (_operandRefQueue.empty()) {
// Parse operand types: // Parse operand types:
@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
_parser.parseOperandList(operand_refs); _parser.parseOperandList(operand_refs);
// Record operands: // Record operands:
for (auto& operand_ref : operand_refs) for (auto &operand_ref : operand_refs)
_operandRefQueue.emplace(operand_ref); _operandRefQueue.emplace(operand_ref);
} }
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
return success(); return success();
} }
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType, ParseResult KrnlDialectOperandParser::ParseOperand(
Value &operand) { const Type &operandType, Value &operand) {
if (ParseOptionalOperand(operandType, operand)) if (ParseOptionalOperand(operandType, operand))
return _parser.emitError( return _parser.emitError(
_parser.getCurrentLocation(), "Expecting an operand."); _parser.getCurrentLocation(), "Expecting an operand.");
@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
return success(); return success();
} }
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
unsigned numSymbols, OpAsmPrinter& p) { unsigned numSymbols, OpAsmPrinter &p) {
p << '('; p << '(';
p.printOperands(begin, begin + numDims); p.printOperands(begin, begin + numDims);
p << ')'; p << ')';
@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
} }
void printBound(AffineMapAttr boundMap, void printBound(AffineMapAttr boundMap,
Operation::operand_iterator& boundOperandsBeg, const char* prefix, Operation::operand_iterator &boundOperandsBeg, const char *prefix,
OpAsmPrinter& p) { OpAsmPrinter &p) {
AffineMap map = boundMap.getValue(); AffineMap map = boundMap.getValue();
// Check if this bound should be printed using custom assembly form. // Check if this bound should be printed using custom assembly form.
@ -123,6 +123,7 @@ void printBound(AffineMapAttr boundMap,
} // namespace onnf } // namespace onnf
namespace mlir { namespace mlir {
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
if (boundMaps.size() % 2 == 0) if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]); _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
@ -137,4 +138,125 @@ void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
_operands.emplace_back(operand); _operands.emplace_back(operand);
} }
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
createdIterateOp(false) {
if (originalLoopNum <= 0)
emitError(loc, "expected positive number of original loops");
}
BuildKrnlLoop::BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
: BuildKrnlLoop(rewriter, loc,
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
BuildKrnlLoop::~BuildKrnlLoop() {
if (!createdDefineOp)
emitError(loc, "expected to create define op");
if (!createdIterateOp)
emitError(loc, "expected to create iteration op");
if (pack)
free(pack);
}
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
// insert define loop op
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
originalLoops.reserve(originalLoopNum);
for (auto result : loopsOp.getResults())
originalLoops.push_back(result);
// inserte optimize loop op.
auto optimizedLoopsOp =
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
optLoops.reserve(originalLoopNum);
// Emit empty optimizations
if (withEmptyOptimization) {
for (auto result : optimizedLoopsOp.getResults())
optLoops.push_back(result);
optBlock = &optimizedLoopsOp.region().front();
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(optBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.restoreInsertionPoint(ip);
}
// prepare data structure to push bounds
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
createdOptimizeOp = true;
}
// push bounds (lower and upper) and return index for loop info
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushConstantBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
pack->pushConstantBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
pack->pushConstantBound(lowerBound);
// process upperBound as a dimension of mem ref, possibly non-constant
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
if (shape[upperBoundMemRefIndex] < 0) {
if (upperBoundMustBeConstant)
emitError(loc, "bound expected to be constant");
pack->pushOperandBound(
rewriter
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
.getResult());
} else
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
return pushCount++;
}
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
pack->pushOperandBound(lowerBound);
pack->pushOperandBound(upperBound);
return pushCount++;
}
// create iter
void BuildKrnlLoop::createIterateOp() {
if (!createdDefineOp)
emitError(loc, "must create define op before iterate op");
// Tight now, optimize (possibly empty) is mandatory. This may change
if (!createdOptimizeOp)
emitError(loc, "must create optimize op before iterate op");
// have to have defined all bounds
if (pushCount != originalLoopNum) {
printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
emitError(loc, "must push bounds for all original loops");
}
// create iterate op
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
iterBlock = &iterateOp.bodyRegion().front();
createdIterateOp = true;
}
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization) {
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
if (originalLoopNum != loopNum)
emitError(loc, "mismatch in loop numbers from constructor and define");
createDefineAndOptimizeOp(withEmptyOptimization);
for (int i = 0; i < originalLoopNum; ++i)
pushBounds(0, memRefOperand, i);
createIterateOp();
}
// get induction variable to be use within iter
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
emitError(loc, "original loop index is out of bound");
return iterBlock->getArguments()[originalLoopIndex];
}
} // namespace mlir } // namespace mlir

View File

@ -8,39 +8,38 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnf { namespace onnf {
class KrnlDialectOperandParser { class KrnlDialectOperandParser {
public: public:
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser) explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
: _parser(parser), _builder(parser.getBuilder()){}; : _parser(parser), _builder(parser.getBuilder()){};
// Parse an optional operand. // Parse an optional operand.
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType, mlir::ParseResult ParseOptionalOperand(
mlir::Value &operand); const mlir::Type &operandType, mlir::Value &operand);
// Parse an optional operand and push it to an operand list. // Parse an optional operand and push it to an operand list.
mlir::ParseResult mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
ParseOptionalOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList); llvm::SmallVectorImpl<mlir::Value> &operandList);
// Parse a required operand. // Parse a required operand.
mlir::ParseResult ParseOperand(const mlir::Type &operandType, mlir::ParseResult ParseOperand(
mlir::Value &operand); const mlir::Type &operandType, mlir::Value &operand);
// Parse a required operand and push it to an operand list. // Parse a required operand and push it to an operand list.
mlir::ParseResult mlir::ParseResult ParseOperand(const mlir::Type &operandType,
ParseOperand(const mlir::Type &operandType,
llvm::SmallVectorImpl<mlir::Value> &operandList); llvm::SmallVectorImpl<mlir::Value> &operandList);
// Do we have more operands to parse? // Do we have more operands to parse?
bool hasOperandLeft() { return !_operandRefQueue.empty(); } bool hasOperandLeft() { return !_operandRefQueue.empty(); }
private: private:
mlir::OpAsmParser& _parser; mlir::OpAsmParser &_parser;
mlir::Builder& _builder; mlir::Builder &_builder;
// A queue storing the parsed SSA id references. // A queue storing the parsed SSA id references.
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue; std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
@ -50,16 +49,16 @@ class KrnlDialectOperandParser {
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197 // https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
// Main difference is that it advances the iterator `begin` as it consumes // Main difference is that it advances the iterator `begin` as it consumes
// dimension and symbol operands. // dimension and symbol operands.
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin, void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p); unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
// Adapted from: // Adapted from:
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272 // https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
// Main difference is that it advances the iterator `boundOperandsBeg` as it // Main difference is that it advances the iterator `boundOperandsBeg` as it
// prints bound. // prints bound.
void printBound(mlir::AffineMapAttr boundMap, void printBound(mlir::AffineMapAttr boundMap,
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix, mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
mlir::OpAsmPrinter& p); mlir::OpAsmPrinter &p);
} // namespace onnf } // namespace onnf
namespace mlir { namespace mlir {
@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
size_t getNumInputLoops() const { return inputLoops.size(); } size_t getNumInputLoops() const { return inputLoops.size(); }
private: private:
int _boundIdx = 0; int _boundIdx = 0;
llvm::SmallVector<mlir::Value, 8> _operands; llvm::SmallVector<mlir::Value, 8> _operands;
@ -97,7 +96,99 @@ struct KrnlIterateOperandPack {
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops; llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
mlir::Builder& builder; mlir::Builder &builder;
};
// Helper function to write kernel loops. This class will let us build a single
// define/optimize/iterate operation combo. We can then insert optimizations in
// the body of the optimization operation, and operations in the body of the
// iterate operation.
//
// The sequence is as follow:
//
// 1) Create a object giving the rewriter, location, and number of loop in the
// original (non optimized) loop.
//
// 2) Create define & optimize ops (currently paired). Optimizations can then
// be added to the inner block of the optimize operation. Make sure to set the
// insertion point to that block for optimizations to go in the right place.
//
// 3) Push the bounds for each of the original loops. Bounds are pushed in
// pairs (lower & upper bounds). THere are a few methods to do it depending on
// the type of the bounds. When pushing bounds, the method returns a number
// that represent the index associated with that iteration (induction variable
// and bounds). That index can be used later to extract the induction variable
// for reference in computation and/or index calculations of mem refs.
//
// 4) Once all the bounds are pushed, create the iterate operation. Once this
// is done, we can add operations within the iterate blocks by setting the
// insertion point to it. Value of the induction variables can be retrieved
// using the proper index (determined when pushin the bounds).
class BuildKrnlLoop {
public:
// Create a build kernel loop for the given location and loop number.
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
// Do the same, but where the loop number corresponds to the dimensionality of
// the mem ref operand.
BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
~BuildKrnlLoop();
// Create define and optimize loop with loopNum original loops. If
// withEmptyOptimization, the optimization is simply the identity function (no
// optimizations).
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
// Push bounds (lower and upper) for each of the loops, in order. It returns
// the index associated with the loop iteration. This index is in the range
// from zero to original loop number -1, and is monotonally increasing from
// call to call. This index is later used in the getInductionVar call.
int pushBounds(int64_t lowerBound, int64_t upperBound);
int pushBounds(int64_t lowerBound, Value upperBound);
int pushBounds(Value lowerBound, Value upperBound);
// same, where the lower bound is an integer, and the uppoer bound is given by
// the size of the mem ref operand along the upperBoundMemRefIndex dimension.
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
// Create an iterate op.
void createIterateOp();
// Create an define, optimize and iterate op, with the same loop nummber as
// the rank of the memRefOperand. The lower bound of each loops is zero, and
// the upper bound of each loops is the dimension given by the mem refs
void createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization = true);
// Get the (original loop) induction variable associated with the given index.
// Use the index returned when pushing the bounds.
BlockArgument &getInductionVar(int originalLoopIndex);
// Get blocks. This allow us to set the insertion point to the inner block of
// the optimize and the iterate Operation
Block *getOptimizationBlock() { return optBlock; }
Block *getIterateBlock() { return iterBlock; }
// get original or optimized loops
std::vector<Value> &getOriginalLoops() { return originalLoops; }
std::vector<Value> &getOptimizedLoops() { return optLoops; }
private:
// inputs
ConversionPatternRewriter &rewriter;
Location loc;
int originalLoopNum;
// track loops and bounds
std::vector<Value> originalLoops;
std::vector<Value> optLoops;
KrnlIterateOperandPack *pack;
int pushCount;
bool createdDefineOp;
bool createdOptimizeOp;
bool createdIterateOp;
// insertion points (opt block, iterate)
Block *optBlock;
Block *iterBlock;
}; };
} // namespace mlir } // namespace mlir