CuPBoP/compilation/KernelTranslation/lib/generate_x86_format.cpp

179 lines
6.0 KiB
C++

#include "generate_x86_format.h"
#include "tool.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <iostream>
using namespace llvm;
void set_meta_data(llvm::Module *M) {
M->setTargetTriple("x86_64-unknown-linux-gnu");
M->setDataLayout(
"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128");
}
// as pthread only accept a single void* for input
// we have to decode this input inside the kernel
void decode_input(llvm::Module *M) {
std::set<llvm::Function *> need_remove;
LLVMContext *C = &M->getContext();
llvm::Type *Int32T = Type::getInt32Ty(*C);
llvm::Type *Int8T = Type::getInt8Ty(*C);
llvm::FunctionType *LauncherFuncT = FunctionType::get(
Type::getVoidTy(*C), {PointerType::get(Int8T, 0)}, false);
std::set<GlobalVariable *> dynmaic_memory;
std::map<GlobalVariable *, Value *> corres_dynamic_memory_load_address;
// generate Wrapper Function type
// now we only support a single int32*
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
Function *F = &(*i);
if (!isKernelFunction(M, F))
continue;
auto func_name = F->getName().str();
llvm::IRBuilder<> Builder(M->getContext());
FunctionCallee fc =
M->getOrInsertFunction(func_name + "_wrapper", LauncherFuncT);
Function *WorkGroup = dyn_cast<Function>(fc.getCallee());
BasicBlock *Block = BasicBlock::Create(M->getContext(), "", WorkGroup);
Builder.SetInsertPoint(Block);
// WorkGroup has only a single input
Function::arg_iterator ai = WorkGroup->arg_begin();
SmallVector<Value *, 8> Arguments;
Value *input_arg = &*ai;
// convert to int**
input_arg = Builder.CreateBitOrPointerCast(
input_arg, PointerType::get(PointerType::get(Int32T, 0), 0));
// dynamic memory load in the wrapper function
GlobalVariable *share_memory = M->getGlobalVariable("wrapper_global_data");
if (share_memory != NULL) {
dynmaic_memory.insert(share_memory);
llvm::GlobalVariable *global_mem = new llvm::GlobalVariable(
*M, Int32T, false, llvm::GlobalValue::ExternalLinkage, NULL,
"thread_memory_size", NULL, llvm::GlobalValue::GeneralDynamicTLSModel,
0, false);
Value *loadedValue = Builder.CreateLoad(global_mem);
llvm::FunctionType *LaunchFun2 = FunctionType::get(
PointerType::get(PointerType::get(Int32T, 0), 0), NULL);
FunctionCallee fc2 =
M->getOrInsertFunction("_wrapper_global_data", LaunchFun2);
Function *WorkGroup2 = dyn_cast<Function>(fc2.getCallee());
WorkGroup2->setLinkage(GlobalValue::WeakODRLinkage);
WorkGroup2->setVisibility(GlobalValue::HiddenVisibility);
Comdat *co = M->getOrInsertComdat("_wrapper_global_data");
co->setSelectionKind(Comdat::SelectionKind::Any);
WorkGroup2->setComdat(co);
BasicBlock *Block2 = BasicBlock::Create(M->getContext(), "", WorkGroup2);
llvm::IRBuilder<> Builder2(M->getContext());
Builder2.SetInsertPoint(Block2);
Builder2.CreateRet(share_memory);
auto PT = dyn_cast<PointerType>(share_memory->getType());
auto element_type = PT->getElementType();
// std::cout << element_type->getTypeID() << " Got global memor $$$$$$"
// << share_memory->getName().str() << std::endl;
AllocaInst *new_arr = Builder.CreateAlloca(Int8T, loadedValue, "new_arr");
// new_arr->setAlignment(llvm::MaybeAlign(16));
Value *new_ar = new_arr;
Value *gptr = Builder.CreateBitOrPointerCast(
share_memory, PointerType::get(PointerType::get(Int8T, 0), 0));
Builder.CreateStore(new_ar, gptr);
}
size_t idx = 0;
for (Function::const_arg_iterator ii = F->arg_begin(), ee = F->arg_end();
ii != ee; ++ii) {
Type *ArgType = ii->getType();
// calculate addr
Value *GEP = Builder.CreateGEP(input_arg, ConstantInt::get(Int32T, idx));
// load corresponding int*
GEP = Builder.CreateLoad(GEP);
// bitcast
GEP = Builder.CreateBitOrPointerCast(GEP, PointerType::get(ArgType, 0));
Value *Arg = Builder.CreateLoad(GEP);
Arguments.push_back(Arg);
++idx;
}
CallInst *c = Builder.CreateCall(F, ArrayRef<llvm::Value *>(Arguments));
Builder.CreateRetVoid();
}
for (auto f : need_remove) {
f->dropAllReferences();
f->eraseFromParent();
}
}
void remove_barrier(llvm::Module *M) {
std::vector<Instruction *> need_remove;
for (auto F = M->begin(); F != M->end(); ++F)
for (auto BB = F->begin(); BB != F->end(); ++BB) {
for (auto BI = BB->begin(); BI != BB->end(); BI++) {
if (auto Call = dyn_cast<CallInst>(BI)) {
if (Call->isInlineAsm())
continue;
auto func_name = Call->getCalledFunction()->getName().str();
if (func_name == "llvm.nvvm.bar.warp.sync" ||
func_name == "llvm.nvvm.barrier0" ||
func_name == "llvm.nvvm.barrier.sync") {
need_remove.push_back(Call);
}
}
}
}
for (auto inst : need_remove) {
inst->eraseFromParent();
}
}
void remove_useless_var(llvm::Module *M) {
M->getGlobalVariable("intra_warp_index")->eraseFromParent();
M->getGlobalVariable("inter_warp_index")->eraseFromParent();
}
void generate_x86_format(llvm::Module *M) {
// change metadata
set_meta_data(M);
// decode argument
decode_input(M);
// remove barrier
remove_barrier(M);
// remove useless func/variable
remove_useless_var(M);
}