support CloverLeaf on LLVM14

This commit is contained in:
Ruobing Han 2022-07-13 18:39:59 -04:00
parent 8fddb647bd
commit cf12d604eb
7 changed files with 101 additions and 11 deletions

View File

@ -9,20 +9,77 @@
#include "llvm/IR/LLVMContext.h" #include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
#include "llvm/Transforms/Utils/CtorUtils.h"
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <set> #include <set>
using namespace llvm; using namespace llvm;
/// Given a llvm.global_ctors list that we can understand,
/// return a list of the functions and null terminator as a vector.
static std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) {
if (GV->getInitializer()->isNullValue())
return std::vector<Function *>();
ConstantArray *CA = cast<ConstantArray>(GV->getInitializer());
std::vector<Function *> Result;
Result.reserve(CA->getNumOperands());
for (auto &V : CA->operands()) {
ConstantStruct *CS = cast<ConstantStruct>(V);
Result.push_back(dyn_cast<Function>(CS->getOperand(1)));
}
return Result;
}
void RemoveCudaBuiltin(llvm::Module *M) { void RemoveCudaBuiltin(llvm::Module *M) {
std::set<llvm::Function *> need_remove; std::set<llvm::Function *> need_remove;
if (GlobalVariable *gv = M->getGlobalVariable("llvm.global_ctors")) { // remove cuda built-in from Ctors
gv->dropAllReferences(); if (GlobalVariable *GV = M->getGlobalVariable("llvm.global_ctors")) {
gv->eraseFromParent(); std::vector<Function *> Ctors = parseGlobalCtors(GV);
if (!Ctors.empty()) {
ConstantArray *OldCA = cast<ConstantArray>(GV->getInitializer());
SmallVector<Constant *, 10> CAList;
for (int i = 0; i < OldCA->getNumOperands(); i++) {
if (!Ctors[i])
continue;
if (Ctors[i]->hasName() &&
Ctors[i]->getName().str().find("__cuda") == std::string::npos) {
std::cout << "keep: " << Ctors[i]->getName().str() << std::endl
<< std::flush;
CAList.push_back(OldCA->getOperand(i));
} }
}
// Create the new array initializer.
ArrayType *ATy =
ArrayType::get(OldCA->getType()->getElementType(), CAList.size());
Constant *CA = ConstantArray::get(ATy, CAList);
// If we didn't change the number of elements, don't create a new GV.
if (CA->getType() == OldCA->getType()) {
GV->setInitializer(CA);
} else {
// Create the new global and insert it next to the existing list.
GlobalVariable *NGV = new GlobalVariable(
CA->getType(), GV->isConstant(), GV->getLinkage(), CA, "",
GV->getThreadLocalMode());
GV->getParent()->getGlobalList().insert(GV->getIterator(), NGV);
NGV->takeName(GV);
// Nuke the old list, replacing any uses with the new one.
if (!GV->use_empty()) {
Constant *V = NGV;
if (V->getType() != GV->getType())
V = ConstantExpr::getBitCast(V, GV->getType());
GV->replaceAllUsesWith(V);
}
GV->eraseFromParent();
}
}
}
Function *c_tor = NULL; Function *c_tor = NULL;
if (c_tor = M->getFunction("__cuda_module_ctor")) { if (c_tor = M->getFunction("__cuda_module_ctor")) {
c_tor->dropAllReferences(); c_tor->dropAllReferences();

View File

@ -360,6 +360,12 @@ void replace_cuda_math_built_in(llvm::Module *M) {
if (func_name.find("_ZL3expd") != std::string::npos) { if (func_name.find("_ZL3expd") != std::string::npos) {
F->deleteBody(); F->deleteBody();
} }
if (func_name.find("_ZL8copysigndd") != std::string::npos) {
F->deleteBody();
}
if (func_name.find("_ZL8copysigndd.8") != std::string::npos) {
F->deleteBody();
}
} }
} }
@ -370,6 +376,8 @@ void init_block(llvm::Module *M, std::ofstream &fout) {
remove_cuda_built_in(M); remove_cuda_built_in(M);
// replace CUDA math function, like expf // replace CUDA math function, like expf
replace_cuda_math_built_in(M); replace_cuda_math_built_in(M);
// replace CUDA math function, like expf
replace_cuda_math_built_in(M);
// lower ConstantExpression // lower ConstantExpression
bool modified; bool modified;

View File

@ -464,7 +464,9 @@ void replace_built_in_function(llvm::Module *M) {
std::vector<Value *> Indices; std::vector<Value *> Indices;
Indices.push_back(ConstantInt::get(I32, 0)); Indices.push_back(ConstantInt::get(I32, 0));
Indices.push_back(ConstantInt::get(I32, i)); Indices.push_back(ConstantInt::get(I32, i));
auto new_GEP = GetElementPtrInst::Create(NULL, // Pointee type auto new_GEP = GetElementPtrInst::Create(
cast<PointerType>(src_alloc->getType()->getScalarType())
->getElementType(),
src_alloc, // Alloca src_alloc, // Alloca
Indices, // Indices Indices, // Indices
"", Call); "", Call);
@ -503,8 +505,14 @@ void replace_built_in_function(llvm::Module *M) {
Call->getCalledFunction()->setName("__nvvm_lohi_i2d"); Call->getCalledFunction()->setName("__nvvm_lohi_i2d");
} else if (func_name == "llvm.nvvm.fabs.f") { } else if (func_name == "llvm.nvvm.fabs.f") {
Call->getCalledFunction()->setName("__nvvm_fabs_f"); Call->getCalledFunction()->setName("__nvvm_fabs_f");
} else if (func_name == "llvm.nvvm.fabs.d") {
Call->getCalledFunction()->setName("__nv_fabsd");
} else if (func_name == "llvm.nvvm.mul24.i") { } else if (func_name == "llvm.nvvm.mul24.i") {
Call->getCalledFunction()->setName("__nvvm_mul24_i"); Call->getCalledFunction()->setName("__nvvm_mul24_i");
} else if (func_name == "llvm.nvvm.fmin.d") {
Call->getCalledFunction()->setName("__nv_fmind");
} else if (func_name == "llvm.nvvm.fmax.d") {
Call->getCalledFunction()->setName("__nv_fmaxd");
} }
} }
} }

