fix bug for inserting sync after kernelLaunch

This commit is contained in:
RobinHan 2022-06-18 13:39:26 -04:00
parent 4791dfc9c9
commit 7d29a409f6
1 changed files with 5 additions and 22 deletions

View File

@ -26,7 +26,7 @@ void InsertSyncAfterKernelLaunch(llvm::Module *M) {
llvm::FunctionCallee _f = llvm::FunctionCallee _f =
M->getOrInsertFunction("cudaDeviceSynchronize", LauncherFuncT); M->getOrInsertFunction("cudaDeviceSynchronize", LauncherFuncT);
llvm::Function *func_launch = llvm::cast<llvm::Function>(_f.getCallee()); llvm::Function *func_launch = llvm::cast<llvm::Function>(_f.getCallee());
std::set<std::string> launch_function_name; std::set<llvm::Instruction *> kernel_launch_instruction;
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) { for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
Function *F = &(*i); Function *F = &(*i);
auto func_name = F->getName().str(); auto func_name = F->getName().str();
@ -40,33 +40,16 @@ void InsertSyncAfterKernelLaunch(llvm::Module *M) {
if (Function *calledFunction = callInst->getCalledFunction()) { if (Function *calledFunction = callInst->getCalledFunction()) {
if (calledFunction->getName().startswith("cudaLaunchKernel")) { if (calledFunction->getName().startswith("cudaLaunchKernel")) {
// F is a kernel launch function // F is a kernel launch function
launch_function_name.insert(func_name); kernel_launch_instruction.insert(callInst);
} }
} }
} }
} }
} }
} }
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) { for (auto call : kernel_launch_instruction) {
Function *F = &(*i); auto sync_call = llvm::CallInst::Create(func_launch, "inserted_sync");
for (Function::iterator b = F->begin(); b != F->end(); ++b) { sync_call->insertAfter(call);
BasicBlock *B = &(*b);
for (BasicBlock::iterator i = B->begin(); i != B->end(); ++i) {
Instruction *inst = &(*i);
if (llvm::CallBase *callInst = llvm::dyn_cast<llvm::CallBase>(inst)) {
if (Function *calledFunction = callInst->getCalledFunction()) {
if (launch_function_name.find(calledFunction->getName().str()) !=
launch_function_name.end()) {
// insert a sync after launch
auto sync_call =
llvm::CallInst::Create(func_launch, "inserted_sync");
sync_call->insertAfter(callInst);
}
}
}
}
}
} }
} }