CuPBoP/compilation/KernelTranslation/lib/insert_warp_loop.cpp

849 lines
30 KiB
C++
Raw Normal View History

#include "insert_warp_loop.h"
#include "handle_sync.h"
#include "tool.h"
#include <assert.h>
#include <iostream>
#include <set>
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ValueSymbolTable.h"
#include "llvm/InitializePasses.h"
#include "llvm/PassInfo.h"
#include "llvm/PassRegistry.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <map>
#include <set>
#include <sstream>
#include <tuple>
#include <vector>
using namespace llvm;
struct ParallelRegion {
std::set<llvm::BasicBlock *> wrapped_block;
llvm::BasicBlock *successor_block;
llvm::BasicBlock *start_block;
llvm::BasicBlock *end_block;
bool inst_in_region(llvm::Instruction *inst) {
for (auto bb : wrapped_block) {
if (inst->getParent()->getName().str() == bb->getName().str())
return true;
}
return false;
}
bool inst_used_in_region(llvm::Instruction *inst) {
for (auto ui = inst->use_begin(); ui != inst->use_end(); ++ui) {
auto *user = dyn_cast<Instruction>(ui->getUser());
if (user == NULL)
continue;
if (inst_in_region(user)) {
return 1;
}
}
return 0;
}
};
std::map<llvm::Instruction *, unsigned> tempInstructionIds;
std::map<std::string, llvm::Instruction *> contextArrays;
int tempInstructionIndex = 0;
int need_nested_loop;
bool ShouldNotBeContextSaved(llvm::Instruction *instr) {
if (isa<BranchInst>(instr))
return true;
llvm::Module *M = instr->getParent()->getParent()->getParent();
llvm::LoadInst *load = dyn_cast<llvm::LoadInst>(instr);
if (load != NULL) {
auto load_addr = load->getPointerOperand();
if (load_addr == M->getGlobalVariable("intra_warp_index"))
return true;
if (load_addr == M->getGlobalVariable("inter_warp_index"))
return true;
if (load_addr == M->getGlobalVariable("warp_vote"))
return true;
}
// TODO: we should further analyze whether the local variable
// is same among all threads within a wrap
return false;
}
// generate countpart alloc in the beginning of the Function
llvm::Instruction *GetContextArray(llvm::Instruction *instruction,
bool intra_warp_loop) {
std::ostringstream var;
if (std::string(instruction->getName().str()) != "") {
var << instruction->getName().str();
} else if (tempInstructionIds.find(instruction) != tempInstructionIds.end()) {
var << tempInstructionIds[instruction];
} else {
tempInstructionIds[instruction] = tempInstructionIndex++;
var << tempInstructionIds[instruction];
}
if (intra_warp_loop)
var << "_intra_warp_";
else
var << "_inter_warp_";
std::string varName = var.str();
if (contextArrays.find(varName) != contextArrays.end())
return contextArrays[varName];
BasicBlock &bb = instruction->getParent()->getParent()->getEntryBlock();
IRBuilder<> builder(&*(bb.getFirstInsertionPt()));
Function *FF = instruction->getParent()->getParent();
Module *M = instruction->getParent()->getParent()->getParent();
LLVMContext &C = M->getContext();
const llvm::DataLayout &Layout = M->getDataLayout();
llvm::Type *elementType;
if (isa<AllocaInst>(instruction)) {
elementType =
dyn_cast<AllocaInst>(instruction)->getType()->getElementType();
} else {
elementType = instruction->getType();
}
Type *AllocType = elementType;
AllocaInst *InstCast = dyn_cast<AllocaInst>(instruction);
if (InstCast) {
unsigned Alignment = InstCast->getAlignment();
uint64_t StoreSize = Layout.getTypeStoreSize(InstCast->getAllocatedType());
if ((Alignment > 1) && (StoreSize & (Alignment - 1))) {
uint64_t AlignedSize = (StoreSize & (~(Alignment - 1))) + Alignment;
assert(AlignedSize > StoreSize);
uint64_t RequiredExtraBytes = AlignedSize - StoreSize;
if (isa<ArrayType>(elementType)) {
ArrayType *StructPadding = ArrayType::get(
Type::getInt8Ty(M->getContext()), RequiredExtraBytes);
std::vector<Type *> PaddedStructElements;
PaddedStructElements.push_back(elementType);
PaddedStructElements.push_back(StructPadding);
const ArrayRef<Type *> NewStructElements(PaddedStructElements);
AllocType = StructType::get(M->getContext(), NewStructElements, true);
uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType);
assert(NewStoreSize == AlignedSize);
} else if (isa<StructType>(elementType)) {
StructType *OldStruct = dyn_cast<StructType>(elementType);
ArrayType *StructPadding = ArrayType::get(
Type::getInt8Ty(M->getContext()), RequiredExtraBytes);
std::vector<Type *> PaddedStructElements;
for (unsigned j = 0; j < OldStruct->getNumElements(); j++)
PaddedStructElements.push_back(OldStruct->getElementType(j));
PaddedStructElements.push_back(StructPadding);
const ArrayRef<Type *> NewStructElements(PaddedStructElements);
AllocType = StructType::get(OldStruct->getContext(), NewStructElements,
OldStruct->isPacked());
uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType);
assert(NewStoreSize == AlignedSize);
}
}
}
llvm::Value *ItemSize = nullptr;
llvm::AllocaInst *Alloca = nullptr;
auto block_size_addr = M->getGlobalVariable("block_size");
auto block_size = builder.CreateLoad(block_size_addr);
Alloca = builder.CreateAlloca(AllocType, block_size, varName);
contextArrays[varName] = Alloca;
return Alloca;
}
// save the local variable into replicated array
llvm::Instruction *AddContextSave(llvm::Instruction *instruction,
llvm::Instruction *alloca,
bool intra_warp_loop) {
if (isa<AllocaInst>(instruction)) {
return NULL;
}
llvm::Module *M = instruction->getParent()->getParent()->getParent();
LLVMContext &context = M->getContext();
auto I32 = llvm::Type::getInt32Ty(context);
/* Save the produced variable to the array. */
BasicBlock::iterator definition =
(dyn_cast<Instruction>(instruction))->getIterator();
++definition;
IRBuilder<> builder(&*definition);
std::vector<llvm::Value *> gepArgs;
auto inter_warp_index =
builder.CreateLoad(M->getGlobalVariable("inter_warp_index"));
auto intra_warp_index =
builder.CreateLoad(M->getGlobalVariable("intra_warp_index"));
auto thread_idx = builder.CreateBinOp(
Instruction::Add, intra_warp_index,
builder.CreateBinOp(Instruction::Mul, inter_warp_index,
ConstantInt::get(I32, 32)),
"thread_idx");
gepArgs.push_back(thread_idx);
return builder.CreateStore(instruction, builder.CreateGEP(alloca, gepArgs));
}
llvm::Instruction *AddContextRestore(llvm::Value *val,
llvm::Instruction *alloca,
llvm::Instruction *before, bool isAlloca,
bool intra_warp_loop) {
assert(val != NULL);
assert(alloca != NULL);
IRBuilder<> builder(alloca);
if (before != NULL) {
builder.SetInsertPoint(before);
} else if (isa<Instruction>(val)) {
builder.SetInsertPoint(dyn_cast<Instruction>(val));
before = dyn_cast<Instruction>(val);
} else {
assert(false && "Unknown context restore location!");
}
std::vector<llvm::Value *> gepArgs;
auto M = before->getParent()->getParent()->getParent();
auto I32 = llvm::Type::getInt32Ty(M->getContext());
auto inter_warp_index =
builder.CreateLoad(M->getGlobalVariable("inter_warp_index"));
auto intra_warp_index =
builder.CreateLoad(M->getGlobalVariable("intra_warp_index"));
auto thread_idx = builder.CreateBinOp(
Instruction::Add, intra_warp_index,
builder.CreateBinOp(Instruction::Mul, inter_warp_index,
ConstantInt::get(I32, 32)),
"thread_idx");
gepArgs.push_back(thread_idx);
llvm::Instruction *gep =
dyn_cast<Instruction>(builder.CreateGEP(alloca, gepArgs));
if (isAlloca) {
return gep;
}
return builder.CreateLoad(gep);
}
void AddContextSaveRestore(llvm::Instruction *instruction,
bool intra_warp_loop) {
/* Allocate the context data array for the variable. */
llvm::Instruction *alloca = GetContextArray(instruction, intra_warp_loop);
llvm::Instruction *theStore =
AddContextSave(instruction, alloca, intra_warp_loop);
std::vector<Instruction *> uses;
for (Instruction::use_iterator ui = instruction->use_begin(),
ue = instruction->use_end();
ui != ue; ++ui) {
llvm::Instruction *user = cast<Instruction>(ui->getUser());
if (user == NULL)
continue;
if (user == theStore)
continue;
uses.push_back(user);
}
for (auto user : uses) {
Instruction *contextRestoreLocation = user;
llvm::Value *loadedValue =
AddContextRestore(user, alloca, contextRestoreLocation,
isa<AllocaInst>(instruction), intra_warp_loop);
user->replaceUsesOfWith(instruction, loadedValue);
}
}
void handle_alloc(llvm::Function *F) {
auto M = F->getParent();
LLVMContext &C = M->getContext();
auto I32 = llvm::Type::getInt32Ty(C);
std::vector<llvm::Instruction *> instruction_to_fix;
for (auto bb = F->begin(); bb != F->end(); bb++) {
for (auto ii = bb->begin(); ii != bb->end(); ii++) {
if (llvm::AllocaInst *i = dyn_cast<AllocaInst>(ii)) {
instruction_to_fix.push_back(i);
}
}
}
std::vector<llvm::Instruction *> need_remove;
for (auto inst : instruction_to_fix) {
// generate a new alloc
auto block_size_addr = M->getGlobalVariable("block_size");
IRBuilder<> builder(inst);
auto block_size = builder.CreateLoad(block_size_addr);
llvm::Type *elementType = NULL;
if (dyn_cast<AllocaInst>(inst)->getType()->getElementType()) {
elementType = dyn_cast<AllocaInst>(inst)->getType()->getElementType();
}
assert(elementType != NULL);
auto Alloca = builder.CreateAlloca(elementType, block_size,
inst->getName().str() + "inter_warp");
// replace all usage
std::set<Instruction *> replace_user;
for (Instruction::use_iterator ui = inst->use_begin(), ue = inst->use_end();
ui != ue; ++ui) {
replace_user.insert(dyn_cast<Instruction>(ui->getUser()));
}
for (auto user : replace_user) {
IRBuilder<> builder(user);
// std::vector<llvm::Value *> gepArgs;
auto inter_warp_index =
builder.CreateLoad(M->getGlobalVariable("inter_warp_index"));
auto intra_warp_index =
builder.CreateLoad(M->getGlobalVariable("intra_warp_index"));
auto thread_idx = builder.CreateBinOp(
Instruction::Add, intra_warp_index,
builder.CreateBinOp(Instruction::Mul, inter_warp_index,
ConstantInt::get(I32, 32)),
"thread_idx");
auto gep = builder.CreateGEP(Alloca, thread_idx);
user->replaceUsesOfWith(inst, gep);
}
need_remove.push_back(inst);
}
for (auto inst : need_remove) {
inst->dropAllReferences();
inst->eraseFromParent();
}
}
void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs) {
bool intra_warp_loop = 1;
// we should handle allocation generated by PHI
{
std::vector<llvm::Instruction *> instruction_to_fix;
auto F = PRs[0].start_block->getParent();
for (auto bb = F->begin(); bb != F->end(); bb++) {
for (auto ii = bb->begin(); ii != bb->end(); ii++) {
if (isa<AllocaInst>(&(*ii)))
instruction_to_fix.push_back(&(*ii));
}
for (auto inst : instruction_to_fix) {
AddContextSaveRestore(inst, intra_warp_loop);
}
}
}
for (auto parallel_regions : PRs) {
std::set<llvm::Instruction *> instruction_in_region;
std::vector<llvm::Instruction *> instruction_to_fix;
for (auto bb : parallel_regions.wrapped_block) {
for (llvm::BasicBlock::iterator instr = bb->begin(); instr != bb->end();
++instr) {
llvm::Instruction *instruction = &*instr;
instruction_in_region.insert(instruction);
}
}
/* Find all the instructions that define new values and
check if they need to be context saved. */
for (auto bb : parallel_regions.wrapped_block) {
for (llvm::BasicBlock::iterator instr = bb->begin(); instr != bb->end();
++instr) {
llvm::Instruction *instruction = &*instr;
if (ShouldNotBeContextSaved(instruction))
continue;
for (Instruction::use_iterator ui = instruction->use_begin(),
ue = instruction->use_end();
ui != ue; ++ui) {
llvm::Instruction *user = dyn_cast<Instruction>(ui->getUser());
if (user == NULL)
continue;
if (isa<AllocaInst>(instruction) ||
(instruction_in_region.find(user) ==
instruction_in_region.end())) {
instruction_to_fix.push_back(instruction);
break;
}
}
}
}
for (auto inst : instruction_to_fix) {
AddContextSaveRestore(inst, intra_warp_loop);
}
}
}
BasicBlock *insert_loop_init(llvm::BasicBlock *InsertInitBefore,
bool intra_warp_loop) {
llvm::Module *M = InsertInitBefore->getParent()->getParent();
LLVMContext &context = M->getContext();
auto I32 = llvm::Type::getInt32Ty(context);
std::string block_name =
(intra_warp_loop) ? "intra_warp_init" : "inter_warp_init";
BasicBlock *loop_init = BasicBlock::Create(
context, block_name, InsertInitBefore->getParent(), InsertInitBefore);
IRBuilder<> builder(context);
builder.SetInsertPoint(loop_init);
if (intra_warp_loop) { // intra warp
auto intra_warp_index = M->getGlobalVariable("intra_warp_index");
builder.CreateStore(ConstantInt::get(I32, 0), intra_warp_index);
} else { // inter warp
auto inter_warp_index = M->getGlobalVariable("inter_warp_index");
builder.CreateStore(ConstantInt::get(I32, 0), inter_warp_index);
}
builder.CreateBr(InsertInitBefore);
return loop_init;
}
BasicBlock *insert_loop_cond(llvm::BasicBlock *InsertCondBefore,
llvm::BasicBlock *LoopEnd, bool intra_warp_loop) {
llvm::Module *M = InsertCondBefore->getParent()->getParent();
LLVMContext &context = M->getContext();
auto I32 = llvm::Type::getInt32Ty(context);
std::string block_name =
(intra_warp_loop) ? "intra_warp_cond" : "inter_warp_cond";
BasicBlock *loop_cond = BasicBlock::Create(
context, block_name, InsertCondBefore->getParent(), InsertCondBefore);
IRBuilder<> builder(context);
builder.SetInsertPoint(loop_cond);
llvm::Value *cmpResult = NULL;
if (!intra_warp_loop) {
auto inter_warp_index = M->getGlobalVariable("inter_warp_index");
auto block_size = M->getGlobalVariable("block_size");
auto warp_cnt =
builder.CreateBinOp(Instruction::SDiv, builder.CreateLoad(block_size),
ConstantInt::get(I32, 32), "warp_number");
cmpResult =
builder.CreateICmpULT(builder.CreateLoad(inter_warp_index), warp_cnt);
} else {
auto intra_warp_index = M->getGlobalVariable("intra_warp_index");
auto block_size = M->getGlobalVariable("block_size");
if (!need_nested_loop) {
cmpResult = builder.CreateICmpULT(builder.CreateLoad(intra_warp_index),
builder.CreateLoad(block_size));
} else {
cmpResult = builder.CreateICmpULT(builder.CreateLoad(intra_warp_index),
ConstantInt::get(I32, 32));
}
}
builder.CreateCondBr(cmpResult, InsertCondBefore, LoopEnd);
return loop_cond;
}
BasicBlock *insert_loop_inc(llvm::BasicBlock *InsertIncBefore,
bool intra_warp_loop) {
llvm::Module *M = InsertIncBefore->getParent()->getParent();
LLVMContext &context = M->getContext();
auto I32 = llvm::Type::getInt32Ty(context);
std::string block_name =
(intra_warp_loop) ? "intra_warp_inc" : "inter_warp_inc";
BasicBlock *loop_inc = BasicBlock::Create(
context, block_name, InsertIncBefore->getParent(), InsertIncBefore);
IRBuilder<> builder(context);
builder.SetInsertPoint(loop_inc);
if (intra_warp_loop) { // intra warp
auto intra_warp_index = M->getGlobalVariable("intra_warp_index");
auto new_index = builder.CreateBinOp(
Instruction::Add, builder.CreateLoad(intra_warp_index),
ConstantInt::get(I32, 1), "intra_warp_index_increment");
builder.CreateStore(new_index, intra_warp_index);
} else { // inter warp
auto inter_warp_index = M->getGlobalVariable("inter_warp_index");
auto new_index = builder.CreateBinOp(
Instruction::Add, builder.CreateLoad(inter_warp_index),
ConstantInt::get(I32, 1), "inter_warp_index_increment");
builder.CreateStore(new_index, inter_warp_index);
}
builder.CreateBr(InsertIncBefore);
return loop_inc;
}
void add_warp_loop(std::vector<ParallelRegion> parallel_regions,
bool intra_warp_loop) {
for (auto region : parallel_regions) {
auto start_block = region.start_block;
auto tail_block = region.end_block;
auto next_block = region.successor_block;
auto loop_cond = insert_loop_cond(start_block, next_block, intra_warp_loop);
auto loop_init = insert_loop_init(loop_cond, intra_warp_loop);
auto F = start_block->getParent();
for (Function::iterator i = F->begin(); i != F->end(); ++i) {
llvm::BasicBlock *bb = &(*i);
if (bb == loop_cond)
continue;
bb->getTerminator()->replaceUsesOfWith(start_block, loop_init);
}
auto loop_inc = insert_loop_inc(loop_cond, intra_warp_loop);
tail_block->getTerminator()->replaceUsesOfWith(next_block, loop_inc);
// we have to reset inter/intra warp index to 0, as these maybe used
// outside PR when there are conditional loop/branch
llvm::Module *M = start_block->getParent()->getParent();
LLVMContext &context = M->getContext();
auto I32 = llvm::Type::getInt32Ty(context);
BasicBlock *reset_index = BasicBlock::Create(start_block->getContext(),
"reset_block", F, next_block);
IRBuilder<> builder(start_block->getContext());
builder.SetInsertPoint(reset_index);
if (intra_warp_loop) { // intra warp
auto intra_warp_index = M->getGlobalVariable("intra_warp_index");
builder.CreateStore(ConstantInt::get(I32, 0), intra_warp_index);
} else { // inter warp
auto inter_warp_index = M->getGlobalVariable("inter_warp_index");
builder.CreateStore(ConstantInt::get(I32, 0), inter_warp_index);
}
builder.CreateBr(next_block);
loop_cond->getTerminator()->replaceUsesOfWith(next_block, reset_index);
// add metadata
MDNode *Dummy =
MDNode::getTemporary(context, ArrayRef<Metadata *>()).release();
MDNode *AccessGroupMD = MDNode::getDistinct(context, {});
MDNode *ParallelAccessMD = MDNode::get(
context,
{MDString::get(context, "llvm.loop.parallel_accesses"), AccessGroupMD});
MDNode *Root = MDNode::get(context, {Dummy, ParallelAccessMD});
Root->replaceOperandWith(0, Root);
MDNode::deleteTemporary(Dummy);
// We now have
// !1 = metadata !{metadata !1} <- self-referential root
loop_cond->getTerminator()->setMetadata("llvm.loop", Root);
for (auto bb : region.wrapped_block) {
for (BasicBlock::iterator ii = bb->begin(), ee = bb->end(); ii != ee;
ii++) {
if (!ii->mayReadOrWriteMemory()) {
continue;
}
MDNode *NewMD = MDNode::get(bb->getContext(), AccessGroupMD);
MDNode *OldMD = ii->getMetadata("llvm.mem.parallel_loop_access");
if (OldMD != nullptr) {
NewMD = llvm::MDNode::concatenate(OldMD, NewMD);
}
ii->setMetadata("llvm.mem.parallel_loop_access", NewMD);
}
}
}
}
void print_parallel_region(std::vector<ParallelRegion> parallel_regions) {
printf("get PR:\n");
for (auto region : parallel_regions) {
auto start = region.start_block;
auto end = region.end_block;
auto next = region.successor_block;
printf("parallel region: %s->%s next: %s\n", start->getName().str().c_str(),
end->getName().str().c_str(), next->getName().str().c_str());
printf("have: \n");
for (auto b : region.wrapped_block) {
printf("%s\n", b->getName().str().c_str());
}
}
}
void remove_barrier(llvm::Function *F, bool intra_warp_loop) {
std::vector<Instruction *> need_remove;
for (auto BB = F->begin(); BB != F->end(); ++BB) {
for (auto BI = BB->begin(); BI != BB->end(); BI++) {
if (auto Call = dyn_cast<CallInst>(BI)) {
auto func_name = Call->getCalledFunction()->getName().str();
if (func_name == "llvm.nvvm.bar.warp.sync") {
need_remove.push_back(Call);
}
if (!intra_warp_loop && (func_name == "llvm.nvvm.barrier0" ||
func_name == "llvm.nvvm.barrier.sync")) {
need_remove.push_back(Call);
}
}
}
}
for (auto inst : need_remove) {
inst->eraseFromParent();
}
}
class InsertWarpLoopPass : public llvm::FunctionPass {
public:
static char ID;
bool intra_warp_loop;
DominatorTree *DT;
PostDominatorTree *PDT;
InsertWarpLoopPass(bool intra_warp = 0)
: FunctionPass(ID), intra_warp_loop(intra_warp) {}
virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<PostDominatorTreeWrapperPass>();
}
void getParallelRegionBefore(llvm::BasicBlock *B, bool intra_warp_loop,
std::vector<ParallelRegion> &parallel_regions) {
ParallelRegion current_region;
SmallVector<BasicBlock *, 4> pending_blocks;
BasicBlock *region_entry_barrier = NULL;
BasicBlock *entry = NULL;
BasicBlock *exit = B->getSinglePredecessor();
for (BasicBlock *Pred : predecessors(B)) {
pending_blocks.push_back(Pred);
}
if (pending_blocks.size() > 1) {
// becuase we have insert the sync and split by them,
// so if B has several income edges, it must be a merge point
// for a conditional if. We can safely ignore it
// TODO: we have to further check whether this conditional if
// is for inter warp or intra warp
return;
}
while (!pending_blocks.empty()) {
BasicBlock *current = pending_blocks.back();
pending_blocks.pop_back();
// avoid infinite recursion of loops
if (current_region.wrapped_block.count(current) != 0) {
continue;
}
// If we reach another barrier this must be the
// parallel region entry.
bool has_barrier = 0;
for (auto i = current->begin(), e = current->end(); i != e; ++i) {
if (llvm::CallInst *call_inst = llvm::dyn_cast<llvm::CallInst>(&(*i))) {
auto func_name = call_inst->getCalledFunction()->getName().str();
if (func_name == "llvm.nvvm.barrier0" ||
func_name == "llvm.nvvm.barrier.sync")
has_barrier = 1;
if (func_name == "llvm.nvvm.bar.warp.sync" && intra_warp_loop)
has_barrier = 1;
}
}
// if we reach a block which only has a single condtional branch,
// it is the start point of a B-condition, we have to stop here
bool is_single_conditional_branch_block = 0;
if (auto br = dyn_cast<llvm::BranchInst>(current->getTerminator())) {
if (br->isConditional()) {
if (current->size() == 1) {
is_single_conditional_branch_block = 1;
} else {
// generate by replicate local variable
printf(
"[WARNING] match single conditional branch with HARD CODE\n");
bool branch_to_intra_init = false;
for (unsigned suc = 0; suc < br->getNumSuccessors(); ++suc) {
llvm::BasicBlock *entryCandidate = br->getSuccessor(suc);
auto block_name = entryCandidate->getName().str();
if (find_block_barrier_in_region(current, B)) {
if (block_name.find("warp_init") != block_name.npos) {
is_single_conditional_branch_block = 1;
break;
}
}
}
}
}
}
if (has_barrier || is_single_conditional_branch_block) {
if (region_entry_barrier == NULL)
region_entry_barrier = current;
else if (region_entry_barrier != current) {
// this means there is not PR before B, just return
return;
}
continue;
}
// Non-barrier block, this must be on the region.
current_region.wrapped_block.insert(current);
// Add predecessors to pending queue.
for (BasicBlock *Pred : predecessors(current)) {
pending_blocks.push_back(Pred);
}
}
if (current_region.wrapped_block.empty()) {
return;
}
// if do not find entry node, this means all predecessor
// blocks do not need to execute multiply times
if (region_entry_barrier == NULL) {
return;
}
// Find the entry node.
assert(region_entry_barrier != NULL);
for (unsigned
suc = 0,
num = region_entry_barrier->getTerminator()->getNumSuccessors();
suc < num; ++suc) {
llvm::BasicBlock *entryCandidate =
region_entry_barrier->getTerminator()->getSuccessor(suc);
if (current_region.wrapped_block.count(entryCandidate) == 0)
continue;
entry = entryCandidate;
break;
}
// delete useless PR, those PRs only have branch
if (entry == exit) {
if (entry->size() == 1 && isa<llvm::BranchInst>(entry->begin())) {
return;
}
}
bool is_useless = true;
auto iter = entry;
do {
if (iter->size() != 1 || !isa<llvm::BranchInst>(entry->begin())) {
is_useless = false;
break;
}
if (iter->getTerminator()->getNumSuccessors() > 1) {
is_useless = false;
break;
}
iter = iter->getTerminator()->getSuccessor(0);
} while (iter != exit);
if (is_useless) {
return;
}
assert(current_region.wrapped_block.count(entry) != 0);
current_region.start_block = entry;
current_region.end_block = exit;
current_region.successor_block = B;
parallel_regions.push_back(current_region);
}
std::vector<ParallelRegion> getParallelRegions(llvm::Function *F,
bool intra_warp_loop) {
std::vector<ParallelRegion> parallel_regions;
SmallVector<BasicBlock *, 4> exit_blocks;
for (Function::iterator s = F->begin(); s != F->end(); s++) {
if (llvm::CallInst *call_inst =
llvm::dyn_cast<llvm::CallInst>(s->begin())) {
auto func_name = call_inst->getCalledFunction()->getName().str();
if (func_name == "llvm.nvvm.barrier0" ||
func_name == "llvm.nvvm.barrier.sync") {
exit_blocks.push_back(&(*s));
}
// when handling intra warp loop, we need also split the blocks
// between warp barrier
if (intra_warp_loop && func_name == "llvm.nvvm.bar.warp.sync") {
exit_blocks.push_back(&(*s));
}
}
}
// First find all the ParallelRegions in the Function.
while (!exit_blocks.empty()) {
BasicBlock *exit = exit_blocks.back();
exit_blocks.pop_back();
getParallelRegionBefore(exit, intra_warp_loop, parallel_regions);
}
return parallel_regions;
}
virtual bool runOnFunction(Function &F) {
if (!isKernelFunction(F.getParent(), &F))
return 0;
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
// find parallel region we need to wrap
auto parallel_regions = getParallelRegions(&F, intra_warp_loop);
assert(!parallel_regions.empty() && "can not find any parallel regions\n");
// print_parallel_region(parallel_regions);
add_warp_loop(parallel_regions, intra_warp_loop);
if (intra_warp_loop) {
handle_local_variable_intra_warp(parallel_regions);
}
remove_barrier(&F, intra_warp_loop);
return 1;
}
};
char InsertWarpLoopPass::ID = 0;
namespace {
static RegisterPass<InsertWarpLoopPass> X("insert-warp-loop",
"Insert inter/intra warp loop");
} // namespace
bool has_warp_barrier(llvm::Module *M) {
for (auto F = M->begin(); F != M->end(); ++F)
for (auto BB = F->begin(); BB != F->end(); ++BB) {
for (auto BI = BB->begin(); BI != BB->end(); BI++) {
if (auto Call = dyn_cast<CallInst>(BI)) {
auto func_name = Call->getCalledFunction()->getName().str();
if (func_name == "llvm.nvvm.bar.warp.sync") {
return true;
}
}
}
}
return false;
}
void insert_warp_loop(llvm::Module *M) {
llvm::legacy::PassManager Passes;
need_nested_loop = has_warp_barrier(M);
// use nested loop only when there are warp-level barrier
if (need_nested_loop) {
bool intra_warp = true;
Passes.add(new InsertWarpLoopPass(intra_warp));
// insert inter warp loop
Passes.add(new InsertWarpLoopPass(!intra_warp));
Passes.run(*M);
} else {
bool intra_warp = true;
// only need a single loop, with size=block_size
Passes.add(new InsertWarpLoopPass(intra_warp));
Passes.run(*M);
// remove all barriers
for (auto F = M->begin(); F != M->end(); ++F)
remove_barrier(dyn_cast<llvm::Function>(F), false);
}
}