#include "ReplaceKernelLaunch.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include #include #include using namespace llvm; // Change to i8* bitcast (i8* (i8*)* @_Z9vecPKiS0_Pii_wrapper to i8*) // Original: i8* bitcast (void (i32*, i32*, i32*, i32)* @_Z9vecPKiS0_Pii to i8*) void ReplaceKernelLaunch(llvm::Module *M) { LLVMContext &context = M->getContext(); auto VoidTy = llvm::Type::getVoidTy(context); auto I8 = llvm::Type::getInt8PtrTy(context); std::map kernels; LLVMContext *C = &M->getContext(); llvm::Type *Int32T = Type::getInt32Ty(*C); llvm::Type *Int8T = Type::getInt8Ty(*C); llvm::FunctionType *LauncherFuncT = FunctionType::get(Type::getVoidTy(*C), NULL); llvm::FunctionType *LaunchFun2 = FunctionType::get(PointerType::get(PointerType::get(Int32T, 0), 0), NULL); bool done = false; for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) { Function *F = &(*i); auto func_name = F->getName().str(); for (Function::iterator b = F->begin(); b != F->end(); ++b) { BasicBlock *B = &(*b); for (BasicBlock::iterator i = B->begin(); i != B->end(); ++i) { Instruction *inst = &(*i); if (llvm::CallInst *callInst = llvm::dyn_cast(inst)) { if (Function *calledFunction = callInst->getCalledFunction()) { if (calledFunction->getName().startswith("cudaLaunchKernel")) { Value *callOperand = callInst->getArgOperand(0); Function *functionOperand = dyn_cast(callInst->getArgOperand(0)); // call function is wrapped in a bitcast if (functionOperand == NULL) { std::vector arg_sizes; functionOperand = dyn_cast(callOperand->stripPointerCasts()); FunctionType *ft = calledFunction->getFunctionType(); std::cout << " Parent (Caller) Function Name: " << func_name << ", cudaLaunchKernel Function: " << functionOperand->getName().str() << ", args " << functionOperand->arg_size() << std::endl; auto rep = kernels.find(functionOperand->getName().str()); if (rep != kernels.end()) { callInst->setArgOperand(0, rep->second); continue; } std::vector Params; Params.push_back(I8); FunctionType *FT = FunctionType::get(VoidTy, Params, false); std::string newName = functionOperand->getName().str() + "_wrapper"; Function *F = Function::Create(FT, Function::ExternalLinkage, newName, M); F->setDSOLocal(true); BitCastInst *BC = new BitCastInst(F, I8, "", callInst); callInst->setArgOperand(0, BC); kernels.insert({functionOperand->getName().str(), BC}); } } } } } } } }