Merge branch 'master' into shapeinference-pad
This commit is contained in:
commit
1ad7989fc5
|
@ -1,2 +1,3 @@
|
||||||
BasedOnStyle: LLVM
|
BasedOnStyle: LLVM
|
||||||
AlwaysBreakTemplateDeclarations: Yes
|
AlwaysBreakTemplateDeclarations: Yes
|
||||||
|
AlignAfterOpenBracket: DontAlign
|
||||||
|
|
|
@ -12,8 +12,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
@ -25,12 +24,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
{operands[0]});
|
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
||||||
|
|
||||||
auto resultShape = memRefType.getShape();
|
auto resultShape = memRefType.getShape();
|
||||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
auto &inputOperand = operands[0];
|
||||||
auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
|
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
auto &kernelOperand = operands[1];
|
||||||
|
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
// R = ConvNoBias(D, K)
|
// R = ConvNoBias(D, K)
|
||||||
//
|
//
|
||||||
|
@ -91,123 +92,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
loc, FloatAttr::get(memRefType.getElementType(), 0));
|
||||||
Value subchannels;
|
Value subchannels;
|
||||||
if (kernelShape[1] < 0) {
|
if (kernelShape[1] < 0) {
|
||||||
subchannels =
|
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
|
||||||
rewriter.create<DimOp>(loc, operands[1], 1).getResult();
|
|
||||||
} else {
|
} else {
|
||||||
subchannels = rewriter.create<ConstantIndexOp>(
|
subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
|
||||||
loc, kernelShape[1]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Define outer loops and emit empty optimization block:
|
// 1. Define outer loops and emit empty optimization block:
|
||||||
int64_t nOuterLoops = (group > 1) ? 3 : 2;
|
int64_t nOuterLoops = (group > 1) ? 3 : 2;
|
||||||
std::vector<Value> outerLoops;
|
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
|
||||||
std::vector<Value> optimizedOuterLoops;
|
outerLoops.createDefineAndOptimizeOp();
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
|
|
||||||
optimizedOuterLoops, nOuterLoops);
|
|
||||||
|
|
||||||
// Prepare iteration arguments over outer loop nest.
|
|
||||||
KrnlIterateOperandPack pack(
|
|
||||||
rewriter, outerLoops, optimizedOuterLoops);
|
|
||||||
// for n = 0 .. N:
|
// for n = 0 .. N:
|
||||||
pack.pushConstantBound(0);
|
int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
|
||||||
if (inputShape[0] < 0)
|
|
||||||
pack.pushOperandBound(
|
|
||||||
rewriter.create<DimOp>(loc, operands[0], 0).getResult());
|
|
||||||
else
|
|
||||||
pack.pushConstantBound(inputShape[0]);
|
|
||||||
// for g = 0 .. N:
|
// for g = 0 .. N:
|
||||||
if (group > 1) {
|
int gIndex = -1;
|
||||||
pack.pushConstantBound(0);
|
if (group > 1)
|
||||||
pack.pushConstantBound(group);
|
gIndex = outerLoops.pushBounds(0, group);
|
||||||
}
|
|
||||||
// for m = 0 .. kernelsPerGroup:
|
// for m = 0 .. kernelsPerGroup:
|
||||||
pack.pushConstantBound(0);
|
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
|
||||||
pack.pushConstantBound(kernelsPerGroup);
|
// Outer loop iteration
|
||||||
// Outer loop iteration.
|
outerLoops.createIterateOp();
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
||||||
Block &outerIterationBlock = iterateOp.bodyRegion().front();
|
|
||||||
// Emit optimizations for outer loops:
|
|
||||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
|
||||||
rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
|
|
||||||
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
|
||||||
{
|
{
|
||||||
// 2. Emit the body of the outer loop nest.
|
// 2. Emit the body of the outer loop nest.
|
||||||
|
|
||||||
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
|
// 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
|
||||||
// If group is not set then the value of the kernel ID is
|
// If group is not set then the value of the kernel ID is
|
||||||
// identical to that of the loop over kernels.
|
// identical to that of the loop over kernels.
|
||||||
Value kernel = outerIterationBlock.getArguments()[1];
|
Value kernel = outerLoops.getInductionVar(mIndex);
|
||||||
if (group > 1) {
|
if (group > 1) {
|
||||||
// Middle loop is over groups and third loop is over the
|
// Middle loop is over groups and third loop is over the
|
||||||
// kernel identifiers in the current group.
|
// kernel identifiers in the current group.
|
||||||
auto kernelsOffset = rewriter.create<MulIOp>(loc,
|
auto kernelsOffset = rewriter.create<MulIOp>(
|
||||||
outerIterationBlock.getArguments()[1],
|
loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
|
||||||
kernelsPerGroupValue);
|
kernel = rewriter.create<AddIOp>(
|
||||||
kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
|
loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
|
||||||
outerIterationBlock.getArguments()[2]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.2 Define spatial loops
|
// 2.2 Define spatial loops
|
||||||
int64_t nSpatialLoops = resultShape.size() - 2;
|
int64_t nSpatialLoops = resultShape.size() - 2;
|
||||||
std::vector<Value> spatialLoops;
|
BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
|
||||||
std::vector<Value> optimizedSpatialLoops;
|
spatialLoops.createDefineAndOptimizeOp();
|
||||||
Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
|
|
||||||
optimizedSpatialLoops, nSpatialLoops);
|
|
||||||
|
|
||||||
// 2.3 Prepare iteration arguments for spatial loop nest.
|
|
||||||
KrnlIterateOperandPack spatialPack(
|
|
||||||
rewriter, spatialLoops, optimizedSpatialLoops);
|
|
||||||
for (int i = 2; i < resultShape.size(); ++i)
|
for (int i = 2; i < resultShape.size(); ++i)
|
||||||
addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
|
spatialLoops.pushBounds(0, alloc, i);
|
||||||
|
|
||||||
// 2.4 Emit loop nest over output spatial dimensions.
|
// 2.4 Emit loop nest over output spatial dimensions.
|
||||||
// for rX = 0 .. RX
|
// for rX = 0 .. RX
|
||||||
auto spatialIterateOp =
|
spatialLoops.createIterateOp();
|
||||||
rewriter.create<KrnlIterateOp>(loc, spatialPack);
|
rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
|
||||||
Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
|
|
||||||
// 2.5 Emit optimizations for outer loops:
|
|
||||||
rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
|
|
||||||
rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
|
|
||||||
rewriter.setInsertionPointToStart(&spatialIterationBlock);
|
|
||||||
{
|
{
|
||||||
// 3. Emit the body of the spatial loop nest.
|
// 3. Emit the body of the spatial loop nest.
|
||||||
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
|
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
|
||||||
SmallVector<Value, 4> resultIndices;
|
SmallVector<Value, 4> resultIndices;
|
||||||
// n
|
// n
|
||||||
resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
|
||||||
// kernel
|
// kernel
|
||||||
resultIndices.emplace_back(kernel);
|
resultIndices.emplace_back(kernel);
|
||||||
// rX
|
// rX
|
||||||
for (auto arg : spatialIterationBlock.getArguments())
|
for (auto arg : spatialLoops.getIterateBlock()->getArguments())
|
||||||
resultIndices.emplace_back(arg);
|
resultIndices.emplace_back(arg);
|
||||||
// Store initializer value into output location.
|
// Store initializer value into output location.
|
||||||
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
|
rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
|
||||||
|
|
||||||
// 3.2 Define inner loops.
|
// 3.2 Define inner loops.
|
||||||
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
|
int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
|
||||||
std::vector<Value> innerLoops;
|
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
|
||||||
std::vector<Value> optimizedInnerLoops;
|
innerLoops.createDefineAndOptimizeOp();
|
||||||
Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
|
|
||||||
optimizedInnerLoops, nInnerLoops);
|
|
||||||
|
|
||||||
// 3.3 Prepare iteration arguments for inner loop nest.
|
|
||||||
KrnlIterateOperandPack innerPack(
|
|
||||||
rewriter, innerLoops, optimizedInnerLoops);
|
|
||||||
// for c = 0 .. C/group
|
// for c = 0 .. C/group
|
||||||
innerPack.pushConstantBound(0);
|
int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
|
||||||
innerPack.pushConstantBound(kernelShape[1]);
|
|
||||||
// for Kx = 0 .. KX
|
// for Kx = 0 .. KX
|
||||||
for (int i = 2; i < kernelShape.size(); ++i)
|
for (int i = 2; i < kernelShape.size(); ++i)
|
||||||
addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
|
innerLoops.pushBounds(0, kernelOperand, i);
|
||||||
|
|
||||||
// 3.4 Emit inner loop nest.
|
// 3.4 Emit inner loop nest.
|
||||||
auto innerIterateOp =
|
innerLoops.createIterateOp();
|
||||||
rewriter.create<KrnlIterateOp>(loc, innerPack);
|
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
||||||
Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
|
|
||||||
// 3.5 Emit optimizations for outer loops:
|
|
||||||
rewriter.setInsertionPointToEnd(optInnerLoopBlock);
|
|
||||||
rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
|
|
||||||
rewriter.setInsertionPointToStart(&innerIterationBlock);
|
|
||||||
{
|
{
|
||||||
// 4. Emit inner loop body
|
// 4. Emit inner loop body
|
||||||
// R[n][kernel][r1][r2] =
|
// R[n][kernel][r1][r2] =
|
||||||
|
@ -217,13 +177,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
// 4.1 Prepare indices for accesing the data tensor.
|
// 4.1 Prepare indices for accesing the data tensor.
|
||||||
SmallVector<Value, 4> dataIndices;
|
SmallVector<Value, 4> dataIndices;
|
||||||
// n
|
// n
|
||||||
dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
|
dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
|
||||||
// g * (C / group) + c
|
// g * (C / group) + c
|
||||||
Value channelDepth = innerIterationBlock.getArguments()[0];
|
Value channelDepth = innerLoops.getInductionVar(cIndex);
|
||||||
if (group > 1)
|
if (group > 1)
|
||||||
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
|
channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
|
||||||
rewriter.create<MulIOp>(loc, subchannels,
|
rewriter.create<MulIOp>(
|
||||||
outerIterationBlock.getArguments()[1]));
|
loc, subchannels, outerLoops.getInductionVar(gIndex)));
|
||||||
dataIndices.emplace_back(channelDepth);
|
dataIndices.emplace_back(channelDepth);
|
||||||
// sX * rX + kX
|
// sX * rX + kX
|
||||||
auto stridesAttribute = convOp.stridesAttr();
|
auto stridesAttribute = convOp.stridesAttr();
|
||||||
|
@ -233,15 +193,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
for (auto stride : stridesAttribute.getValue())
|
for (auto stride : stridesAttribute.getValue())
|
||||||
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
||||||
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
||||||
Value spatialIndex = spatialIterationBlock.getArguments()[i];
|
Value spatialIndex = spatialLoops.getInductionVar(i);
|
||||||
// If strides are present then emit the correct access index.
|
// If strides are present then emit the correct access index.
|
||||||
if (stridesAttribute && strides[i] > 1)
|
if (stridesAttribute && strides[i] > 1)
|
||||||
spatialIndex = rewriter.create<MulIOp>(loc,
|
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||||
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
||||||
spatialIterationBlock.getArguments()[i]);
|
spatialLoops.getInductionVar(i));
|
||||||
dataIndices.emplace_back(
|
dataIndices.emplace_back(rewriter.create<AddIOp>(
|
||||||
rewriter.create<AddIOp>(loc, spatialIndex,
|
loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
|
||||||
innerIterationBlock.getArguments()[i+1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4.2 Prepare indices for accessing the kernel tensor.
|
// 4.2 Prepare indices for accessing the kernel tensor.
|
||||||
|
@ -249,17 +208,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
// kernel
|
// kernel
|
||||||
kernelIndices.emplace_back(kernel);
|
kernelIndices.emplace_back(kernel);
|
||||||
// c
|
// c
|
||||||
kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
|
kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
|
||||||
// kX
|
// kX
|
||||||
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
||||||
kernelIndices.emplace_back(
|
kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
|
||||||
innerIterationBlock.getArguments()[i+1]);
|
|
||||||
|
|
||||||
// 4.3 Compute convolution.
|
// 4.3 Compute convolution.
|
||||||
auto loadData =
|
auto loadData =
|
||||||
rewriter.create<LoadOp>(loc, operands[0], dataIndices);
|
rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
||||||
auto loadKernel =
|
auto loadKernel =
|
||||||
rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
|
rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
|
||||||
auto loadPartialSum =
|
auto loadPartialSum =
|
||||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||||
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
|
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
|
||||||
#include "src/dialect/krnl/krnl_ops.hpp"
|
#include "src/dialect/krnl/krnl_ops.hpp"
|
||||||
|
@ -9,9 +10,8 @@ namespace onnf {
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
ParseResult
|
ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
const Type &operandType, Value &operand) {
|
||||||
Value &operand) {
|
|
||||||
// If operand queue is empty, parse more operands and cache them.
|
// If operand queue is empty, parse more operands and cache them.
|
||||||
if (_operandRefQueue.empty()) {
|
if (_operandRefQueue.empty()) {
|
||||||
// Parse operand types:
|
// Parse operand types:
|
||||||
|
@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
|
||||||
_parser.parseOperandList(operand_refs);
|
_parser.parseOperandList(operand_refs);
|
||||||
|
|
||||||
// Record operands:
|
// Record operands:
|
||||||
for (auto& operand_ref : operand_refs)
|
for (auto &operand_ref : operand_refs)
|
||||||
_operandRefQueue.emplace(operand_ref);
|
_operandRefQueue.emplace(operand_ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
|
ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
Value &operand) {
|
const Type &operandType, Value &operand) {
|
||||||
if (ParseOptionalOperand(operandType, operand))
|
if (ParseOptionalOperand(operandType, operand))
|
||||||
return _parser.emitError(
|
return _parser.emitError(
|
||||||
_parser.getCurrentLocation(), "Expecting an operand.");
|
_parser.getCurrentLocation(), "Expecting an operand.");
|
||||||
|
@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
|
||||||
unsigned numSymbols, OpAsmPrinter& p) {
|
unsigned numSymbols, OpAsmPrinter &p) {
|
||||||
p << '(';
|
p << '(';
|
||||||
p.printOperands(begin, begin + numDims);
|
p.printOperands(begin, begin + numDims);
|
||||||
p << ')';
|
p << ')';
|
||||||
|
@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
|
||||||
}
|
}
|
||||||
|
|
||||||
void printBound(AffineMapAttr boundMap,
|
void printBound(AffineMapAttr boundMap,
|
||||||
Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
Operation::operand_iterator &boundOperandsBeg, const char *prefix,
|
||||||
OpAsmPrinter& p) {
|
OpAsmPrinter &p) {
|
||||||
AffineMap map = boundMap.getValue();
|
AffineMap map = boundMap.getValue();
|
||||||
|
|
||||||
// Check if this bound should be printed using custom assembly form.
|
// Check if this bound should be printed using custom assembly form.
|
||||||
|
@ -123,6 +123,7 @@ void printBound(AffineMapAttr boundMap,
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
|
||||||
if (boundMaps.size() % 2 == 0)
|
if (boundMaps.size() % 2 == 0)
|
||||||
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
_operands.emplace_back(inputLoops[boundMaps.size() / 2]);
|
||||||
|
@ -137,4 +138,125 @@ void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
|
||||||
boundMaps.emplace_back(AffineMapAttr::get(map));
|
boundMaps.emplace_back(AffineMapAttr::get(map));
|
||||||
_operands.emplace_back(operand);
|
_operands.emplace_back(operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BuildKrnlLoop::BuildKrnlLoop(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, int loopNum)
|
||||||
|
: rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
|
||||||
|
pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
|
||||||
|
createdIterateOp(false) {
|
||||||
|
if (originalLoopNum <= 0)
|
||||||
|
emitError(loc, "expected positive number of original loops");
|
||||||
|
}
|
||||||
|
|
||||||
|
BuildKrnlLoop::BuildKrnlLoop(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
|
||||||
|
: BuildKrnlLoop(rewriter, loc,
|
||||||
|
memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
|
||||||
|
|
||||||
|
BuildKrnlLoop::~BuildKrnlLoop() {
|
||||||
|
if (!createdDefineOp)
|
||||||
|
emitError(loc, "expected to create define op");
|
||||||
|
if (!createdIterateOp)
|
||||||
|
emitError(loc, "expected to create iteration op");
|
||||||
|
if (pack)
|
||||||
|
free(pack);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
|
||||||
|
// insert define loop op
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
|
||||||
|
originalLoops.reserve(originalLoopNum);
|
||||||
|
for (auto result : loopsOp.getResults())
|
||||||
|
originalLoops.push_back(result);
|
||||||
|
// inserte optimize loop op.
|
||||||
|
auto optimizedLoopsOp =
|
||||||
|
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
|
||||||
|
optLoops.reserve(originalLoopNum);
|
||||||
|
// Emit empty optimizations
|
||||||
|
if (withEmptyOptimization) {
|
||||||
|
for (auto result : optimizedLoopsOp.getResults())
|
||||||
|
optLoops.push_back(result);
|
||||||
|
optBlock = &optimizedLoopsOp.region().front();
|
||||||
|
auto ip = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointToEnd(optBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
rewriter.restoreInsertionPoint(ip);
|
||||||
|
}
|
||||||
|
// prepare data structure to push bounds
|
||||||
|
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
|
||||||
|
createdOptimizeOp = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// push bounds (lower and upper) and return index for loop info
|
||||||
|
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
|
||||||
|
pack->pushConstantBound(lowerBound);
|
||||||
|
pack->pushConstantBound(upperBound);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
|
||||||
|
pack->pushConstantBound(lowerBound);
|
||||||
|
pack->pushOperandBound(upperBound);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||||
|
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
|
||||||
|
pack->pushConstantBound(lowerBound);
|
||||||
|
// process upperBound as a dimension of mem ref, possibly non-constant
|
||||||
|
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
if (shape[upperBoundMemRefIndex] < 0) {
|
||||||
|
if (upperBoundMustBeConstant)
|
||||||
|
emitError(loc, "bound expected to be constant");
|
||||||
|
pack->pushOperandBound(
|
||||||
|
rewriter
|
||||||
|
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
|
||||||
|
.getResult());
|
||||||
|
} else
|
||||||
|
pack->pushConstantBound(shape[upperBoundMemRefIndex]);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
|
||||||
|
pack->pushOperandBound(lowerBound);
|
||||||
|
pack->pushOperandBound(upperBound);
|
||||||
|
return pushCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create iter
|
||||||
|
void BuildKrnlLoop::createIterateOp() {
|
||||||
|
if (!createdDefineOp)
|
||||||
|
emitError(loc, "must create define op before iterate op");
|
||||||
|
// Tight now, optimize (possibly empty) is mandatory. This may change
|
||||||
|
if (!createdOptimizeOp)
|
||||||
|
emitError(loc, "must create optimize op before iterate op");
|
||||||
|
// have to have defined all bounds
|
||||||
|
if (pushCount != originalLoopNum) {
|
||||||
|
printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
|
||||||
|
emitError(loc, "must push bounds for all original loops");
|
||||||
|
}
|
||||||
|
// create iterate op
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
|
||||||
|
iterBlock = &iterateOp.bodyRegion().front();
|
||||||
|
createdIterateOp = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
|
||||||
|
Value memRefOperand, bool withEmptyOptimization) {
|
||||||
|
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
|
||||||
|
if (originalLoopNum != loopNum)
|
||||||
|
emitError(loc, "mismatch in loop numbers from constructor and define");
|
||||||
|
createDefineAndOptimizeOp(withEmptyOptimization);
|
||||||
|
for (int i = 0; i < originalLoopNum; ++i)
|
||||||
|
pushBounds(0, memRefOperand, i);
|
||||||
|
createIterateOp();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get induction variable to be use within iter
|
||||||
|
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
|
||||||
|
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
|
||||||
|
emitError(loc, "original loop index is out of bound");
|
||||||
|
return iterBlock->getArguments()[originalLoopIndex];
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -8,39 +8,38 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
|
|
||||||
class KrnlDialectOperandParser {
|
class KrnlDialectOperandParser {
|
||||||
public:
|
public:
|
||||||
explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
|
explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
|
||||||
: _parser(parser), _builder(parser.getBuilder()){};
|
: _parser(parser), _builder(parser.getBuilder()){};
|
||||||
|
|
||||||
// Parse an optional operand.
|
// Parse an optional operand.
|
||||||
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
mlir::ParseResult ParseOptionalOperand(
|
||||||
mlir::Value &operand);
|
const mlir::Type &operandType, mlir::Value &operand);
|
||||||
|
|
||||||
// Parse an optional operand and push it to an operand list.
|
// Parse an optional operand and push it to an operand list.
|
||||||
mlir::ParseResult
|
mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
|
||||||
ParseOptionalOperand(const mlir::Type &operandType,
|
|
||||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||||
|
|
||||||
// Parse a required operand.
|
// Parse a required operand.
|
||||||
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
mlir::ParseResult ParseOperand(
|
||||||
mlir::Value &operand);
|
const mlir::Type &operandType, mlir::Value &operand);
|
||||||
|
|
||||||
// Parse a required operand and push it to an operand list.
|
// Parse a required operand and push it to an operand list.
|
||||||
mlir::ParseResult
|
mlir::ParseResult ParseOperand(const mlir::Type &operandType,
|
||||||
ParseOperand(const mlir::Type &operandType,
|
|
||||||
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
llvm::SmallVectorImpl<mlir::Value> &operandList);
|
||||||
|
|
||||||
// Do we have more operands to parse?
|
// Do we have more operands to parse?
|
||||||
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
bool hasOperandLeft() { return !_operandRefQueue.empty(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::OpAsmParser& _parser;
|
mlir::OpAsmParser &_parser;
|
||||||
|
|
||||||
mlir::Builder& _builder;
|
mlir::Builder &_builder;
|
||||||
|
|
||||||
// A queue storing the parsed SSA id references.
|
// A queue storing the parsed SSA id references.
|
||||||
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
|
||||||
|
@ -50,16 +49,16 @@ class KrnlDialectOperandParser {
|
||||||
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
|
||||||
// Main difference is that it advances the iterator `begin` as it consumes
|
// Main difference is that it advances the iterator `begin` as it consumes
|
||||||
// dimension and symbol operands.
|
// dimension and symbol operands.
|
||||||
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
|
void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
|
||||||
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
|
unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
|
||||||
|
|
||||||
// Adapted from:
|
// Adapted from:
|
||||||
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
|
||||||
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
// Main difference is that it advances the iterator `boundOperandsBeg` as it
|
||||||
// prints bound.
|
// prints bound.
|
||||||
void printBound(mlir::AffineMapAttr boundMap,
|
void printBound(mlir::AffineMapAttr boundMap,
|
||||||
mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
|
mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
|
||||||
mlir::OpAsmPrinter& p);
|
mlir::OpAsmPrinter &p);
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
|
||||||
|
|
||||||
size_t getNumInputLoops() const { return inputLoops.size(); }
|
size_t getNumInputLoops() const { return inputLoops.size(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int _boundIdx = 0;
|
int _boundIdx = 0;
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value, 8> _operands;
|
llvm::SmallVector<mlir::Value, 8> _operands;
|
||||||
|
@ -97,7 +96,99 @@ struct KrnlIterateOperandPack {
|
||||||
|
|
||||||
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
|
||||||
|
|
||||||
mlir::Builder& builder;
|
mlir::Builder &builder;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper function to write kernel loops. This class will let us build a single
|
||||||
|
// define/optimize/iterate operation combo. We can then insert optimizations in
|
||||||
|
// the body of the optimization operation, and operations in the body of the
|
||||||
|
// iterate operation.
|
||||||
|
//
|
||||||
|
// The sequence is as follow:
|
||||||
|
//
|
||||||
|
// 1) Create a object giving the rewriter, location, and number of loop in the
|
||||||
|
// original (non optimized) loop.
|
||||||
|
//
|
||||||
|
// 2) Create define & optimize ops (currently paired). Optimizations can then
|
||||||
|
// be added to the inner block of the optimize operation. Make sure to set the
|
||||||
|
// insertion point to that block for optimizations to go in the right place.
|
||||||
|
//
|
||||||
|
// 3) Push the bounds for each of the original loops. Bounds are pushed in
|
||||||
|
// pairs (lower & upper bounds). THere are a few methods to do it depending on
|
||||||
|
// the type of the bounds. When pushing bounds, the method returns a number
|
||||||
|
// that represent the index associated with that iteration (induction variable
|
||||||
|
// and bounds). That index can be used later to extract the induction variable
|
||||||
|
// for reference in computation and/or index calculations of mem refs.
|
||||||
|
//
|
||||||
|
// 4) Once all the bounds are pushed, create the iterate operation. Once this
|
||||||
|
// is done, we can add operations within the iterate blocks by setting the
|
||||||
|
// insertion point to it. Value of the induction variables can be retrieved
|
||||||
|
// using the proper index (determined when pushin the bounds).
|
||||||
|
|
||||||
|
class BuildKrnlLoop {
|
||||||
|
public:
|
||||||
|
// Create a build kernel loop for the given location and loop number.
|
||||||
|
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
|
||||||
|
// Do the same, but where the loop number corresponds to the dimensionality of
|
||||||
|
// the mem ref operand.
|
||||||
|
BuildKrnlLoop(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
|
||||||
|
~BuildKrnlLoop();
|
||||||
|
|
||||||
|
// Create define and optimize loop with loopNum original loops. If
|
||||||
|
// withEmptyOptimization, the optimization is simply the identity function (no
|
||||||
|
// optimizations).
|
||||||
|
void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
|
||||||
|
|
||||||
|
// Push bounds (lower and upper) for each of the loops, in order. It returns
|
||||||
|
// the index associated with the loop iteration. This index is in the range
|
||||||
|
// from zero to original loop number -1, and is monotonally increasing from
|
||||||
|
// call to call. This index is later used in the getInductionVar call.
|
||||||
|
int pushBounds(int64_t lowerBound, int64_t upperBound);
|
||||||
|
int pushBounds(int64_t lowerBound, Value upperBound);
|
||||||
|
int pushBounds(Value lowerBound, Value upperBound);
|
||||||
|
// same, where the lower bound is an integer, and the uppoer bound is given by
|
||||||
|
// the size of the mem ref operand along the upperBoundMemRefIndex dimension.
|
||||||
|
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
|
||||||
|
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
|
||||||
|
|
||||||
|
// Create an iterate op.
|
||||||
|
void createIterateOp();
|
||||||
|
// Create an define, optimize and iterate op, with the same loop nummber as
|
||||||
|
// the rank of the memRefOperand. The lower bound of each loops is zero, and
|
||||||
|
// the upper bound of each loops is the dimension given by the mem refs
|
||||||
|
void createDefineOptimizeAndIterateOp(
|
||||||
|
Value memRefOperand, bool withEmptyOptimization = true);
|
||||||
|
|
||||||
|
// Get the (original loop) induction variable associated with the given index.
|
||||||
|
// Use the index returned when pushing the bounds.
|
||||||
|
BlockArgument &getInductionVar(int originalLoopIndex);
|
||||||
|
|
||||||
|
// Get blocks. This allow us to set the insertion point to the inner block of
|
||||||
|
// the optimize and the iterate Operation
|
||||||
|
Block *getOptimizationBlock() { return optBlock; }
|
||||||
|
Block *getIterateBlock() { return iterBlock; }
|
||||||
|
|
||||||
|
// get original or optimized loops
|
||||||
|
std::vector<Value> &getOriginalLoops() { return originalLoops; }
|
||||||
|
std::vector<Value> &getOptimizedLoops() { return optLoops; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// inputs
|
||||||
|
ConversionPatternRewriter &rewriter;
|
||||||
|
Location loc;
|
||||||
|
int originalLoopNum;
|
||||||
|
// track loops and bounds
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
std::vector<Value> optLoops;
|
||||||
|
KrnlIterateOperandPack *pack;
|
||||||
|
int pushCount;
|
||||||
|
bool createdDefineOp;
|
||||||
|
bool createdOptimizeOp;
|
||||||
|
bool createdIterateOp;
|
||||||
|
// insertion points (opt block, iterate)
|
||||||
|
Block *optBlock;
|
||||||
|
Block *iterBlock;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue