#include "generate_x86_format.h" #include "tool.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include using namespace llvm; void set_meta_data(llvm::Module *M) { M->setTargetTriple("x86_64-unknown-linux-gnu"); M->setDataLayout( "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"); } // as pthread only accept a single void* for input // we have to decode this input inside the kernel void decode_input(llvm::Module *M) { std::set need_remove; LLVMContext *C = &M->getContext(); llvm::Type *Int32T = Type::getInt32Ty(*C); llvm::Type *Int8T = Type::getInt8Ty(*C); llvm::FunctionType *LauncherFuncT = FunctionType::get( Type::getVoidTy(*C), {PointerType::get(Int8T, 0)}, false); std::set dynmaic_memory; std::map corres_dynamic_memory_load_address; // generate Wrapper Function type // now we only support a single int32* for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) { Function *F = &(*i); if (!isKernelFunction(M, F)) continue; auto func_name = F->getName().str(); llvm::IRBuilder<> Builder(M->getContext()); FunctionCallee fc = M->getOrInsertFunction(func_name + "_wrapper", LauncherFuncT); Function *WorkGroup = dyn_cast(fc.getCallee()); BasicBlock *Block = BasicBlock::Create(M->getContext(), "", WorkGroup); Builder.SetInsertPoint(Block); // WorkGroup has only a single input Function::arg_iterator ai = WorkGroup->arg_begin(); SmallVector Arguments; Value *input_arg = &*ai; // convert to int** input_arg = Builder.CreateBitOrPointerCast( input_arg, PointerType::get(PointerType::get(Int32T, 0), 0)); // dynamic memory load in the wrapper function GlobalVariable *share_memory = M->getGlobalVariable("wrapper_global_data"); if (share_memory != NULL) { dynmaic_memory.insert(share_memory); llvm::GlobalVariable *global_mem = new llvm::GlobalVariable( *M, Int32T, false, llvm::GlobalValue::ExternalLinkage, NULL, "thread_memory_size", NULL, llvm::GlobalValue::GeneralDynamicTLSModel, 0, false); Value *loadedValue = Builder.CreateLoad(global_mem); llvm::FunctionType *LaunchFun2 = FunctionType::get( PointerType::get(PointerType::get(Int32T, 0), 0), NULL); FunctionCallee fc2 = M->getOrInsertFunction("_wrapper_global_data", LaunchFun2); Function *WorkGroup2 = dyn_cast(fc2.getCallee()); WorkGroup2->setLinkage(GlobalValue::WeakODRLinkage); WorkGroup2->setVisibility(GlobalValue::HiddenVisibility); Comdat *co = M->getOrInsertComdat("_wrapper_global_data"); co->setSelectionKind(Comdat::SelectionKind::Any); WorkGroup2->setComdat(co); BasicBlock *Block2 = BasicBlock::Create(M->getContext(), "", WorkGroup2); llvm::IRBuilder<> Builder2(M->getContext()); Builder2.SetInsertPoint(Block2); Builder2.CreateRet(share_memory); auto PT = dyn_cast(share_memory->getType()); auto element_type = PT->getElementType(); // std::cout << element_type->getTypeID() << " Got global memor $$$$$$" // << share_memory->getName().str() << std::endl; AllocaInst *new_arr = Builder.CreateAlloca(Int8T, loadedValue, "new_arr"); // new_arr->setAlignment(llvm::MaybeAlign(16)); Value *new_ar = new_arr; Value *gptr = Builder.CreateBitOrPointerCast( share_memory, PointerType::get(PointerType::get(Int8T, 0), 0)); Builder.CreateStore(new_ar, gptr); } size_t idx = 0; for (Function::const_arg_iterator ii = F->arg_begin(), ee = F->arg_end(); ii != ee; ++ii) { Type *ArgType = ii->getType(); // calculate addr Value *GEP = Builder.CreateGEP(input_arg, ConstantInt::get(Int32T, idx)); // load corresponding int* GEP = Builder.CreateLoad(GEP); // bitcast GEP = Builder.CreateBitOrPointerCast(GEP, PointerType::get(ArgType, 0)); Value *Arg = Builder.CreateLoad(GEP); Arguments.push_back(Arg); ++idx; } CallInst *c = Builder.CreateCall(F, ArrayRef(Arguments)); Builder.CreateRetVoid(); } for (auto f : need_remove) { f->dropAllReferences(); f->eraseFromParent(); } } void remove_barrier(llvm::Module *M) { std::vector need_remove; for (auto F = M->begin(); F != M->end(); ++F) for (auto BB = F->begin(); BB != F->end(); ++BB) { for (auto BI = BB->begin(); BI != BB->end(); BI++) { if (auto Call = dyn_cast(BI)) { if (Call->isInlineAsm()) continue; auto func_name = Call->getCalledFunction()->getName().str(); if (func_name == "llvm.nvvm.bar.warp.sync" || func_name == "llvm.nvvm.barrier0" || func_name == "llvm.nvvm.barrier.sync") { need_remove.push_back(Call); } } } } for (auto inst : need_remove) { inst->eraseFromParent(); } } void remove_useless_var(llvm::Module *M) { M->getGlobalVariable("intra_warp_index")->eraseFromParent(); M->getGlobalVariable("inter_warp_index")->eraseFromParent(); } void generate_x86_format(llvm::Module *M) { // change metadata set_meta_data(M); // decode argument decode_input(M); // remove barrier remove_barrier(M); // remove useless func/variable remove_useless_var(M); }