Introduce helper class to generate KRNL code and apply it to Convolution (#93)
* helper to gen krnl code, applied to conv * suggested changes, name, removed set insertion point * format * suggested changes * added comments and made a small name change
This commit is contained in:
parent
9c398c0121
commit
fcb5f35993
|
@ -1,2 +1,3 @@
|
|||
BasedOnStyle: LLVM
|
||||
AlwaysBreakTemplateDeclarations: Yes
|
||||
AlignAfterOpenBracket: DontAlign
|
||||
|
|
|
@ -12,9 +12,8 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
@ -25,12 +24,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
{operands[0]});
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
||||
|
||||
auto resultShape = memRefType.getShape();
|
||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
|
||||
auto &inputOperand = operands[0];
|
||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||
auto &kernelOperand = operands[1];
|
||||
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// R = ConvNoBias(D, K)
|
||||
//
|
||||
|
@ -91,123 +92,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||
Value subchannels;
|
||||
if (kernelShape[1] < 0) {
|
||||
subchannels =
|
||||
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
|
||||
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
|
||||
} else {
|
||||
subchannels = rewriter.create<ConstantIndexOp>(
|
||||
loc, kernelShape[1]);
|
||||
subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
|
||||
}
|
||||
|
||||
// 1. Define outer loops and emit empty optimization block:
|
||||
int64_t nOuterLoops = (group > 1) ? 3 : 2;
|
||||
std::vector<Value> outerLoops;
|
||||
std::vector<Value> optimizedOuterLoops;
|
||||
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
|
||||
optimizedOuterLoops, nOuterLoops);
|
||||
|
||||
// Prepare iteration arguments over outer loop nest.
|
||||
KrnlIterateOperandPack pack(
|
||||
rewriter, outerLoops, optimizedOuterLoops);
|
||||
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
|
||||
outerLoops.createDefineAndOptimizeOp();
|
||||
// for n = 0 .. N:
|
||||
pack.pushConstantBound(0);
|
||||
if (inputShape[0] < 0)
|
||||
pack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
|
||||
else
|
||||
pack.pushConstantBound(inputShape[0]);
|
||||
int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
|
||||
// for g = 0 .. N:
|
||||
if (group > 1) {
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushConstantBound(group);
|
||||
}
|
||||
int gIndex = -1;
|
||||
if (group > 1)
|
||||
gIndex = outerLoops.pushBounds(0, group);
|
||||
// for m = 0 .. kernelsPerGroup:
|
||||
pack.pushConstantBound(0);
|
||||
pack.pushConstantBound(kernelsPerGroup);
|
||||
// Outer loop iteration.
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
Block &outerIterationBlock = iterateOp.bodyRegion().front();
|
||||
// Emit optimizations for outer loops:
|
||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
|
||||
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
|
||||
// Outer loop iteration
|
||||
outerLoops.createIterateOp();
|
||||
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
||||
{
|
||||
// 2. Emit the body of the outer loop nest.
|
||||
|
||||
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
|
||||
// If group is not set then the value of the kernel ID is
|
||||
// identical to that of the loop over kernels.
|
||||
Value kernel = outerIterationBlock.getArguments()[1];
|
||||
Value kernel = outerLoops.getInductionVar(mIndex);
|
||||
if (group > 1) {
|
||||
// Middle loop is over groups and third loop is over the
|
||||
// kernel identifiers in the current group.
|
||||
auto kernelsOffset = rewriter.create<MulIOp>(loc,
|
||||
outerIterationBlock.getArguments()[1],
|
||||
kernelsPerGroupValue);
|
||||
kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
|
||||
outerIterationBlock.getArguments()[2]);
|
||||
auto kernelsOffset = rewriter.create<MulIOp>(
|
||||
loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
|
||||
kernel = rewriter.create<AddIOp>(
|
||||
loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
|
||||
}
|
||||
|
||||
// 2.2 Define spatial loops
|
||||
int64_t nSpatialLoops = resultShape.size() - 2;
|
||||
std::vector<Value> spatialLoops;
|
||||
std::vector<Value> optimizedSpatialLoops;
|
||||
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
|
||||
optimizedSpatialLoops, nSpatialLoops);
|
||||
|
||||
// 2.3 Prepare iteration arguments for spatial loop nest.
|
||||
KrnlIterateOperandPack spatialPack(
|
||||
rewriter, spatialLoops, optimizedSpatialLoops);
|
||||
BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
|
||||
spatialLoops.createDefineAndOptimizeOp();
|
||||
for (int i = 2; i < resultShape.size(); ++i)
|
||||
addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
|
||||
spatialLoops.pushBounds(0, alloc, i);
|
||||
|
||||
// 2.4 Emit loop nest over output spatial dimensions.
|
||||
// for rX = 0 .. RX
|
||||
auto spatialIterateOp =
|
||||
rewriter.create<KrnlIterateOp>(loc, spatialPack);
|
||||
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
|
||||
// 2.5 Emit optimizations for outer loops:
|
||||
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
|
||||
rewriter.setInsertionPointToStart(&spatialIterationBlock);
|
||||
spatialLoops.createIterateOp();
|
||||
rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
|
||||
|
||||
{
|
||||
// 3. Emit the body of the spatial loop nest.
|
||||
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
|
||||
SmallVector<Value, 4> resultIndices;
|
||||
// n
|
||||
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
||||
resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
|
||||
// kernel
|
||||
resultIndices.emplace_back(kernel);
|
||||
// rX
|
||||
for (auto arg : spatialIterationBlock.getArguments())
|
||||
for (auto arg : spatialLoops.getIterateBlock()->getArguments())
|
||||
resultIndices.emplace_back(arg);
|
||||
// Store initializer value into output location.
|
||||
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
|
||||
|
||||
// 3.2 Define inner loops.
|
||||
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
|
||||
std::vector<Value> innerLoops;
|
||||
std::vector<Value> optimizedInnerLoops;
|
||||
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
|
||||
optimizedInnerLoops, nInnerLoops);
|
||||
|
||||
// 3.3 Prepare iteration arguments for inner loop nest.
|
||||
KrnlIterateOperandPack innerPack(
|
||||
rewriter, innerLoops, optimizedInnerLoops);
|
||||
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
|
||||
innerLoops.createDefineAndOptimizeOp();
|
||||
// for c = 0 .. C/group
|
||||
innerPack.pushConstantBound(0);
|
||||
innerPack.pushConstantBound(kernelShape[1]);
|
||||
int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
|
||||
// for Kx = 0 .. KX
|
||||
for (int i = 2; i < kernelShape.size(); ++i)
|
||||
addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
|
||||
innerLoops.pushBounds(0, kernelOperand, i);
|
||||
|
||||
// 3.4 Emit inner loop nest.
|
||||
auto innerIterateOp =
|
||||
rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
|
||||
// 3.5 Emit optimizations for outer loops:
|
||||
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
|
||||
rewriter.setInsertionPointToStart(&innerIterationBlock);
|
||||
innerLoops.createIterateOp();
|
||||
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
||||
|
||||
{
|
||||
// 4. Emit inner loop body
|
||||
// R[n][kernel][r1][r2] =
|
||||
|
@ -217,13 +177,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
// 4.1 Prepare indices for accesing the data tensor.
|
||||
SmallVector<Value, 4> dataIndices;
|
||||
// n
|
||||
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
||||
dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
|
||||
// g * (C / group) + c
|
||||
Value channelDepth = innerIterationBlock.getArguments()[0];
|
||||
Value channelDepth = innerLoops.getInductionVar(cIndex);
|
||||
if (group > 1)
|
||||
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
|
||||
rewriter.create<MulIOp>(loc, subchannels,
|
||||
outerIterationBlock.getArguments()[1]));
|
||||
rewriter.create<MulIOp>(
|
||||
loc, subchannels, outerLoops.getInductionVar(gIndex)));
|
||||
dataIndices.emplace_back(channelDepth);
|
||||
// sX * rX + kX
|
||||
auto stridesAttribute = convOp.stridesAttr();
|
||||
|
@ -233,15 +193,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
for (auto stride : stridesAttribute.getValue())
|
||||
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
||||
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
||||
Value spatialIndex = spatialIterationBlock.getArguments()[i];
|
||||
Value spatialIndex = spatialLoops.getInductionVar(i);
|
||||
// If strides are present then emit the correct access index.
|
||||
if (stridesAttribute && strides[i] > 1)
|
||||
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
||||
spatialIterationBlock.getArguments()[i]);
|
||||
dataIndices.emplace_back(
|
||||
rewriter.create<AddIOp>(loc, spatialIndex,
|
||||
innerIterationBlock.getArguments()[i+1]));
|
||||
spatialLoops.getInductionVar(i));
|
||||
dataIndices.emplace_back(rewriter.create<AddIOp>(
|
||||
loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
|
||||
}
|
||||
|
||||
// 4.2 Prepare indices for accessing the kernel tensor.
|
||||
|
@ -249,17 +208,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
// kernel
|
||||
kernelIndices.emplace_back(kernel);
|
||||
// c
|
||||
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
|
||||
kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
|
||||
// kX
|
||||
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
||||
kernelIndices.emplace_back(
|
||||
innerIterationBlock.getArguments()[i+1]);
|
||||
kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
|
||||
|
||||
// 4.3 Compute convolution.
|
||||
auto loadData =
|
||||
rewriter.create<LoadOp>(loc, operands[0], dataIndices);
|
||||
rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
||||
auto loadKernel =
|
||||
rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
|
||||
rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
|
||||
auto loadPartialSum =
|
||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||
|
@ -9,9 +10,8 @@ namespace onnf {
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
ParseResult
|
||||
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
||||
Value &operand) {
|
||||
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||
const Type &operandType, Value &operand) {
|
||||
// If operand queue is empty, parse more operands and cache them.
|
||||
if (_operandRefQueue.empty()) {
|
||||
// Parse operand types:
|
||||
|
@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
|||
_parser.parseOperandList(operand_refs);
|
||||
|
||||
// Record operands:
|
||||
for (auto& operand_ref : operand_refs)
|
||||
for (auto &operand_ref : operand_refs)
|
||||
_operandRefQueue.emplace(operand_ref);
|
||||
}
|
||||
|
||||
|
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
|
||||
Value &operand) {
|
||||
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||
const Type &operandType, Value &operand) {
|
||||
if (ParseOptionalOperand(operandType, operand))
|
||||
return _parser.emitError(
|
||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||
|
@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
|
|||
return success();
|
||||
}
|
||||
|
||||
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
||||
unsigned numSymbols, OpAsmPrinter& p) {
|
||||
void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
|
||||
unsigned numSymbols, OpAsmPrinter &p) {
|
||||
p << '(';
|
||||
p.printOperands(begin, begin + numDims);
|
||||
p << ')';
|
||||
|
@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
|||
}
|
||||
|
||||
void printBound(AffineMapAttr boundMap,
|
||||
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
||||
OpAsmPrinter& p) {
|
||||
Operation::operand_iterator &boundOperandsBeg, const char *prefix,
|
||||
OpAsmPrinter &p) {
|
||||
AffineMap map = boundMap.getValue();
|
||||
|
||||
// Check if this bound should be printed using custom assembly form.
|
||||
|
@ -120,9 +120,10 @@ void printBound(AffineMapAttr boundMap,
|
|||
printDimAndSymbolList(
|
||||
boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p);
|
||||
}
|
||||
} // namespace onnf
|
||||
} // namespace onnf
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||
if (boundMaps.size() % 2 == 0)
|
||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||
|
@ -137,4 +138,125 @@ void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
|
|||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||
_operands.emplace_back(operand);
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
BuildKrnlLoop::BuildKrnlLoop(
|
||||
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
|
||||
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
|
||||
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
|
||||
createdIterateOp(false) {
|
||||
if (originalLoopNum <= 0)
|
||||
emitError(loc, "expected positive number of original loops");
|
||||
}
|
||||
|
||||
BuildKrnlLoop::BuildKrnlLoop(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
|
||||
: BuildKrnlLoop(rewriter, loc,
|
||||
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
|
||||
|
||||
BuildKrnlLoop::~BuildKrnlLoop() {
|
||||
if (!createdDefineOp)
|
||||
emitError(loc, "expected to create define op");
|
||||
if (!createdIterateOp)
|
||||
emitError(loc, "expected to create iteration op");
|
||||
if (pack)
|
||||
free(pack);
|
||||
}
|
||||
|
||||
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
|
||||
// insert define loop op
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
|
||||
originalLoops.reserve(originalLoopNum);
|
||||
for (auto result : loopsOp.getResults())
|
||||
originalLoops.push_back(result);
|
||||
// inserte optimize loop op.
|
||||
auto optimizedLoopsOp =
|
||||
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
|
||||
optLoops.reserve(originalLoopNum);
|
||||
// Emit empty optimizations
|
||||
if (withEmptyOptimization) {
|
||||
for (auto result : optimizedLoopsOp.getResults())
|
||||
optLoops.push_back(result);
|
||||
optBlock = &optimizedLoopsOp.region().front();
|
||||
auto ip = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPointToEnd(optBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
rewriter.restoreInsertionPoint(ip);
|
||||
}
|
||||
// prepare data structure to push bounds
|
||||
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
|
||||
createdOptimizeOp = true;
|
||||
}
|
||||
|
||||
// push bounds (lower and upper) and return index for loop info
|
||||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
|
||||
pack->pushConstantBound(lowerBound);
|
||||
pack->pushConstantBound(upperBound);
|
||||
return pushCount++;
|
||||
}
|
||||
|
||||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
|
||||
pack->pushConstantBound(lowerBound);
|
||||
pack->pushOperandBound(upperBound);
|
||||
return pushCount++;
|
||||
}
|
||||
|
||||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
|
||||
pack->pushConstantBound(lowerBound);
|
||||
// process upperBound as a dimension of mem ref, possibly non-constant
|
||||
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
|
||||
if (shape[upperBoundMemRefIndex] < 0) {
|
||||
if (upperBoundMustBeConstant)
|
||||
emitError(loc, "bound expected to be constant");
|
||||
pack->pushOperandBound(
|
||||
rewriter
|
||||
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
|
||||
.getResult());
|
||||
} else
|
||||
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
|
||||
return pushCount++;
|
||||
}
|
||||
|
||||
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
|
||||
pack->pushOperandBound(lowerBound);
|
||||
pack->pushOperandBound(upperBound);
|
||||
return pushCount++;
|
||||
}
|
||||
|
||||
// create iter
|
||||
void BuildKrnlLoop::createIterateOp() {
|
||||
if (!createdDefineOp)
|
||||
emitError(loc, "must create define op before iterate op");
|
||||
// Tight now, optimize (possibly empty) is mandatory. This may change
|
||||
if (!createdOptimizeOp)
|
||||
emitError(loc, "must create optimize op before iterate op");
|
||||
// have to have defined all bounds
|
||||
if (pushCount != originalLoopNum) {
|
||||
printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
|
||||
emitError(loc, "must push bounds for all original loops");
|
||||
}
|
||||
// create iterate op
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
|
||||
iterBlock = &iterateOp.bodyRegion().front();
|
||||
createdIterateOp = true;
|
||||
}
|
||||
|
||||
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
|
||||
Value memRefOperand, bool withEmptyOptimization) {
|
||||
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
|
||||
if (originalLoopNum != loopNum)
|
||||
emitError(loc, "mismatch in loop numbers from constructor and define");
|
||||
createDefineAndOptimizeOp(withEmptyOptimization);
|
||||
for (int i = 0; i < originalLoopNum; ++i)
|
||||
pushBounds(0, memRefOperand, i);
|
||||
createIterateOp();
|
||||
}
|
||||
|
||||
// get induction variable to be use within iter
|
||||
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
|
||||
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
|
||||
emitError(loc, "original loop index is out of bound");
|
||||
return iterBlock->getArguments()[originalLoopIndex];
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -8,39 +8,38 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnf {
|
||||
|
||||
class KrnlDialectOperandParser {
|
||||
public:
|
||||
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
||||
public:
|
||||
explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
|
||||
: _parser(parser), _builder(parser.getBuilder()){};
|
||||
|
||||
// Parse an optional operand.
|
||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
||||
mlir::Value &operand);
|
||||
mlir::ParseResult ParseOptionalOperand(
|
||||
const mlir::Type &operandType, mlir::Value &operand);
|
||||
|
||||
// Parse an optional operand and push it to an operand list.
|
||||
mlir::ParseResult
|
||||
ParseOptionalOperand(const mlir::Type &operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||
|
||||
// Parse a required operand.
|
||||
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
||||
mlir::Value &operand);
|
||||
mlir::ParseResult ParseOperand(
|
||||
const mlir::Type &operandType, mlir::Value &operand);
|
||||
|
||||
// Parse a required operand and push it to an operand list.
|
||||
mlir::ParseResult
|
||||
ParseOperand(const mlir::Type &operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||
|
||||
// Do we have more operands to parse?
|
||||
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||
|
||||
private:
|
||||
mlir::OpAsmParser& _parser;
|
||||
private:
|
||||
mlir::OpAsmParser &_parser;
|
||||
|
||||
mlir::Builder& _builder;
|
||||
mlir::Builder &_builder;
|
||||
|
||||
// A queue storing the parsed SSA id references.
|
||||
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
||||
|
@ -50,24 +49,24 @@ class KrnlDialectOperandParser {
|
|||
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
||||
// Main difference is that it advances the iterator `begin` as it consumes
|
||||
// dimension and symbol operands.
|
||||
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
|
||||
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
|
||||
void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
|
||||
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
|
||||
|
||||
// Adapted from:
|
||||
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
||||
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
||||
// prints bound.
|
||||
void printBound(mlir::AffineMapAttr boundMap,
|
||||
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
||||
mlir::OpAsmPrinter& p);
|
||||
} // namespace onnf
|
||||
mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
|
||||
mlir::OpAsmPrinter &p);
|
||||
} // namespace onnf
|
||||
|
||||
namespace mlir {
|
||||
|
||||
struct KrnlIterateOperandPack {
|
||||
KrnlIterateOperandPack(mlir::Builder &builder,
|
||||
llvm::ArrayRef<mlir::Value> inputLoops,
|
||||
llvm::ArrayRef<mlir::Value> optimizedLoops)
|
||||
llvm::ArrayRef<mlir::Value> inputLoops,
|
||||
llvm::ArrayRef<mlir::Value> optimizedLoops)
|
||||
: builder(builder), inputLoops(inputLoops),
|
||||
optimizedLoops(optimizedLoops) {
|
||||
_operands.insert(
|
||||
|
@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
|
|||
|
||||
size_t getNumInputLoops() const { return inputLoops.size(); }
|
||||
|
||||
private:
|
||||
private:
|
||||
int _boundIdx = 0;
|
||||
|
||||
llvm::SmallVector<mlir::Value, 8> _operands;
|
||||
|
@ -97,7 +96,99 @@ struct KrnlIterateOperandPack {
|
|||
|
||||
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
||||
|
||||
mlir::Builder& builder;
|
||||
mlir::Builder &builder;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
// Helper function to write kernel loops. This class will let us build a single
|
||||
// define/optimize/iterate operation combo. We can then insert optimizations in
|
||||
// the body of the optimization operation, and operations in the body of the
|
||||
// iterate operation.
|
||||
//
|
||||
// The sequence is as follow:
|
||||
//
|
||||
// 1) Create a object giving the rewriter, location, and number of loop in the
|
||||
// original (non optimized) loop.
|
||||
//
|
||||
// 2) Create define & optimize ops (currently paired). Optimizations can then
|
||||
// be added to the inner block of the optimize operation. Make sure to set the
|
||||
// insertion point to that block for optimizations to go in the right place.
|
||||
//
|
||||
// 3) Push the bounds for each of the original loops. Bounds are pushed in
|
||||
// pairs (lower & upper bounds). THere are a few methods to do it depending on
|
||||
// the type of the bounds. When pushing bounds, the method returns a number
|
||||
// that represent the index associated with that iteration (induction variable
|
||||
// and bounds). That index can be used later to extract the induction variable
|
||||
// for reference in computation and/or index calculations of mem refs.
|
||||
//
|
||||
// 4) Once all the bounds are pushed, create the iterate operation. Once this
|
||||
// is done, we can add operations within the iterate blocks by setting the
|
||||
// insertion point to it. Value of the induction variables can be retrieved
|
||||
// using the proper index (determined when pushin the bounds).
|
||||
|
||||
class BuildKrnlLoop {
|
||||
public:
|
||||
// Create a build kernel loop for the given location and loop number.
|
||||
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
|
||||
// Do the same, but where the loop number corresponds to the dimensionality of
|
||||
// the mem ref operand.
|
||||
BuildKrnlLoop(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
|
||||
~BuildKrnlLoop();
|
||||
|
||||
// Create define and optimize loop with loopNum original loops. If
|
||||
// withEmptyOptimization, the optimization is simply the identity function (no
|
||||
// optimizations).
|
||||
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
|
||||
|
||||
// Push bounds (lower and upper) for each of the loops, in order. It returns
|
||||
// the index associated with the loop iteration. This index is in the range
|
||||
// from zero to original loop number -1, and is monotonally increasing from
|
||||
// call to call. This index is later used in the getInductionVar call.
|
||||
int pushBounds(int64_t lowerBound, int64_t upperBound);
|
||||
int pushBounds(int64_t lowerBound, Value upperBound);
|
||||
int pushBounds(Value lowerBound, Value upperBound);
|
||||
// same, where the lower bound is an integer, and the uppoer bound is given by
|
||||
// the size of the mem ref operand along the upperBoundMemRefIndex dimension.
|
||||
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
|
||||
|
||||
// Create an iterate op.
|
||||
void createIterateOp();
|
||||
// Create an define, optimize and iterate op, with the same loop nummber as
|
||||
// the rank of the memRefOperand. The lower bound of each loops is zero, and
|
||||
// the upper bound of each loops is the dimension given by the mem refs
|
||||
void createDefineOptimizeAndIterateOp(
|
||||
Value memRefOperand, bool withEmptyOptimization = true);
|
||||
|
||||
// Get the (original loop) induction variable associated with the given index.
|
||||
// Use the index returned when pushing the bounds.
|
||||
BlockArgument &getInductionVar(int originalLoopIndex);
|
||||
|
||||
// Get blocks. This allow us to set the insertion point to the inner block of
|
||||
// the optimize and the iterate Operation
|
||||
Block *getOptimizationBlock() { return optBlock; }
|
||||
Block *getIterateBlock() { return iterBlock; }
|
||||
|
||||
// get original or optimized loops
|
||||
std::vector<Value> &getOriginalLoops() { return originalLoops; }
|
||||
std::vector<Value> &getOptimizedLoops() { return optLoops; }
|
||||
|
||||
private:
|
||||
// inputs
|
||||
ConversionPatternRewriter &rewriter;
|
||||
Location loc;
|
||||
int originalLoopNum;
|
||||
// track loops and bounds
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optLoops;
|
||||
KrnlIterateOperandPack *pack;
|
||||
int pushCount;
|
||||
bool createdDefineOp;
|
||||
bool createdOptimizeOp;
|
||||
bool createdIterateOp;
|
||||
// insertion points (opt block, iterate)
|
||||
Block *optBlock;
|
||||
Block *iterBlock;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue