update compilation with DEBUG mode

This commit is contained in:
Ruobing Han 2022-09-15 12:33:28 -04:00
parent 9152feb24f
commit bb3724c486
10 changed files with 49 additions and 88 deletions

View File

@ -33,6 +33,11 @@ else()
message(FATAL_ERROR "llvm-config is required")
endif()
option(DEBUG "Print debug information." OFF)
if(DEBUG)
add_definitions(-DDEBUG)
endif()
set(CMAKE_CXX_FLAGS "${LLVM_CXX_FLAG} ${CMAKE_CXX_FLAGS}")
set(GCC_COVERAGE_LINK_FLAGS

View File

@ -33,33 +33,23 @@ int main(int argc, char **argv) {
handle_warp_vote(program);
// replace warp shuffle
// VerifyModule(program);
handle_warp_shfl(program);
// insert sync
// VerifyModule(program);
insert_sync(program);
// split block by sync
// VerifyModule(program);
std::cout << "split\n" << std::flush;
split_block_by_sync(program);
// add loop for intra&intera thread
// VerifyModule(program);
std::cout << "insert\n" << std::flush;
insert_warp_loop(program);
// VerifyModule(program);
// (TODO): replace this patch
std::cout << "replace\n" << std::flush;
replace_built_in_function(program);
// VerifyModule(program);
std::cout << "generate\n" << std::flush;
// TODO: replace with a more general function
// Not only for x86 backend
generate_x86_format(program);
// VerifyModule(program);
// performance optimization
performance_optimization(program);
@ -68,6 +58,5 @@ int main(int argc, char **argv) {
DumpModule(program, argv[2]);
fout.close();
return 0;
}

View File

@ -4,6 +4,13 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#ifdef DEBUG
#define DEBUG_INFO(...) fprintf(stderr, __VA_ARGS__)
#else
#define DEBUG_INFO(...)
#endif // DEBUG
llvm::Module *LoadModuleFromFilr(char *file_name);
void DumpModule(llvm::Module *M, char *file_name);
bool isKernelFunction(llvm::Module *M, llvm::Function *F);

View File

@ -111,11 +111,8 @@ void decode_input(llvm::Module *M) {
auto PT = dyn_cast<PointerType>(share_memory->getType());
auto element_type = PT->getElementType();
// std::cout << element_type->getTypeID() << " Got global memor $$$$$$"
// << share_memory->getName().str() << std::endl;
AllocaInst *new_arr = Builder.CreateAlloca(Int8T, loadedValue, "new_arr");
// new_arr->setAlignment(llvm::MaybeAlign(16));
Value *new_ar = new_arr;
Value *gptr = Builder.CreateBitOrPointerCast(
share_memory, PointerType::get(PointerType::get(Int8T, 0), 0));
@ -176,6 +173,7 @@ void remove_useless_var(llvm::Module *M) {
}
void generate_x86_format(llvm::Module *M) {
DEBUG_INFO("generate x86 format\n");
// change metadata
set_meta_data(M);
// decode argument

View File

@ -51,6 +51,7 @@ void split_block_by_sync(llvm::Function *F) {
}
void split_block_by_sync(llvm::Module *M) {
DEBUG_INFO("split block by sync\n");
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
Function *F = &(*i);
if (isKernelFunction(M, F))

View File

@ -74,6 +74,7 @@ bool find_sreg_inst(llvm::Function *F) {
}
return false;
}
bool inline_func_with_tid(llvm::Module *M) {
bool changed = false;
std::set<llvm::Function *> need_remove;
@ -87,7 +88,7 @@ bool inline_func_with_tid(llvm::Module *M) {
if (CallInst *c = dyn_cast<CallInst>(BI++)) {
if (c->getCalledFunction()) {
if (find_sreg_inst(c->getCalledFunction())) {
printf("inline: %s\n",
DEBUG_INFO("inline: %s\n",
c->getCalledFunction()->getName().str().c_str());
need_inline.insert(c);
need_remove.insert(c->getCalledFunction());
@ -276,7 +277,7 @@ void llvm_preprocess(llvm::Module *M) {
Pass *thispass = PIs->createPass();
Passes.add(thispass);
} else {
printf("Pass: %s not found\n", pass.c_str());
DEBUG_INFO("Pass: %s not found\n", pass.c_str());
}
}
Passes.run(*M);
@ -334,8 +335,6 @@ bool lower_constant_expr(llvm::Module *M) {
auto get_from = get_element_ptr->getOperand(0);
if (auto addr_cast = dyn_cast<llvm::ConstantExpr>(get_from)) {
modified = true;
// auto ReplInst = addr_cast->getAsInstruction();
// ReplInst->insertBefore(get_element_ptr);
std::vector<Instruction *> Users;
// Do not replace use during iteration of use. Do it in another loop
for (auto U : addr_cast->users()) {
@ -374,6 +373,7 @@ void replace_cuda_math_built_in(llvm::Module *M) {
}
void init_block(llvm::Module *M, std::ofstream &fout) {
DEBUG_INFO("init block\n");
// using official llvm preprocess
llvm_preprocess(M);
// remove useles Cuda function

View File

@ -245,17 +245,14 @@ public:
PDT->getPostDomTree().dominates(merge_point, curr)) {
// we should insert barrier at the beginning and
// end of its predecessor
printf("insert [255]: %s\n", curr->getName().str().c_str());
if (has_warp_barrier(b)) {
CreateIntraWarpBarrier(&(*curr->begin()));
for (BasicBlock *Pred : predecessors(curr)) {
printf("insert [262]: %s\n", Pred->getName().str().c_str());
CreateIntraWarpBarrier(&(*Pred->getTerminator()));
}
} else {
CreateInterWarpBarrier(&(*curr->begin()));
for (BasicBlock *Pred : predecessors(curr)) {
printf("insert [268]: %s\n", Pred->getName().str().c_str());
CreateInterWarpBarrier(&(*Pred->getTerminator()));
}
}
@ -342,7 +339,7 @@ public:
BasicBlock *merge_point = find_merge_point(head, PDT->getPostDomTree());
assert(PDT->getPostDomTree().dominates(merge_point, head));
if (!find_barrier_in_region(head, merge_point)) {
printf("do not need to handle tri-income if: %s\n",
DEBUG_INFO("do not need to handle tri-income if: %s\n",
merge_point->getName().str().c_str());
continue;
}
@ -465,7 +462,7 @@ public:
}
} else {
// handle break in for-loop
printf("loop has multiply exists\n");
DEBUG_INFO("loop has multiply exists\n");
// this time, we have also insert sync before the for-body
auto header_block = L->getHeader();
assert(header_block->getTerminator()->getNumSuccessors() == 2 &&
@ -524,7 +521,12 @@ static RegisterPass<InsertBuiltInBarrier>
"Insert built in barriers");
} // namespace
/*
This function inserts implicit synchronization for conditional statements,
please refer to https://dl.acm.org/doi/abs/10.1145/3554736 for detail
*/
void insert_sync(llvm::Module *M) {
DEBUG_INFO("insert sync\n");
auto Registry = PassRegistry::getPassRegistry();
llvm::legacy::PassManager Passes;

View File

@ -72,10 +72,6 @@ int need_nested_loop;
bool ShouldNotBeContextSaved(llvm::Instruction *instr) {
if (isa<BranchInst>(instr))
return true;
// if (isa<AddrSpaceCastInst>(instr))
// return true;
// if (isa<CastInst>(instr))
// return true;
llvm::Module *M = instr->getParent()->getParent()->getParent();
llvm::LoadInst *load = dyn_cast<llvm::LoadInst>(instr);
@ -134,47 +130,6 @@ llvm::Instruction *GetContextArray(llvm::Instruction *instruction,
Type *AllocType = elementType;
AllocaInst *InstCast = dyn_cast<AllocaInst>(instruction);
/*
if (InstCast) {
unsigned Alignment = InstCast->getAlignment();
uint64_t StoreSize = Layout.getTypeStoreSize(InstCast->getAllocatedType());
if ((Alignment > 1) && (StoreSize & (Alignment - 1))) {
uint64_t AlignedSize = (StoreSize & (~(Alignment - 1))) + Alignment;
assert(AlignedSize > StoreSize);
uint64_t RequiredExtraBytes = AlignedSize - StoreSize;
if (isa<ArrayType>(elementType)) {
ArrayType *StructPadding = ArrayType::get(
Type::getInt8Ty(M->getContext()), RequiredExtraBytes);
std::vector<Type *> PaddedStructElements;
PaddedStructElements.push_back(elementType);
PaddedStructElements.push_back(StructPadding);
const ArrayRef<Type *> NewStructElements(PaddedStructElements);
AllocType = StructType::get(M->getContext(), NewStructElements, true);
uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType);
assert(NewStoreSize == AlignedSize);
} else if (isa<StructType>(elementType)) {
StructType *OldStruct = dyn_cast<StructType>(elementType);
ArrayType *StructPadding = ArrayType::get(
Type::getInt8Ty(M->getContext()), RequiredExtraBytes);
std::vector<Type *> PaddedStructElements;
for (unsigned j = 0; j < OldStruct->getNumElements(); j++)
PaddedStructElements.push_back(OldStruct->getElementType(j));
PaddedStructElements.push_back(StructPadding);
const ArrayRef<Type *> NewStructElements(PaddedStructElements);
AllocType = StructType::get(OldStruct->getContext(), NewStructElements,
OldStruct->isPacked());
uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType);
assert(NewStoreSize == AlignedSize);
}
}
}
*/
llvm::Value *ItemSize = nullptr;
llvm::AllocaInst *Alloca = nullptr;
@ -336,7 +291,6 @@ void handle_alloc(llvm::Function *F) {
for (auto user : replace_user) {
IRBuilder<> builder(user);
// std::vector<llvm::Value *> gepArgs;
auto inter_warp_index =
createLoad(builder, M->getGlobalVariable("inter_warp_index"));
auto intra_warp_index =
@ -597,16 +551,17 @@ void add_warp_loop(std::vector<ParallelRegion> parallel_regions,
}
void print_parallel_region(std::vector<ParallelRegion> parallel_regions) {
printf("get PR:\n");
DEBUG_INFO("get PR:\n");
for (auto region : parallel_regions) {
auto start = region.start_block;
auto end = region.end_block;
auto next = region.successor_block;
printf("parallel region: %s->%s next: %s\n", start->getName().str().c_str(),
end->getName().str().c_str(), next->getName().str().c_str());
printf("have: \n");
DEBUG_INFO("parallel region: %s->%s next: %s\n",
start->getName().str().c_str(), end->getName().str().c_str(),
next->getName().str().c_str());
DEBUG_INFO("have: \n");
for (auto b : region.wrapped_block) {
printf("%s\n", b->getName().str().c_str());
DEBUG_INFO("%s\n", b->getName().str().c_str());
}
}
}
@ -839,7 +794,9 @@ public:
// find parallel region we need to wrap
auto parallel_regions = getParallelRegions(&F, intra_warp_loop);
assert(!parallel_regions.empty() && "can not find any parallel regions\n");
// print_parallel_region(parallel_regions);
#ifdef DEBUG
print_parallel_region(parallel_regions);
#endif
if (intra_warp_loop) {
handle_local_variable_intra_warp(parallel_regions);
@ -874,7 +831,12 @@ bool has_warp_barrier(llvm::Module *M) {
return false;
}
/*
This function wrap the ParallelRegion with inter/intra warp loops,
please refer to https://dl.acm.org/doi/abs/10.1145/3554736 for detail.
*/
void insert_warp_loop(llvm::Module *M) {
DEBUG_INFO("insert warp loop\n");
llvm::legacy::PassManager Passes;
need_nested_loop = has_warp_barrier(M);
// use nested loop only when there are warp-level barrier

View File

@ -1,4 +1,5 @@
#include "performance.h"
#include "tool.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
@ -40,6 +41,7 @@
using namespace llvm;
void performance_optimization(llvm::Module *M) {
DEBUG_INFO("performance optimization\n");
for (auto F = M->begin(); F != M->end(); F++) {
for (auto I = F->arg_begin(); I != F->arg_end(); ++I) {
if (I->getType()->isPointerTy()) {
@ -54,10 +56,7 @@ void performance_optimization(llvm::Module *M) {
std::string Error;
const Target *TheTarget = TargetRegistry::lookupTarget("", triple, Error);
if (!TheTarget) {
printf("Error: %s\n", Error.c_str());
assert(0);
}
assert(TheTarget && "No Target Information\n");
llvm::TargetOptions Options;
Options.FloatABIType = FloatABI::Hard;

View File

@ -46,7 +46,6 @@ void VerifyModule(llvm::Module *program) {
}
void DumpModule(llvm::Module *M, char *file_name) {
// modify the program, add a wrapper
std::string msg;
llvm::raw_string_ostream os(msg);
std::error_code EC;
@ -541,8 +540,7 @@ void replace_asm_call(llvm::Module *M) {
if (Call->isInlineAsm()) {
auto asm_inst = dyn_cast<InlineAsm>(Call->getCalledOperand());
if (asm_inst->getAsmString() != "mov.u32 $0, %laneid;") {
printf("unknown InlineAsm\n");
exit(1);
assert(0 && "unknown InlineAsm\n");
}
// return the rank within the warp
IRBuilder<> builder(context);