diff --git a/CMakeLists.txt b/CMakeLists.txt index f1a9d56..325d894 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,11 @@ else() message(FATAL_ERROR "llvm-config is required") endif() +option(DEBUG "Print debug information." OFF) +if(DEBUG) + add_definitions(-DDEBUG) +endif() + set(CMAKE_CXX_FLAGS "${LLVM_CXX_FLAG} ${CMAKE_CXX_FLAGS}") set(GCC_COVERAGE_LINK_FLAGS diff --git a/compilation/KernelTranslation.cpp b/compilation/KernelTranslation.cpp index 77cc494..2dd318b 100644 --- a/compilation/KernelTranslation.cpp +++ b/compilation/KernelTranslation.cpp @@ -33,33 +33,23 @@ int main(int argc, char **argv) { handle_warp_vote(program); // replace warp shuffle - // VerifyModule(program); handle_warp_shfl(program); + // insert sync - // VerifyModule(program); insert_sync(program); + // split block by sync - // VerifyModule(program); - std::cout << "split\n" << std::flush; split_block_by_sync(program); // add loop for intra&intera thread - - // VerifyModule(program); - std::cout << "insert\n" << std::flush; insert_warp_loop(program); - // VerifyModule(program); - // (TODO): replace this patch - std::cout << "replace\n" << std::flush; replace_built_in_function(program); - // VerifyModule(program); - std::cout << "generate\n" << std::flush; + // TODO: replace with a more general function + // Not only for x86 backend generate_x86_format(program); - // VerifyModule(program); - // performance optimization performance_optimization(program); @@ -68,6 +58,5 @@ int main(int argc, char **argv) { DumpModule(program, argv[2]); fout.close(); - return 0; } diff --git a/compilation/KernelTranslation/include/x86/tool.h b/compilation/KernelTranslation/include/x86/tool.h index c4538f1..7e5947b 100644 --- a/compilation/KernelTranslation/include/x86/tool.h +++ b/compilation/KernelTranslation/include/x86/tool.h @@ -4,6 +4,13 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" + +#ifdef DEBUG +#define DEBUG_INFO(...) fprintf(stderr, __VA_ARGS__) +#else +#define DEBUG_INFO(...) +#endif // DEBUG + llvm::Module *LoadModuleFromFilr(char *file_name); void DumpModule(llvm::Module *M, char *file_name); bool isKernelFunction(llvm::Module *M, llvm::Function *F); diff --git a/compilation/KernelTranslation/src/x86/generate_x86_format.cpp b/compilation/KernelTranslation/src/x86/generate_x86_format.cpp index 9d563ff..67d23cd 100644 --- a/compilation/KernelTranslation/src/x86/generate_x86_format.cpp +++ b/compilation/KernelTranslation/src/x86/generate_x86_format.cpp @@ -111,11 +111,8 @@ void decode_input(llvm::Module *M) { auto PT = dyn_cast(share_memory->getType()); auto element_type = PT->getElementType(); - // std::cout << element_type->getTypeID() << " Got global memor $$$$$$" - // << share_memory->getName().str() << std::endl; AllocaInst *new_arr = Builder.CreateAlloca(Int8T, loadedValue, "new_arr"); - // new_arr->setAlignment(llvm::MaybeAlign(16)); Value *new_ar = new_arr; Value *gptr = Builder.CreateBitOrPointerCast( share_memory, PointerType::get(PointerType::get(Int8T, 0), 0)); @@ -176,6 +173,7 @@ void remove_useless_var(llvm::Module *M) { } void generate_x86_format(llvm::Module *M) { + DEBUG_INFO("generate x86 format\n"); // change metadata set_meta_data(M); // decode argument diff --git a/compilation/KernelTranslation/src/x86/handle_sync.cpp b/compilation/KernelTranslation/src/x86/handle_sync.cpp index 565d636..44a5876 100644 --- a/compilation/KernelTranslation/src/x86/handle_sync.cpp +++ b/compilation/KernelTranslation/src/x86/handle_sync.cpp @@ -51,6 +51,7 @@ void split_block_by_sync(llvm::Function *F) { } void split_block_by_sync(llvm::Module *M) { + DEBUG_INFO("split block by sync\n"); for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) { Function *F = &(*i); if (isKernelFunction(M, F)) diff --git a/compilation/KernelTranslation/src/x86/init.cpp b/compilation/KernelTranslation/src/x86/init.cpp index 4b80064..1feedd9 100644 --- a/compilation/KernelTranslation/src/x86/init.cpp +++ b/compilation/KernelTranslation/src/x86/init.cpp @@ -74,6 +74,7 @@ bool find_sreg_inst(llvm::Function *F) { } return false; } + bool inline_func_with_tid(llvm::Module *M) { bool changed = false; std::set need_remove; @@ -87,8 +88,8 @@ bool inline_func_with_tid(llvm::Module *M) { if (CallInst *c = dyn_cast(BI++)) { if (c->getCalledFunction()) { if (find_sreg_inst(c->getCalledFunction())) { - printf("inline: %s\n", - c->getCalledFunction()->getName().str().c_str()); + DEBUG_INFO("inline: %s\n", + c->getCalledFunction()->getName().str().c_str()); need_inline.insert(c); need_remove.insert(c->getCalledFunction()); } @@ -276,7 +277,7 @@ void llvm_preprocess(llvm::Module *M) { Pass *thispass = PIs->createPass(); Passes.add(thispass); } else { - printf("Pass: %s not found\n", pass.c_str()); + DEBUG_INFO("Pass: %s not found\n", pass.c_str()); } } Passes.run(*M); @@ -334,8 +335,6 @@ bool lower_constant_expr(llvm::Module *M) { auto get_from = get_element_ptr->getOperand(0); if (auto addr_cast = dyn_cast(get_from)) { modified = true; - // auto ReplInst = addr_cast->getAsInstruction(); - // ReplInst->insertBefore(get_element_ptr); std::vector Users; // Do not replace use during iteration of use. Do it in another loop for (auto U : addr_cast->users()) { @@ -374,6 +373,7 @@ void replace_cuda_math_built_in(llvm::Module *M) { } void init_block(llvm::Module *M, std::ofstream &fout) { + DEBUG_INFO("init block\n"); // using official llvm preprocess llvm_preprocess(M); // remove useles Cuda function diff --git a/compilation/KernelTranslation/src/x86/insert_sync.cpp b/compilation/KernelTranslation/src/x86/insert_sync.cpp index f7483fe..f04fd5a 100644 --- a/compilation/KernelTranslation/src/x86/insert_sync.cpp +++ b/compilation/KernelTranslation/src/x86/insert_sync.cpp @@ -245,17 +245,14 @@ public: PDT->getPostDomTree().dominates(merge_point, curr)) { // we should insert barrier at the beginning and // end of its predecessor - printf("insert [255]: %s\n", curr->getName().str().c_str()); if (has_warp_barrier(b)) { CreateIntraWarpBarrier(&(*curr->begin())); for (BasicBlock *Pred : predecessors(curr)) { - printf("insert [262]: %s\n", Pred->getName().str().c_str()); CreateIntraWarpBarrier(&(*Pred->getTerminator())); } } else { CreateInterWarpBarrier(&(*curr->begin())); for (BasicBlock *Pred : predecessors(curr)) { - printf("insert [268]: %s\n", Pred->getName().str().c_str()); CreateInterWarpBarrier(&(*Pred->getTerminator())); } } @@ -342,8 +339,8 @@ public: BasicBlock *merge_point = find_merge_point(head, PDT->getPostDomTree()); assert(PDT->getPostDomTree().dominates(merge_point, head)); if (!find_barrier_in_region(head, merge_point)) { - printf("do not need to handle tri-income if: %s\n", - merge_point->getName().str().c_str()); + DEBUG_INFO("do not need to handle tri-income if: %s\n", + merge_point->getName().str().c_str()); continue; } @@ -465,7 +462,7 @@ public: } } else { // handle break in for-loop - printf("loop has multiply exists\n"); + DEBUG_INFO("loop has multiply exists\n"); // this time, we have also insert sync before the for-body auto header_block = L->getHeader(); assert(header_block->getTerminator()->getNumSuccessors() == 2 && @@ -524,7 +521,12 @@ static RegisterPass "Insert built in barriers"); } // namespace +/* +This function inserts implicit synchronization for conditional statements, +please refer to https://dl.acm.org/doi/abs/10.1145/3554736 for detail +*/ void insert_sync(llvm::Module *M) { + DEBUG_INFO("insert sync\n"); auto Registry = PassRegistry::getPassRegistry(); llvm::legacy::PassManager Passes; diff --git a/compilation/KernelTranslation/src/x86/insert_warp_loop.cpp b/compilation/KernelTranslation/src/x86/insert_warp_loop.cpp index 549ac31..6d3b997 100644 --- a/compilation/KernelTranslation/src/x86/insert_warp_loop.cpp +++ b/compilation/KernelTranslation/src/x86/insert_warp_loop.cpp @@ -72,10 +72,6 @@ int need_nested_loop; bool ShouldNotBeContextSaved(llvm::Instruction *instr) { if (isa(instr)) return true; - // if (isa(instr)) - // return true; - // if (isa(instr)) - // return true; llvm::Module *M = instr->getParent()->getParent()->getParent(); llvm::LoadInst *load = dyn_cast(instr); @@ -134,47 +130,6 @@ llvm::Instruction *GetContextArray(llvm::Instruction *instruction, Type *AllocType = elementType; AllocaInst *InstCast = dyn_cast(instruction); - /* - if (InstCast) { - unsigned Alignment = InstCast->getAlignment(); - - uint64_t StoreSize = Layout.getTypeStoreSize(InstCast->getAllocatedType()); - - if ((Alignment > 1) && (StoreSize & (Alignment - 1))) { - uint64_t AlignedSize = (StoreSize & (~(Alignment - 1))) + Alignment; - assert(AlignedSize > StoreSize); - uint64_t RequiredExtraBytes = AlignedSize - StoreSize; - - if (isa(elementType)) { - - ArrayType *StructPadding = ArrayType::get( - Type::getInt8Ty(M->getContext()), RequiredExtraBytes); - - std::vector PaddedStructElements; - PaddedStructElements.push_back(elementType); - PaddedStructElements.push_back(StructPadding); - const ArrayRef NewStructElements(PaddedStructElements); - AllocType = StructType::get(M->getContext(), NewStructElements, true); - uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType); - assert(NewStoreSize == AlignedSize); - } else if (isa(elementType)) { - StructType *OldStruct = dyn_cast(elementType); - - ArrayType *StructPadding = ArrayType::get( - Type::getInt8Ty(M->getContext()), RequiredExtraBytes); - std::vector PaddedStructElements; - for (unsigned j = 0; j < OldStruct->getNumElements(); j++) - PaddedStructElements.push_back(OldStruct->getElementType(j)); - PaddedStructElements.push_back(StructPadding); - const ArrayRef NewStructElements(PaddedStructElements); - AllocType = StructType::get(OldStruct->getContext(), NewStructElements, - OldStruct->isPacked()); - uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType); - assert(NewStoreSize == AlignedSize); - } - } - } - */ llvm::Value *ItemSize = nullptr; llvm::AllocaInst *Alloca = nullptr; @@ -336,7 +291,6 @@ void handle_alloc(llvm::Function *F) { for (auto user : replace_user) { IRBuilder<> builder(user); - // std::vector gepArgs; auto inter_warp_index = createLoad(builder, M->getGlobalVariable("inter_warp_index")); auto intra_warp_index = @@ -597,16 +551,17 @@ void add_warp_loop(std::vector parallel_regions, } void print_parallel_region(std::vector parallel_regions) { - printf("get PR:\n"); + DEBUG_INFO("get PR:\n"); for (auto region : parallel_regions) { auto start = region.start_block; auto end = region.end_block; auto next = region.successor_block; - printf("parallel region: %s->%s next: %s\n", start->getName().str().c_str(), - end->getName().str().c_str(), next->getName().str().c_str()); - printf("have: \n"); + DEBUG_INFO("parallel region: %s->%s next: %s\n", + start->getName().str().c_str(), end->getName().str().c_str(), + next->getName().str().c_str()); + DEBUG_INFO("have: \n"); for (auto b : region.wrapped_block) { - printf("%s\n", b->getName().str().c_str()); + DEBUG_INFO("%s\n", b->getName().str().c_str()); } } } @@ -839,7 +794,9 @@ public: // find parallel region we need to wrap auto parallel_regions = getParallelRegions(&F, intra_warp_loop); assert(!parallel_regions.empty() && "can not find any parallel regions\n"); - // print_parallel_region(parallel_regions); +#ifdef DEBUG + print_parallel_region(parallel_regions); +#endif if (intra_warp_loop) { handle_local_variable_intra_warp(parallel_regions); @@ -874,7 +831,12 @@ bool has_warp_barrier(llvm::Module *M) { return false; } +/* +This function wrap the ParallelRegion with inter/intra warp loops, +please refer to https://dl.acm.org/doi/abs/10.1145/3554736 for detail. +*/ void insert_warp_loop(llvm::Module *M) { + DEBUG_INFO("insert warp loop\n"); llvm::legacy::PassManager Passes; need_nested_loop = has_warp_barrier(M); // use nested loop only when there are warp-level barrier diff --git a/compilation/KernelTranslation/src/x86/performance.cpp b/compilation/KernelTranslation/src/x86/performance.cpp index e583ae1..57bba6e 100644 --- a/compilation/KernelTranslation/src/x86/performance.cpp +++ b/compilation/KernelTranslation/src/x86/performance.cpp @@ -1,4 +1,5 @@ #include "performance.h" +#include "tool.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" @@ -40,6 +41,7 @@ using namespace llvm; void performance_optimization(llvm::Module *M) { + DEBUG_INFO("performance optimization\n"); for (auto F = M->begin(); F != M->end(); F++) { for (auto I = F->arg_begin(); I != F->arg_end(); ++I) { if (I->getType()->isPointerTy()) { @@ -54,10 +56,7 @@ void performance_optimization(llvm::Module *M) { std::string Error; const Target *TheTarget = TargetRegistry::lookupTarget("", triple, Error); - if (!TheTarget) { - printf("Error: %s\n", Error.c_str()); - assert(0); - } + assert(TheTarget && "No Target Information\n"); llvm::TargetOptions Options; Options.FloatABIType = FloatABI::Hard; diff --git a/compilation/KernelTranslation/src/x86/tool.cpp b/compilation/KernelTranslation/src/x86/tool.cpp index 44f5a7b..3b52c77 100644 --- a/compilation/KernelTranslation/src/x86/tool.cpp +++ b/compilation/KernelTranslation/src/x86/tool.cpp @@ -46,7 +46,6 @@ void VerifyModule(llvm::Module *program) { } void DumpModule(llvm::Module *M, char *file_name) { - // modify the program, add a wrapper std::string msg; llvm::raw_string_ostream os(msg); std::error_code EC; @@ -541,8 +540,7 @@ void replace_asm_call(llvm::Module *M) { if (Call->isInlineAsm()) { auto asm_inst = dyn_cast(Call->getCalledOperand()); if (asm_inst->getAsmString() != "mov.u32 $0, %laneid;") { - printf("unknown InlineAsm\n"); - exit(1); + assert(0 && "unknown InlineAsm\n"); } // return the rank within the warp IRBuilder<> builder(context);