update compilation with DEBUG mode
This commit is contained in:
parent
9152feb24f
commit
bb3724c486
|
@ -33,6 +33,11 @@ else()
|
|||
message(FATAL_ERROR "llvm-config is required")
|
||||
endif()
|
||||
|
||||
option(DEBUG "Print debug information." OFF)
|
||||
if(DEBUG)
|
||||
add_definitions(-DDEBUG)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${LLVM_CXX_FLAG} ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
set(GCC_COVERAGE_LINK_FLAGS
|
||||
|
|
|
@ -33,33 +33,23 @@ int main(int argc, char **argv) {
|
|||
handle_warp_vote(program);
|
||||
|
||||
// replace warp shuffle
|
||||
// VerifyModule(program);
|
||||
handle_warp_shfl(program);
|
||||
|
||||
// insert sync
|
||||
// VerifyModule(program);
|
||||
insert_sync(program);
|
||||
|
||||
// split block by sync
|
||||
// VerifyModule(program);
|
||||
std::cout << "split\n" << std::flush;
|
||||
split_block_by_sync(program);
|
||||
// add loop for intra&intera thread
|
||||
|
||||
// VerifyModule(program);
|
||||
std::cout << "insert\n" << std::flush;
|
||||
insert_warp_loop(program);
|
||||
|
||||
// VerifyModule(program);
|
||||
|
||||
// (TODO): replace this patch
|
||||
std::cout << "replace\n" << std::flush;
|
||||
replace_built_in_function(program);
|
||||
|
||||
// VerifyModule(program);
|
||||
std::cout << "generate\n" << std::flush;
|
||||
// TODO: replace with a more general function
|
||||
// Not only for x86 backend
|
||||
generate_x86_format(program);
|
||||
|
||||
// VerifyModule(program);
|
||||
|
||||
// performance optimization
|
||||
performance_optimization(program);
|
||||
|
||||
|
@ -68,6 +58,5 @@ int main(int argc, char **argv) {
|
|||
DumpModule(program, argv[2]);
|
||||
|
||||
fout.close();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -4,6 +4,13 @@
|
|||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
||||
#ifdef DEBUG
|
||||
#define DEBUG_INFO(...) fprintf(stderr, __VA_ARGS__)
|
||||
#else
|
||||
#define DEBUG_INFO(...)
|
||||
#endif // DEBUG
|
||||
|
||||
llvm::Module *LoadModuleFromFilr(char *file_name);
|
||||
void DumpModule(llvm::Module *M, char *file_name);
|
||||
bool isKernelFunction(llvm::Module *M, llvm::Function *F);
|
||||
|
|
|
@ -111,11 +111,8 @@ void decode_input(llvm::Module *M) {
|
|||
|
||||
auto PT = dyn_cast<PointerType>(share_memory->getType());
|
||||
auto element_type = PT->getElementType();
|
||||
// std::cout << element_type->getTypeID() << " Got global memor $$$$$$"
|
||||
// << share_memory->getName().str() << std::endl;
|
||||
|
||||
AllocaInst *new_arr = Builder.CreateAlloca(Int8T, loadedValue, "new_arr");
|
||||
// new_arr->setAlignment(llvm::MaybeAlign(16));
|
||||
Value *new_ar = new_arr;
|
||||
Value *gptr = Builder.CreateBitOrPointerCast(
|
||||
share_memory, PointerType::get(PointerType::get(Int8T, 0), 0));
|
||||
|
@ -176,6 +173,7 @@ void remove_useless_var(llvm::Module *M) {
|
|||
}
|
||||
|
||||
void generate_x86_format(llvm::Module *M) {
|
||||
DEBUG_INFO("generate x86 format\n");
|
||||
// change metadata
|
||||
set_meta_data(M);
|
||||
// decode argument
|
||||
|
|
|
@ -51,6 +51,7 @@ void split_block_by_sync(llvm::Function *F) {
|
|||
}
|
||||
|
||||
void split_block_by_sync(llvm::Module *M) {
|
||||
DEBUG_INFO("split block by sync\n");
|
||||
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
|
||||
Function *F = &(*i);
|
||||
if (isKernelFunction(M, F))
|
||||
|
|
|
@ -74,6 +74,7 @@ bool find_sreg_inst(llvm::Function *F) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool inline_func_with_tid(llvm::Module *M) {
|
||||
bool changed = false;
|
||||
std::set<llvm::Function *> need_remove;
|
||||
|
@ -87,8 +88,8 @@ bool inline_func_with_tid(llvm::Module *M) {
|
|||
if (CallInst *c = dyn_cast<CallInst>(BI++)) {
|
||||
if (c->getCalledFunction()) {
|
||||
if (find_sreg_inst(c->getCalledFunction())) {
|
||||
printf("inline: %s\n",
|
||||
c->getCalledFunction()->getName().str().c_str());
|
||||
DEBUG_INFO("inline: %s\n",
|
||||
c->getCalledFunction()->getName().str().c_str());
|
||||
need_inline.insert(c);
|
||||
need_remove.insert(c->getCalledFunction());
|
||||
}
|
||||
|
@ -276,7 +277,7 @@ void llvm_preprocess(llvm::Module *M) {
|
|||
Pass *thispass = PIs->createPass();
|
||||
Passes.add(thispass);
|
||||
} else {
|
||||
printf("Pass: %s not found\n", pass.c_str());
|
||||
DEBUG_INFO("Pass: %s not found\n", pass.c_str());
|
||||
}
|
||||
}
|
||||
Passes.run(*M);
|
||||
|
@ -334,8 +335,6 @@ bool lower_constant_expr(llvm::Module *M) {
|
|||
auto get_from = get_element_ptr->getOperand(0);
|
||||
if (auto addr_cast = dyn_cast<llvm::ConstantExpr>(get_from)) {
|
||||
modified = true;
|
||||
// auto ReplInst = addr_cast->getAsInstruction();
|
||||
// ReplInst->insertBefore(get_element_ptr);
|
||||
std::vector<Instruction *> Users;
|
||||
// Do not replace use during iteration of use. Do it in another loop
|
||||
for (auto U : addr_cast->users()) {
|
||||
|
@ -374,6 +373,7 @@ void replace_cuda_math_built_in(llvm::Module *M) {
|
|||
}
|
||||
|
||||
void init_block(llvm::Module *M, std::ofstream &fout) {
|
||||
DEBUG_INFO("init block\n");
|
||||
// using official llvm preprocess
|
||||
llvm_preprocess(M);
|
||||
// remove useles Cuda function
|
||||
|
|
|
@ -245,17 +245,14 @@ public:
|
|||
PDT->getPostDomTree().dominates(merge_point, curr)) {
|
||||
// we should insert barrier at the beginning and
|
||||
// end of its predecessor
|
||||
printf("insert [255]: %s\n", curr->getName().str().c_str());
|
||||
if (has_warp_barrier(b)) {
|
||||
CreateIntraWarpBarrier(&(*curr->begin()));
|
||||
for (BasicBlock *Pred : predecessors(curr)) {
|
||||
printf("insert [262]: %s\n", Pred->getName().str().c_str());
|
||||
CreateIntraWarpBarrier(&(*Pred->getTerminator()));
|
||||
}
|
||||
} else {
|
||||
CreateInterWarpBarrier(&(*curr->begin()));
|
||||
for (BasicBlock *Pred : predecessors(curr)) {
|
||||
printf("insert [268]: %s\n", Pred->getName().str().c_str());
|
||||
CreateInterWarpBarrier(&(*Pred->getTerminator()));
|
||||
}
|
||||
}
|
||||
|
@ -342,8 +339,8 @@ public:
|
|||
BasicBlock *merge_point = find_merge_point(head, PDT->getPostDomTree());
|
||||
assert(PDT->getPostDomTree().dominates(merge_point, head));
|
||||
if (!find_barrier_in_region(head, merge_point)) {
|
||||
printf("do not need to handle tri-income if: %s\n",
|
||||
merge_point->getName().str().c_str());
|
||||
DEBUG_INFO("do not need to handle tri-income if: %s\n",
|
||||
merge_point->getName().str().c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -465,7 +462,7 @@ public:
|
|||
}
|
||||
} else {
|
||||
// handle break in for-loop
|
||||
printf("loop has multiply exists\n");
|
||||
DEBUG_INFO("loop has multiply exists\n");
|
||||
// this time, we have also insert sync before the for-body
|
||||
auto header_block = L->getHeader();
|
||||
assert(header_block->getTerminator()->getNumSuccessors() == 2 &&
|
||||
|
@ -524,7 +521,12 @@ static RegisterPass<InsertBuiltInBarrier>
|
|||
"Insert built in barriers");
|
||||
} // namespace
|
||||
|
||||
/*
|
||||
This function inserts implicit synchronization for conditional statements,
|
||||
please refer to https://dl.acm.org/doi/abs/10.1145/3554736 for detail
|
||||
*/
|
||||
void insert_sync(llvm::Module *M) {
|
||||
DEBUG_INFO("insert sync\n");
|
||||
auto Registry = PassRegistry::getPassRegistry();
|
||||
|
||||
llvm::legacy::PassManager Passes;
|
||||
|
|
|
@ -72,10 +72,6 @@ int need_nested_loop;
|
|||
bool ShouldNotBeContextSaved(llvm::Instruction *instr) {
|
||||
if (isa<BranchInst>(instr))
|
||||
return true;
|
||||
// if (isa<AddrSpaceCastInst>(instr))
|
||||
// return true;
|
||||
// if (isa<CastInst>(instr))
|
||||
// return true;
|
||||
|
||||
llvm::Module *M = instr->getParent()->getParent()->getParent();
|
||||
llvm::LoadInst *load = dyn_cast<llvm::LoadInst>(instr);
|
||||
|
@ -134,47 +130,6 @@ llvm::Instruction *GetContextArray(llvm::Instruction *instruction,
|
|||
|
||||
Type *AllocType = elementType;
|
||||
AllocaInst *InstCast = dyn_cast<AllocaInst>(instruction);
|
||||
/*
|
||||
if (InstCast) {
|
||||
unsigned Alignment = InstCast->getAlignment();
|
||||
|
||||
uint64_t StoreSize = Layout.getTypeStoreSize(InstCast->getAllocatedType());
|
||||
|
||||
if ((Alignment > 1) && (StoreSize & (Alignment - 1))) {
|
||||
uint64_t AlignedSize = (StoreSize & (~(Alignment - 1))) + Alignment;
|
||||
assert(AlignedSize > StoreSize);
|
||||
uint64_t RequiredExtraBytes = AlignedSize - StoreSize;
|
||||
|
||||
if (isa<ArrayType>(elementType)) {
|
||||
|
||||
ArrayType *StructPadding = ArrayType::get(
|
||||
Type::getInt8Ty(M->getContext()), RequiredExtraBytes);
|
||||
|
||||
std::vector<Type *> PaddedStructElements;
|
||||
PaddedStructElements.push_back(elementType);
|
||||
PaddedStructElements.push_back(StructPadding);
|
||||
const ArrayRef<Type *> NewStructElements(PaddedStructElements);
|
||||
AllocType = StructType::get(M->getContext(), NewStructElements, true);
|
||||
uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType);
|
||||
assert(NewStoreSize == AlignedSize);
|
||||
} else if (isa<StructType>(elementType)) {
|
||||
StructType *OldStruct = dyn_cast<StructType>(elementType);
|
||||
|
||||
ArrayType *StructPadding = ArrayType::get(
|
||||
Type::getInt8Ty(M->getContext()), RequiredExtraBytes);
|
||||
std::vector<Type *> PaddedStructElements;
|
||||
for (unsigned j = 0; j < OldStruct->getNumElements(); j++)
|
||||
PaddedStructElements.push_back(OldStruct->getElementType(j));
|
||||
PaddedStructElements.push_back(StructPadding);
|
||||
const ArrayRef<Type *> NewStructElements(PaddedStructElements);
|
||||
AllocType = StructType::get(OldStruct->getContext(), NewStructElements,
|
||||
OldStruct->isPacked());
|
||||
uint64_t NewStoreSize = Layout.getTypeStoreSize(AllocType);
|
||||
assert(NewStoreSize == AlignedSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
llvm::Value *ItemSize = nullptr;
|
||||
llvm::AllocaInst *Alloca = nullptr;
|
||||
|
||||
|
@ -336,7 +291,6 @@ void handle_alloc(llvm::Function *F) {
|
|||
for (auto user : replace_user) {
|
||||
|
||||
IRBuilder<> builder(user);
|
||||
// std::vector<llvm::Value *> gepArgs;
|
||||
auto inter_warp_index =
|
||||
createLoad(builder, M->getGlobalVariable("inter_warp_index"));
|
||||
auto intra_warp_index =
|
||||
|
@ -597,16 +551,17 @@ void add_warp_loop(std::vector<ParallelRegion> parallel_regions,
|
|||
}
|
||||
|
||||
void print_parallel_region(std::vector<ParallelRegion> parallel_regions) {
|
||||
printf("get PR:\n");
|
||||
DEBUG_INFO("get PR:\n");
|
||||
for (auto region : parallel_regions) {
|
||||
auto start = region.start_block;
|
||||
auto end = region.end_block;
|
||||
auto next = region.successor_block;
|
||||
printf("parallel region: %s->%s next: %s\n", start->getName().str().c_str(),
|
||||
end->getName().str().c_str(), next->getName().str().c_str());
|
||||
printf("have: \n");
|
||||
DEBUG_INFO("parallel region: %s->%s next: %s\n",
|
||||
start->getName().str().c_str(), end->getName().str().c_str(),
|
||||
next->getName().str().c_str());
|
||||
DEBUG_INFO("have: \n");
|
||||
for (auto b : region.wrapped_block) {
|
||||
printf("%s\n", b->getName().str().c_str());
|
||||
DEBUG_INFO("%s\n", b->getName().str().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -839,7 +794,9 @@ public:
|
|||
// find parallel region we need to wrap
|
||||
auto parallel_regions = getParallelRegions(&F, intra_warp_loop);
|
||||
assert(!parallel_regions.empty() && "can not find any parallel regions\n");
|
||||
// print_parallel_region(parallel_regions);
|
||||
#ifdef DEBUG
|
||||
print_parallel_region(parallel_regions);
|
||||
#endif
|
||||
|
||||
if (intra_warp_loop) {
|
||||
handle_local_variable_intra_warp(parallel_regions);
|
||||
|
@ -874,7 +831,12 @@ bool has_warp_barrier(llvm::Module *M) {
|
|||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
This function wrap the ParallelRegion with inter/intra warp loops,
|
||||
please refer to https://dl.acm.org/doi/abs/10.1145/3554736 for detail.
|
||||
*/
|
||||
void insert_warp_loop(llvm::Module *M) {
|
||||
DEBUG_INFO("insert warp loop\n");
|
||||
llvm::legacy::PassManager Passes;
|
||||
need_nested_loop = has_warp_barrier(M);
|
||||
// use nested loop only when there are warp-level barrier
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "performance.h"
|
||||
#include "tool.h"
|
||||
#include "llvm/ADT/Statistic.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Triple.h"
|
||||
|
@ -40,6 +41,7 @@
|
|||
using namespace llvm;
|
||||
|
||||
void performance_optimization(llvm::Module *M) {
|
||||
DEBUG_INFO("performance optimization\n");
|
||||
for (auto F = M->begin(); F != M->end(); F++) {
|
||||
for (auto I = F->arg_begin(); I != F->arg_end(); ++I) {
|
||||
if (I->getType()->isPointerTy()) {
|
||||
|
@ -54,10 +56,7 @@ void performance_optimization(llvm::Module *M) {
|
|||
|
||||
std::string Error;
|
||||
const Target *TheTarget = TargetRegistry::lookupTarget("", triple, Error);
|
||||
if (!TheTarget) {
|
||||
printf("Error: %s\n", Error.c_str());
|
||||
assert(0);
|
||||
}
|
||||
assert(TheTarget && "No Target Information\n");
|
||||
llvm::TargetOptions Options;
|
||||
Options.FloatABIType = FloatABI::Hard;
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ void VerifyModule(llvm::Module *program) {
|
|||
}
|
||||
|
||||
void DumpModule(llvm::Module *M, char *file_name) {
|
||||
// modify the program, add a wrapper
|
||||
std::string msg;
|
||||
llvm::raw_string_ostream os(msg);
|
||||
std::error_code EC;
|
||||
|
@ -541,8 +540,7 @@ void replace_asm_call(llvm::Module *M) {
|
|||
if (Call->isInlineAsm()) {
|
||||
auto asm_inst = dyn_cast<InlineAsm>(Call->getCalledOperand());
|
||||
if (asm_inst->getAsmString() != "mov.u32 $0, %laneid;") {
|
||||
printf("unknown InlineAsm\n");
|
||||
exit(1);
|
||||
assert(0 && "unknown InlineAsm\n");
|
||||
}
|
||||
// return the rank within the warp
|
||||
IRBuilder<> builder(context);
|
||||
|
|
Loading…
Reference in New Issue