Clean-up code. (#98)

This commit is contained in:
Gheorghe-Teodor Bercea 2020-02-25 09:54:29 -05:00 committed by GitHub
parent 0d307d1183
commit 32f08bcf0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 90 additions and 57 deletions

View File

@ -131,7 +131,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
boundMaps.emplace_back(AffineMapAttr::get(map)); boundMaps.emplace_back(AffineMapAttr::get(map));
} }
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) { void KrnlIterateOperandPack::pushOperandBound(Value operand) {
if (boundMaps.size() % 2 == 0) if (boundMaps.size() % 2 == 0)
_operands.emplace_back(inputLoops[boundMaps.size() / 2]); _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
AffineMap map = builder.getSymbolIdentityMap(); AffineMap map = builder.getSymbolIdentityMap();
@ -145,7 +145,7 @@ BuildKrnlLoop::BuildKrnlLoop(
pushCount(0), createdDefineOp(false), createdOptimizeOp(false), pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
createdIterateOp(false) { createdIterateOp(false) {
if (originalLoopNum <= 0) if (originalLoopNum <= 0)
emitError(loc, "expected positive number of original loops"); emitError(loc, "Expected positive number of original loops.");
} }
BuildKrnlLoop::BuildKrnlLoop( BuildKrnlLoop::BuildKrnlLoop(
@ -154,27 +154,24 @@ BuildKrnlLoop::BuildKrnlLoop(
memRefOperand.getType().cast<MemRefType>().getShape().size()) {} memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
BuildKrnlLoop::~BuildKrnlLoop() { BuildKrnlLoop::~BuildKrnlLoop() {
if (!createdDefineOp)
emitError(loc, "expected to create define op");
if (!createdIterateOp)
emitError(loc, "expected to create iteration op");
if (pack) if (pack)
free(pack); free(pack);
} }
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) { void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
// insert define loop op // Insert define loop operation.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum); auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
originalLoops.reserve(originalLoopNum); originalLoops.reserve(originalLoopNum);
for (auto result : loopsOp.getResults()) for (auto result : loopsOp.getResults())
originalLoops.push_back(result); originalLoops.push_back(result);
createdDefineOp = true; createdDefineOp = true;
// inserte optimize loop op. // Insert optimize loop operation.
auto optimizedLoopsOp = auto optimizedLoopsOp =
rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum); rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
optLoops.reserve(originalLoopNum); optLoops.reserve(originalLoopNum);
// Emit empty optimizations
// Emit empty optimizations if flag is set.
if (withEmptyOptimization) { if (withEmptyOptimization) {
for (auto result : optimizedLoopsOp.getResults()) for (auto result : optimizedLoopsOp.getResults())
optLoops.push_back(result); optLoops.push_back(result);
@ -190,7 +187,6 @@ void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops); pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
} }
// push bounds (lower and upper) and return index for loop info
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) { int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
pack->pushConstantBound(lowerBound); pack->pushConstantBound(lowerBound);
pack->pushConstantBound(upperBound); pack->pushConstantBound(upperBound);
@ -206,17 +202,20 @@ int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant) { int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
pack->pushConstantBound(lowerBound); pack->pushConstantBound(lowerBound);
// process upperBound as a dimension of mem ref, possibly non-constant
// Process upperBound as a dimension of the MemRef. Non-constant dimensions
// are supported.
auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape(); auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
if (shape[upperBoundMemRefIndex] < 0) { if (shape[upperBoundMemRefIndex] < 0) {
if (upperBoundMustBeConstant) if (upperBoundMustBeConstant)
emitError(loc, "bound expected to be constant"); emitError(loc, "Bound expected to be constant.");
pack->pushOperandBound( pack->pushOperandBound(
rewriter rewriter
.create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex) .create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
.getResult()); .getResult());
} else } else
pack->pushConstantBound(shape[upperBoundMemRefIndex]); pack->pushConstantBound(shape[upperBoundMemRefIndex]);
return pushCount++; return pushCount++;
} }
@ -226,19 +225,20 @@ int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
return pushCount++; return pushCount++;
} }
// create iter
void BuildKrnlLoop::createIterateOp() { void BuildKrnlLoop::createIterateOp() {
// Loop definition operation is mandatory.
if (!createdDefineOp) if (!createdDefineOp)
emitError(loc, "must create define op before iterate op"); emitError(loc, "Must create define op before iterate op.");
// Tight now, optimize (possibly empty) is mandatory. This may change
// Loop optimization operation is mandatory (for now).
if (!createdOptimizeOp) if (!createdOptimizeOp)
emitError(loc, "must create optimize op before iterate op"); emitError(loc, "Must create optimize op before iterate op.");
// have to have defined all bounds
if (pushCount != originalLoopNum) { // Check if all bounds have been defined.
printf(" push count %d, original loop %d\n", pushCount, originalLoopNum); if (pushCount != originalLoopNum)
emitError(loc, "must push bounds for all original loops"); emitError(loc, "Must push bounds for all original loops.");
}
// create iterate op // Emit iteration operation.
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack); auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
iterBlock = &iterateOp.bodyRegion().front(); iterBlock = &iterateOp.bodyRegion().front();
createdIterateOp = true; createdIterateOp = true;
@ -246,19 +246,27 @@ void BuildKrnlLoop::createIterateOp() {
void BuildKrnlLoop::createDefineOptimizeAndIterateOp( void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization) { Value memRefOperand, bool withEmptyOptimization) {
// Rank of the MemRef operand. We will emit a loop for each dimension.
int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size(); int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
if (originalLoopNum != loopNum) if (originalLoopNum != loopNum)
emitError(loc, "mismatch in loop numbers from constructor and define"); emitError(loc, "Mismatch in loop numbers from constructor and define.");
// Emit the definition and the optimization operations for the loop nest.
createDefineAndOptimizeOp(withEmptyOptimization); createDefineAndOptimizeOp(withEmptyOptimization);
// Push a lower-upper bound pair for each dimension of the MemRef operand.
// The lower bound in this case is always zero.
for (int i = 0; i < originalLoopNum; ++i) for (int i = 0; i < originalLoopNum; ++i)
pushBounds(0, memRefOperand, i); pushBounds(0, memRefOperand, i);
// Emit the iteration operation over the current loop nest.
createIterateOp(); createIterateOp();
} }
// get induction variable to be use within iter
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) { BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
// Check if loop iteration variable is within bounds.
if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum) if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
emitError(loc, "original loop index is out of bound"); emitError(loc, "Original loop index is out of bounds.");
return iterBlock->getArguments()[originalLoopIndex]; return iterBlock->getArguments()[originalLoopIndex];
} }

