CuPBoP/compilation/KernelTranslation/lib/generate_x86_format.cpp

120 lines
3.7 KiB
C++

#include "generate_x86_format.h"
#include "tool.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
using namespace llvm;
void set_meta_data(llvm::Module *M) {
M->setTargetTriple("x86_64-unknown-linux-gnu");
M->setDataLayout(
"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128");
}
// as pthread only accept a single void* for input
// we have to decode this input inside the kernel
void decode_input(llvm::Module *M) {
std::set<llvm::Function *> need_remove;
LLVMContext *C = &M->getContext();
llvm::Type *Int32T = Type::getInt32Ty(*C);
llvm::Type *Int8T = Type::getInt8Ty(*C);
llvm::FunctionType *LauncherFuncT = FunctionType::get(
Type::getVoidTy(*C), {PointerType::get(Int8T, 0)}, false);
// generate Wrapper Function type
// now we only support a single int32*
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
Function *F = &(*i);
if (!isKernelFunction(M, F))
continue;
auto func_name = F->getName().str();
llvm::IRBuilder<> Builder(M->getContext());
FunctionCallee fc =
M->getOrInsertFunction(func_name + "_wrapper", LauncherFuncT);
Function *WorkGroup = dyn_cast<Function>(fc.getCallee());
BasicBlock *Block = BasicBlock::Create(M->getContext(), "", WorkGroup);
Builder.SetInsertPoint(Block);
// WorkGroup has only a single input
Function::arg_iterator ai = WorkGroup->arg_begin();
SmallVector<Value *, 8> Arguments;
Value *input_arg = &*ai;
// convert to int**
input_arg = Builder.CreateBitOrPointerCast(
input_arg, PointerType::get(PointerType::get(Int32T, 0), 0));
size_t idx = 0;
for (Function::const_arg_iterator ii = F->arg_begin(), ee = F->arg_end();
ii != ee; ++ii) {
Type *ArgType = ii->getType();
// calculate addr
Value *GEP = Builder.CreateGEP(input_arg, ConstantInt::get(Int32T, idx));
// load corresponding int*
GEP = Builder.CreateLoad(GEP);
// bitcast
GEP = Builder.CreateBitOrPointerCast(GEP, PointerType::get(ArgType, 0));
Value *Arg = Builder.CreateLoad(GEP);
Arguments.push_back(Arg);
++idx;
}
CallInst *c = Builder.CreateCall(F, ArrayRef<llvm::Value *>(Arguments));
Builder.CreateRetVoid();
}
for (auto f : need_remove) {
f->dropAllReferences();
f->eraseFromParent();
}
}
void remove_barrier(llvm::Module *M) {
std::vector<Instruction *> need_remove;
for (auto F = M->begin(); F != M->end(); ++F)
for (auto BB = F->begin(); BB != F->end(); ++BB) {
for (auto BI = BB->begin(); BI != BB->end(); BI++) {
if (auto Call = dyn_cast<CallInst>(BI)) {
auto func_name = Call->getCalledFunction()->getName().str();
if (func_name == "llvm.nvvm.bar.warp.sync" ||
func_name == "llvm.nvvm.barrier0" ||
func_name == "llvm.nvvm.barrier.sync") {
need_remove.push_back(Call);
}
}
}
}
for (auto inst : need_remove) {
inst->eraseFromParent();
}
}
void generate_x86_format(llvm::Module *M) {
// change metadata
set_meta_data(M);
// decode argument
decode_input(M);
// remove barrier
remove_barrier(M);
}