#include "generate_x86_format.h" #include "tool.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; void set_meta_data(llvm::Module *M) { M->setTargetTriple("x86_64-unknown-linux-gnu"); M->setDataLayout( "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"); } // as pthread only accept a single void* for input // we have to decode this input inside the kernel void decode_input(llvm::Module *M) { std::set need_remove; LLVMContext *C = &M->getContext(); llvm::Type *Int32T = Type::getInt32Ty(*C); llvm::Type *Int8T = Type::getInt8Ty(*C); llvm::FunctionType *LauncherFuncT = FunctionType::get( Type::getVoidTy(*C), {PointerType::get(Int8T, 0)}, false); // generate Wrapper Function type // now we only support a single int32* for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) { Function *F = &(*i); if (!isKernelFunction(M, F)) continue; auto func_name = F->getName().str(); llvm::IRBuilder<> Builder(M->getContext()); FunctionCallee fc = M->getOrInsertFunction(func_name + "_wrapper", LauncherFuncT); Function *WorkGroup = dyn_cast(fc.getCallee()); BasicBlock *Block = BasicBlock::Create(M->getContext(), "", WorkGroup); Builder.SetInsertPoint(Block); // WorkGroup has only a single input Function::arg_iterator ai = WorkGroup->arg_begin(); SmallVector Arguments; Value *input_arg = &*ai; // convert to int** input_arg = Builder.CreateBitOrPointerCast( input_arg, PointerType::get(PointerType::get(Int32T, 0), 0)); size_t idx = 0; for (Function::const_arg_iterator ii = F->arg_begin(), ee = F->arg_end(); ii != ee; ++ii) { Type *ArgType = ii->getType(); // calculate addr Value *GEP = Builder.CreateGEP(input_arg, ConstantInt::get(Int32T, idx)); // load corresponding int* GEP = Builder.CreateLoad(GEP); // bitcast GEP = Builder.CreateBitOrPointerCast(GEP, PointerType::get(ArgType, 0)); Value *Arg = Builder.CreateLoad(GEP); Arguments.push_back(Arg); ++idx; } CallInst *c = Builder.CreateCall(F, ArrayRef(Arguments)); Builder.CreateRetVoid(); } for (auto f : need_remove) { f->dropAllReferences(); f->eraseFromParent(); } } void remove_barrier(llvm::Module *M) { std::vector need_remove; for (auto F = M->begin(); F != M->end(); ++F) for (auto BB = F->begin(); BB != F->end(); ++BB) { for (auto BI = BB->begin(); BI != BB->end(); BI++) { if (auto Call = dyn_cast(BI)) { auto func_name = Call->getCalledFunction()->getName().str(); if (func_name == "llvm.nvvm.bar.warp.sync" || func_name == "llvm.nvvm.barrier0" || func_name == "llvm.nvvm.barrier.sync") { need_remove.push_back(Call); } } } } for (auto inst : need_remove) { inst->eraseFromParent(); } } void generate_x86_format(llvm::Module *M) { // change metadata set_meta_data(M); // decode argument decode_input(M); // remove barrier remove_barrier(M); }