View File

@ -106,19 +106,21 @@ private:
// //
// The sequence is as follow: // The sequence is as follow:
// //
// 1) Create a object giving the rewriter, location, and number of loop in the // 1) Create an object giving the rewriter, location, and number of loop in
// original (non optimized) loop. // the original (non optimized) loop.
// //
// 2) Create define & optimize ops (currently paired). Optimizations can then // 2) Create define & optimize ops (currently paired). Optimizations can then
// be added to the inner block of the optimize operation. Make sure to set the // be added to the inner block of the optimize operation. Make sure to set
// insertion point to that block for optimizations to go in the right place. // the insertion point to that block for optimizations to go in the right
// place.
// //
// 3) Push the bounds for each of the original loops. Bounds are pushed in // 3) Push the bounds for each of the original loops. Bounds are pushed in
// pairs (lower & upper bounds). THere are a few methods to do it depending on // pairs (lower & upper bounds). There are a few methods to do it depending
// the type of the bounds. When pushing bounds, the method returns a number // on the type of the bounds. When pushing bounds, the method returns a
// that represent the index associated with that iteration (induction variable // number that represent the index associated with that iteration (induction
// and bounds). That index can be used later to extract the induction variable // variable and bounds). That index can be used later to extract the
// for reference in computation and/or index calculations of mem refs. // induction variable for reference in computation and/or index calculations
// of mem refs.
// //
// 4) Once all the bounds are pushed, create the iterate operation. Once this // 4) Once all the bounds are pushed, create the iterate operation. Once this
// is done, we can add operations within the iterate blocks by setting the // is done, we can add operations within the iterate blocks by setting the
@ -127,67 +129,90 @@ private:
class BuildKrnlLoop { class BuildKrnlLoop {
public: public:
// Create a build kernel loop for the given location and loop number. // Create kernel loop builder for a loop nest of depth loopNum.
BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum); BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
// Do the same, but where the loop number corresponds to the dimensionality of
// the mem ref operand. // Create kernel loop builder for a loop nest of depth equal to the
// dimensionality of the operand. An operand of MemRef type is requied.
BuildKrnlLoop( BuildKrnlLoop(
ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand); ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
~BuildKrnlLoop(); ~BuildKrnlLoop();
// Create define and optimize loop with loopNum original loops. If // Create define and optimize loop with loopNum original loops. If
// withEmptyOptimization, the optimization is simply the identity function (no // withEmptyOptimization is true, the optimization is simply the identity
// optimizations). // function (no optimizations).
void createDefineAndOptimizeOp(bool withEmptyOptimization = true); void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
// Push bounds (lower and upper) for each of the loops, in order. It returns // Push bounds (lower and upper) for each of the loops (order matters).
// the index associated with the loop iteration. This index is in the range // The function returns the order number associated with the loop iteration.
// from zero to original loop number -1, and is monotonally increasing from // This index is used by the getInductionVar call. Non-constant operands
// call to call. This index is later used in the getInductionVar call. // must be of MemRef type.
int pushBounds(int64_t lowerBound, int64_t upperBound); int pushBounds(int64_t lowerBound, int64_t upperBound);
int pushBounds(int64_t lowerBound, Value upperBound); int pushBounds(int64_t lowerBound, Value upperBound);
int pushBounds(Value lowerBound, Value upperBound); int pushBounds(Value lowerBound, Value upperBound);
// same, where the lower bound is an integer, and the uppoer bound is given by
// the size of the mem ref operand along the upperBoundMemRefIndex dimension.
int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false); int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
// Create an iterate op. // Create the KrnlIterateOp assiciated with this loop nest. The loops
// iteration will be created if the definition and the optimization
// operations associated with this loop nest have been emitted already.
void createIterateOp(); void createIterateOp();
// Create an define, optimize and iterate op, with the same loop nummber as
// the rank of the memRefOperand. The lower bound of each loops is zero, and // Create the loop nest definition, optimization and iteration operations
// the upper bound of each loops is the dimension given by the mem refs // for a given operand of MemRef type. The loop nest has a depth equal to the
// rank of the MemRef operand. The lower bound of each loop is zero. The
// upper bound of each loop is given by the corresponding dimension of the
// MemRef operand.
void createDefineOptimizeAndIterateOp( void createDefineOptimizeAndIterateOp(
Value memRefOperand, bool withEmptyOptimization = true); Value memRefOperand, bool withEmptyOptimization = true);
// Get the (original loop) induction variable associated with the given index. // Get the (original loop) induction variable associated with the given
// Use the index returned when pushing the bounds. // index. Use the index returned when pushing the bounds.
BlockArgument &getInductionVar(int originalLoopIndex); BlockArgument &getInductionVar(int originalLoopIndex);
// Get blocks. This allow us to set the insertion point to the inner block of // Get a reference to the code region of the optimization operation.
// the optimize and the iterate Operation // This allows us to set the insertion point to the inner block of the
// loop nest optimization operation.
Block *getOptimizationBlock() { return optBlock; } Block *getOptimizationBlock() { return optBlock; }
// Get a reference to the code region of the iteration operation.
// This allows us to set the insertion point to the inner block of the
// loop nest iteration operation.
Block *getIterateBlock() { return iterBlock; } Block *getIterateBlock() { return iterBlock; }
// get original or optimized loops // Get original loop nest.
std::vector<Value> &getOriginalLoops() { return originalLoops; } std::vector<Value> &getOriginalLoops() { return originalLoops; }
// Get optimized loop nest.
std::vector<Value> &getOptimizedLoops() { return optLoops; } std::vector<Value> &getOptimizedLoops() { return optLoops; }
private: private:
// inputs // Required for emitting operations.
ConversionPatternRewriter &rewriter; ConversionPatternRewriter &rewriter;
Location loc; Location loc;
int originalLoopNum; int originalLoopNum;
// track loops and bounds
// List of original, un-optimized loops.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
// List of optimized loops.
std::vector<Value> optLoops; std::vector<Value> optLoops;
// List of lower-upper bound pairs needed by the KrnlIterateOp.
KrnlIterateOperandPack *pack; KrnlIterateOperandPack *pack;
// Number of lower-upper bound pairs pushed.
int pushCount; int pushCount;
// Flags that keep track of emitted operations.
bool createdDefineOp; bool createdDefineOp;
bool createdOptimizeOp; bool createdOptimizeOp;
bool createdIterateOp; bool createdIterateOp;
// insertion points (opt block, iterate)
// Saved insertion point in the code region of the KrnlOptimizeLoopsOp.
Block *optBlock; Block *optBlock;
// Saved insertion point in the code region of the KrnlIterateOp.
Block *iterBlock; Block *iterBlock;
}; };