From fcb5f35993ba5b7fce51a33f7faafebc62cedd74 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 24 Feb 2020 17:20:15 -0500 Subject: [PATCH] Introduce helper class to generate KRNL code and apply it to Convolution (#93) * helper to gen krnl code, applied to conv * suggested changes, name, removed set insertion point * format * suggested changes * added comments and made a small name change --- .clang-format | 1 + .../onnx_to_krnl/rewrite_patterns/nn/conv.inc | 146 +++++++----------- src/dialect/krnl/krnl_helper.cpp | 146 ++++++++++++++++-- src/dialect/krnl/krnl_helper.hpp | 141 ++++++++++++++--- 4 files changed, 303 insertions(+), 131 deletions(-) diff --git a/.clang-format b/.clang-format index a74fda4..b3276c6 100644 --- a/.clang-format +++ b/.clang-format @@ -1,2 +1,3 @@ BasedOnStyle: LLVM AlwaysBreakTemplateDeclarations: Yes +AlignAfterOpenBracket: DontAlign diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc index 20ac5e8..6e3afe1 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc @@ -12,9 +12,8 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { ONNXConvNoBiasOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); @@ -25,12 +24,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {operands[0]}); auto resultShape = memRefType.getShape(); - auto inputShape = operands[0].getType().cast().getShape(); - auto kernelShape = operands[1].getType().cast().getShape(); + auto &inputOperand = operands[0]; + auto inputShape = inputOperand.getType().cast().getShape(); + auto &kernelOperand = operands[1]; + auto kernelShape = kernelOperand.getType().cast().getShape(); // R = ConvNoBias(D, K) // @@ -91,123 +92,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { loc, FloatAttr::get(memRefType.getElementType(), 0)); Value subchannels; if (kernelShape[1] < 0) { - subchannels = - rewriter.create(loc, operands[1], 1).getResult(); + subchannels = rewriter.create(loc, kernelOperand, 1).getResult(); } else { - subchannels = rewriter.create( - loc, kernelShape[1]); + subchannels = rewriter.create(loc, kernelShape[1]); } // 1. Define outer loops and emit empty optimization block: int64_t nOuterLoops = (group > 1) ? 3 : 2; - std::vector outerLoops; - std::vector optimizedOuterLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, - optimizedOuterLoops, nOuterLoops); - - // Prepare iteration arguments over outer loop nest. - KrnlIterateOperandPack pack( - rewriter, outerLoops, optimizedOuterLoops); + BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops); + outerLoops.createDefineAndOptimizeOp(); // for n = 0 .. N: - pack.pushConstantBound(0); - if (inputShape[0] < 0) - pack.pushOperandBound( - rewriter.create(loc, operands[0], 0).getResult()); - else - pack.pushConstantBound(inputShape[0]); + int nIndex = outerLoops.pushBounds(0, inputOperand, 0); // for g = 0 .. N: - if (group > 1) { - pack.pushConstantBound(0); - pack.pushConstantBound(group); - } + int gIndex = -1; + if (group > 1) + gIndex = outerLoops.pushBounds(0, group); // for m = 0 .. kernelsPerGroup: - pack.pushConstantBound(0); - pack.pushConstantBound(kernelsPerGroup); - // Outer loop iteration. - auto iterateOp = rewriter.create(loc, pack); - Block &outerIterationBlock = iterateOp.bodyRegion().front(); - // Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, outerLoops); - rewriter.setInsertionPointToStart(&outerIterationBlock); + int mIndex = outerLoops.pushBounds(0, kernelsPerGroup); + // Outer loop iteration + outerLoops.createIterateOp(); + rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); { // 2. Emit the body of the outer loop nest. // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; // If group is not set then the value of the kernel ID is // identical to that of the loop over kernels. - Value kernel = outerIterationBlock.getArguments()[1]; + Value kernel = outerLoops.getInductionVar(mIndex); if (group > 1) { // Middle loop is over groups and third loop is over the // kernel identifiers in the current group. - auto kernelsOffset = rewriter.create(loc, - outerIterationBlock.getArguments()[1], - kernelsPerGroupValue); - kernel = rewriter.create(loc, kernelsOffset, - outerIterationBlock.getArguments()[2]); + auto kernelsOffset = rewriter.create( + loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue); + kernel = rewriter.create( + loc, kernelsOffset, outerLoops.getInductionVar(mIndex)); } // 2.2 Define spatial loops int64_t nSpatialLoops = resultShape.size() - 2; - std::vector spatialLoops; - std::vector optimizedSpatialLoops; - Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, - optimizedSpatialLoops, nSpatialLoops); - - // 2.3 Prepare iteration arguments for spatial loop nest. - KrnlIterateOperandPack spatialPack( - rewriter, spatialLoops, optimizedSpatialLoops); + BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops); + spatialLoops.createDefineAndOptimizeOp(); for (int i = 2; i < resultShape.size(); ++i) - addDimensionToPack(rewriter, loc, spatialPack, alloc, i); + spatialLoops.pushBounds(0, alloc, i); // 2.4 Emit loop nest over output spatial dimensions. // for rX = 0 .. RX - auto spatialIterateOp = - rewriter.create(loc, spatialPack); - Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); - // 2.5 Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optSpatialLoopBlock); - rewriter.create(loc, spatialLoops); - rewriter.setInsertionPointToStart(&spatialIterationBlock); + spatialLoops.createIterateOp(); + rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock()); + { // 3. Emit the body of the spatial loop nest. // 3.1 Emit: R[n][kernel][r1][r2] = 0; SmallVector resultIndices; // n - resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); + resultIndices.emplace_back(outerLoops.getInductionVar(nIndex)); // kernel resultIndices.emplace_back(kernel); // rX - for (auto arg : spatialIterationBlock.getArguments()) + for (auto arg : spatialLoops.getIterateBlock()->getArguments()) resultIndices.emplace_back(arg); // Store initializer value into output location. rewriter.create(loc, zero, alloc, resultIndices); // 3.2 Define inner loops. int64_t nInnerLoops = 1 + (kernelShape.size() - 2); - std::vector innerLoops; - std::vector optimizedInnerLoops; - Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, - optimizedInnerLoops, nInnerLoops); - - // 3.3 Prepare iteration arguments for inner loop nest. - KrnlIterateOperandPack innerPack( - rewriter, innerLoops, optimizedInnerLoops); + BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); + innerLoops.createDefineAndOptimizeOp(); // for c = 0 .. C/group - innerPack.pushConstantBound(0); - innerPack.pushConstantBound(kernelShape[1]); + int cIndex = innerLoops.pushBounds(0, kernelShape[1]); // for Kx = 0 .. KX for (int i = 2; i < kernelShape.size(); ++i) - addDimensionToPack(rewriter, loc, innerPack, operands[1], i); + innerLoops.pushBounds(0, kernelOperand, i); // 3.4 Emit inner loop nest. - auto innerIterateOp = - rewriter.create(loc, innerPack); - Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); - // 3.5 Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optInnerLoopBlock); - rewriter.create(loc, innerLoops); - rewriter.setInsertionPointToStart(&innerIterationBlock); + innerLoops.createIterateOp(); + rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); + { // 4. Emit inner loop body // R[n][kernel][r1][r2] = @@ -217,13 +177,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { // 4.1 Prepare indices for accesing the data tensor. SmallVector dataIndices; // n - dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); + dataIndices.emplace_back(outerLoops.getInductionVar(nIndex)); // g * (C / group) + c - Value channelDepth = innerIterationBlock.getArguments()[0]; + Value channelDepth = innerLoops.getInductionVar(cIndex); if (group > 1) channelDepth = rewriter.create(loc, channelDepth, - rewriter.create(loc, subchannels, - outerIterationBlock.getArguments()[1])); + rewriter.create( + loc, subchannels, outerLoops.getInductionVar(gIndex))); dataIndices.emplace_back(channelDepth); // sX * rX + kX auto stridesAttribute = convOp.stridesAttr(); @@ -233,15 +193,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { for (auto stride : stridesAttribute.getValue()) strides.emplace_back(stride.cast().getInt()); for (int i = 0; i < kernelShape.size() - 2; ++i) { - Value spatialIndex = spatialIterationBlock.getArguments()[i]; + Value spatialIndex = spatialLoops.getInductionVar(i); // If strides are present then emit the correct access index. if (stridesAttribute && strides[i] > 1) spatialIndex = rewriter.create(loc, rewriter.create(loc, strides[i]), - spatialIterationBlock.getArguments()[i]); - dataIndices.emplace_back( - rewriter.create(loc, spatialIndex, - innerIterationBlock.getArguments()[i+1])); + spatialLoops.getInductionVar(i)); + dataIndices.emplace_back(rewriter.create( + loc, spatialIndex, innerLoops.getInductionVar(i + 1))); } // 4.2 Prepare indices for accessing the kernel tensor. @@ -249,17 +208,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { // kernel kernelIndices.emplace_back(kernel); // c - kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); + kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex)); // kX for (int i = 0; i < kernelShape.size() - 2; ++i) - kernelIndices.emplace_back( - innerIterationBlock.getArguments()[i+1]); + kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1)); // 4.3 Compute convolution. auto loadData = - rewriter.create(loc, operands[0], dataIndices); + rewriter.create(loc, inputOperand, dataIndices); auto loadKernel = - rewriter.create(loc, operands[1], kernelIndices); + rewriter.create(loc, kernelOperand, kernelIndices); auto loadPartialSum = rewriter.create(loc, alloc, resultIndices); Value result = rewriter.create(loc, loadPartialSum, diff --git a/src/dialect/krnl/krnl_helper.cpp b/src/dialect/krnl/krnl_helper.cpp index 72edb92..4f75a43 100644 --- a/src/dialect/krnl/krnl_helper.cpp +++ b/src/dialect/krnl/krnl_helper.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "src/dialect/krnl/krnl_ops.hpp" @@ -9,9 +10,8 @@ namespace onnf { using namespace mlir; -ParseResult -KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType, - Value &operand) { +ParseResult KrnlDialectOperandParser::ParseOptionalOperand( + const Type &operandType, Value &operand) { // If operand queue is empty, parse more operands and cache them. if (_operandRefQueue.empty()) { // Parse operand types: @@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType, _parser.parseOperandList(operand_refs); // Record operands: - for (auto& operand_ref : operand_refs) + for (auto &operand_ref : operand_refs) _operandRefQueue.emplace(operand_ref); } @@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand( return success(); } -ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType, - Value &operand) { +ParseResult KrnlDialectOperandParser::ParseOperand( + const Type &operandType, Value &operand) { if (ParseOptionalOperand(operandType, operand)) return _parser.emitError( _parser.getCurrentLocation(), "Expecting an operand."); @@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand( return success(); } -void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, - unsigned numSymbols, OpAsmPrinter& p) { +void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims, + unsigned numSymbols, OpAsmPrinter &p) { p << '('; p.printOperands(begin, begin + numDims); p << ')'; @@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims, } void printBound(AffineMapAttr boundMap, - Operation::operand_iterator& boundOperandsBeg, const char* prefix, - OpAsmPrinter& p) { + Operation::operand_iterator &boundOperandsBeg, const char *prefix, + OpAsmPrinter &p) { AffineMap map = boundMap.getValue(); // Check if this bound should be printed using custom assembly form. @@ -120,9 +120,10 @@ void printBound(AffineMapAttr boundMap, printDimAndSymbolList( boundOperandsBeg, map.getNumDims(), map.getNumSymbols(), p); } -} // namespace onnf +} // namespace onnf namespace mlir { + void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { if (boundMaps.size() % 2 == 0) _operands.emplace_back(inputLoops[boundMaps.size() / 2]); @@ -137,4 +138,125 @@ void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) { boundMaps.emplace_back(AffineMapAttr::get(map)); _operands.emplace_back(operand); } -} // namespace mlir + +BuildKrnlLoop::BuildKrnlLoop( + ConversionPatternRewriter &rewriter, Location loc, int loopNum) + : rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL), + pushCount(0), createdDefineOp(false), createdOptimizeOp(false), + createdIterateOp(false) { + if (originalLoopNum <= 0) + emitError(loc, "expected positive number of original loops"); +} + +BuildKrnlLoop::BuildKrnlLoop( + ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand) + : BuildKrnlLoop(rewriter, loc, + memRefOperand.getType().cast().getShape().size()) {} + +BuildKrnlLoop::~BuildKrnlLoop() { + if (!createdDefineOp) + emitError(loc, "expected to create define op"); + if (!createdIterateOp) + emitError(loc, "expected to create iteration op"); + if (pack) + free(pack); +} + +void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) { + // insert define loop op + auto loopsOp = rewriter.create(loc, originalLoopNum); + originalLoops.reserve(originalLoopNum); + for (auto result : loopsOp.getResults()) + originalLoops.push_back(result); + // inserte optimize loop op. + auto optimizedLoopsOp = + rewriter.create(loc, originalLoopNum); + optLoops.reserve(originalLoopNum); + // Emit empty optimizations + if (withEmptyOptimization) { + for (auto result : optimizedLoopsOp.getResults()) + optLoops.push_back(result); + optBlock = &optimizedLoopsOp.region().front(); + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToEnd(optBlock); + rewriter.create(loc, originalLoops); + rewriter.restoreInsertionPoint(ip); + } + // prepare data structure to push bounds + pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops); + createdOptimizeOp = true; +} + +// push bounds (lower and upper) and return index for loop info +int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) { + pack->pushConstantBound(lowerBound); + pack->pushConstantBound(upperBound); + return pushCount++; +} + +int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) { + pack->pushConstantBound(lowerBound); + pack->pushOperandBound(upperBound); + return pushCount++; +} + +int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, + int upperBoundMemRefIndex, bool upperBoundMustBeConstant) { + pack->pushConstantBound(lowerBound); + // process upperBound as a dimension of mem ref, possibly non-constant + auto shape = upperBoundMemRefOperand.getType().cast().getShape(); + if (shape[upperBoundMemRefIndex] < 0) { + if (upperBoundMustBeConstant) + emitError(loc, "bound expected to be constant"); + pack->pushOperandBound( + rewriter + .create(loc, upperBoundMemRefOperand, upperBoundMemRefIndex) + .getResult()); + } else + pack->pushConstantBound(shape[upperBoundMemRefIndex]); + return pushCount++; +} + +int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) { + pack->pushOperandBound(lowerBound); + pack->pushOperandBound(upperBound); + return pushCount++; +} + +// create iter +void BuildKrnlLoop::createIterateOp() { + if (!createdDefineOp) + emitError(loc, "must create define op before iterate op"); + // Tight now, optimize (possibly empty) is mandatory. This may change + if (!createdOptimizeOp) + emitError(loc, "must create optimize op before iterate op"); + // have to have defined all bounds + if (pushCount != originalLoopNum) { + printf(" push count %d, original loop %d\n", pushCount, originalLoopNum); + emitError(loc, "must push bounds for all original loops"); + } + // create iterate op + auto iterateOp = rewriter.create(loc, *pack); + iterBlock = &iterateOp.bodyRegion().front(); + createdIterateOp = true; +} + +void BuildKrnlLoop::createDefineOptimizeAndIterateOp( + Value memRefOperand, bool withEmptyOptimization) { + int loopNum = memRefOperand.getType().cast().getShape().size(); + if (originalLoopNum != loopNum) + emitError(loc, "mismatch in loop numbers from constructor and define"); + createDefineAndOptimizeOp(withEmptyOptimization); + for (int i = 0; i < originalLoopNum; ++i) + pushBounds(0, memRefOperand, i); + createIterateOp(); +} + +// get induction variable to be use within iter +BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) { + if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum) + emitError(loc, "original loop index is out of bound"); + return iterBlock->getArguments()[originalLoopIndex]; +} + +} // namespace mlir diff --git a/src/dialect/krnl/krnl_helper.hpp b/src/dialect/krnl/krnl_helper.hpp index 41a141b..cfe1787 100644 --- a/src/dialect/krnl/krnl_helper.hpp +++ b/src/dialect/krnl/krnl_helper.hpp @@ -8,39 +8,38 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/DialectConversion.h" namespace onnf { class KrnlDialectOperandParser { - public: - explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser) +public: + explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser) : _parser(parser), _builder(parser.getBuilder()){}; // Parse an optional operand. - mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType, - mlir::Value &operand); + mlir::ParseResult ParseOptionalOperand( + const mlir::Type &operandType, mlir::Value &operand); // Parse an optional operand and push it to an operand list. - mlir::ParseResult - ParseOptionalOperand(const mlir::Type &operandType, - llvm::SmallVectorImpl &operandList); + mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType, + llvm::SmallVectorImpl &operandList); // Parse a required operand. - mlir::ParseResult ParseOperand(const mlir::Type &operandType, - mlir::Value &operand); + mlir::ParseResult ParseOperand( + const mlir::Type &operandType, mlir::Value &operand); // Parse a required operand and push it to an operand list. - mlir::ParseResult - ParseOperand(const mlir::Type &operandType, - llvm::SmallVectorImpl &operandList); + mlir::ParseResult ParseOperand(const mlir::Type &operandType, + llvm::SmallVectorImpl &operandList); // Do we have more operands to parse? bool hasOperandLeft() { return !_operandRefQueue.empty(); } - private: - mlir::OpAsmParser& _parser; +private: + mlir::OpAsmParser &_parser; - mlir::Builder& _builder; + mlir::Builder &_builder; // A queue storing the parsed SSA id references. std::queue _operandRefQueue; @@ -50,24 +49,24 @@ class KrnlDialectOperandParser { // https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197 // Main difference is that it advances the iterator `begin` as it consumes // dimension and symbol operands. -void printDimAndSymbolList(mlir::Operation::operand_iterator& begin, - unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p); +void printDimAndSymbolList(mlir::Operation::operand_iterator &begin, + unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p); // Adapted from: // https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272 // Main difference is that it advances the iterator `boundOperandsBeg` as it // prints bound. void printBound(mlir::AffineMapAttr boundMap, - mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix, - mlir::OpAsmPrinter& p); -} // namespace onnf + mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix, + mlir::OpAsmPrinter &p); +} // namespace onnf namespace mlir { struct KrnlIterateOperandPack { KrnlIterateOperandPack(mlir::Builder &builder, - llvm::ArrayRef inputLoops, - llvm::ArrayRef optimizedLoops) + llvm::ArrayRef inputLoops, + llvm::ArrayRef optimizedLoops) : builder(builder), inputLoops(inputLoops), optimizedLoops(optimizedLoops) { _operands.insert( @@ -88,7 +87,7 @@ struct KrnlIterateOperandPack { size_t getNumInputLoops() const { return inputLoops.size(); } - private: +private: int _boundIdx = 0; llvm::SmallVector _operands; @@ -97,7 +96,99 @@ struct KrnlIterateOperandPack { llvm::ArrayRef inputLoops, optimizedLoops; - mlir::Builder& builder; + mlir::Builder &builder; }; -} // namespace mlir +// Helper function to write kernel loops. This class will let us build a single +// define/optimize/iterate operation combo. We can then insert optimizations in +// the body of the optimization operation, and operations in the body of the +// iterate operation. +// +// The sequence is as follow: +// +// 1) Create a object giving the rewriter, location, and number of loop in the +// original (non optimized) loop. +// +// 2) Create define & optimize ops (currently paired). Optimizations can then +// be added to the inner block of the optimize operation. Make sure to set the +// insertion point to that block for optimizations to go in the right place. +// +// 3) Push the bounds for each of the original loops. Bounds are pushed in +// pairs (lower & upper bounds). THere are a few methods to do it depending on +// the type of the bounds. When pushing bounds, the method returns a number +// that represent the index associated with that iteration (induction variable +// and bounds). That index can be used later to extract the induction variable +// for reference in computation and/or index calculations of mem refs. +// +// 4) Once all the bounds are pushed, create the iterate operation. Once this +// is done, we can add operations within the iterate blocks by setting the +// insertion point to it. Value of the induction variables can be retrieved +// using the proper index (determined when pushin the bounds). + +class BuildKrnlLoop { +public: + // Create a build kernel loop for the given location and loop number. + BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum); + // Do the same, but where the loop number corresponds to the dimensionality of + // the mem ref operand. + BuildKrnlLoop( + ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand); + ~BuildKrnlLoop(); + + // Create define and optimize loop with loopNum original loops. If + // withEmptyOptimization, the optimization is simply the identity function (no + // optimizations). + void createDefineAndOptimizeOp(bool withEmptyOptimization = true); + + // Push bounds (lower and upper) for each of the loops, in order. It returns + // the index associated with the loop iteration. This index is in the range + // from zero to original loop number -1, and is monotonally increasing from + // call to call. This index is later used in the getInductionVar call. + int pushBounds(int64_t lowerBound, int64_t upperBound); + int pushBounds(int64_t lowerBound, Value upperBound); + int pushBounds(Value lowerBound, Value upperBound); + // same, where the lower bound is an integer, and the uppoer bound is given by + // the size of the mem ref operand along the upperBoundMemRefIndex dimension. + int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, + int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false); + + // Create an iterate op. + void createIterateOp(); + // Create an define, optimize and iterate op, with the same loop nummber as + // the rank of the memRefOperand. The lower bound of each loops is zero, and + // the upper bound of each loops is the dimension given by the mem refs + void createDefineOptimizeAndIterateOp( + Value memRefOperand, bool withEmptyOptimization = true); + + // Get the (original loop) induction variable associated with the given index. + // Use the index returned when pushing the bounds. + BlockArgument &getInductionVar(int originalLoopIndex); + + // Get blocks. This allow us to set the insertion point to the inner block of + // the optimize and the iterate Operation + Block *getOptimizationBlock() { return optBlock; } + Block *getIterateBlock() { return iterBlock; } + + // get original or optimized loops + std::vector &getOriginalLoops() { return originalLoops; } + std::vector &getOptimizedLoops() { return optLoops; } + +private: + // inputs + ConversionPatternRewriter &rewriter; + Location loc; + int originalLoopNum; + // track loops and bounds + std::vector originalLoops; + std::vector optLoops; + KrnlIterateOperandPack *pack; + int pushCount; + bool createdDefineOp; + bool createdOptimizeOp; + bool createdIterateOp; + // insertion points (opt block, iterate) + Block *optBlock; + Block *iterBlock; +}; + +} // namespace mlir