View File

@ -70,7 +70,9 @@ void handle_warp_vote(llvm::Module *M) {
new LoadInst(intra_warp_index_addr->getType()->getPointerElementType(), new LoadInst(intra_warp_index_addr->getType()->getPointerElementType(),
intra_warp_index_addr, "intra_warp_index", sync_inst); intra_warp_index_addr, "intra_warp_index", sync_inst);
auto GEP = GetElementPtrInst::Create(NULL, // Pointee type auto GEP = GetElementPtrInst::Create(
cast<PointerType>(warp_vote_ptr->getType()->getScalarType())
->getElementType(),
warp_vote_ptr, // Alloca warp_vote_ptr, // Alloca
{zero, intra_warp_index}, // Indices {zero, intra_warp_index}, // Indices
"", sync_inst); "", sync_inst);

View File

@ -19,7 +19,11 @@ float __nv_fmodf(float, float);
int __nv_isnanf(float); int __nv_isnanf(float);
int __nv_isinff(float); int __nv_isinff(float);
float __nv_fabsf(float); float __nv_fabsf(float);
double __nv_fabsd(double);
double __nv_fmind(double, double);
double __nv_fmaxd(double, double);
int __nvvm_mul24_i(int, int); int __nvvm_mul24_i(int, int);
double _ZL3expd(double); double _ZL3expd(double);
double _ZL8copysigndd(double, double);
} }
#endif #endif

View File

@ -15,5 +15,9 @@ float __nv_fmodf(float x, float y) { return fmod(x, y); }
int __nv_isnanf(float v) { return isnan(v); } int __nv_isnanf(float v) { return isnan(v); }
int __nv_isinff(float v) { return isinf(v); } int __nv_isinff(float v) { return isinf(v); }
float __nv_fabsf(float v) { return abs(v); } float __nv_fabsf(float v) { return abs(v); }
double __nv_fabsd(double v) { return abs(v); }
double __nv_fmind(double a, double b) { return (a < b) ? a : b; }
double __nv_fmaxd(double a, double b) { return (a > b) ? a : b; }
int __nvvm_mul24_i(int a, int b) { return a * b; } int __nvvm_mul24_i(int a, int b) { return a * b; }
double _ZL3expd(double base) { return exp(base); } double _ZL3expd(double base) { return exp(base); }
double _ZL8copysigndd(double x, double y) { return y > 0 ? abs(x) : -abs(x); };

View File

@ -17,7 +17,11 @@
Initialize the device Initialize the device
*/ */
int device_max_compute_units = 1; int device_max_compute_units = 1;
bool device_initilized = false;
int init_device() { int init_device() {
if (device_initilized)
return 0;
device_initilized = true;
cu_device *device = (cu_device *)calloc(1, sizeof(cu_device)); cu_device *device = (cu_device *)calloc(1, sizeof(cu_device));
if (device == NULL) if (device == NULL)
return C_ERROR_MEMALLOC; return C_ERROR_MEMALLOC;
@ -231,6 +235,9 @@ void scheduler_uninit() {
Counting Barrier basically Counting Barrier basically
*/ */
void cuSynchronizeBarrier() { void cuSynchronizeBarrier() {
if (!device_initilized) {
init_device();
}
while (1) { while (1) {
// (TODO): currently, we assume each kernel launch will have a // (TODO): currently, we assume each kernel launch will have a
// following sync // following sync