apply divergence analysis for replicating local variables

This commit is contained in:
Ruobing Han 2022-09-22 16:15:38 -04:00
parent 8da1ecc5fd
commit e99205aa8b
1 changed files with 45 additions and 7 deletions

View File

@ -8,9 +8,12 @@
#include <set> #include <set>
#include "llvm/ADT/Statistic.h" #include "llvm/ADT/Statistic.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/DivergenceAnalysis.h"
#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/CFG.h" #include "llvm/IR/CFG.h"
#include "llvm/IR/DataLayout.h" #include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h" #include "llvm/IR/Function.h"
@ -23,9 +26,12 @@
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/ValueSymbolTable.h" #include "llvm/IR/ValueSymbolTable.h"
#include "llvm/InitializePasses.h" #include "llvm/InitializePasses.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/PassInfo.h" #include "llvm/PassInfo.h"
#include "llvm/PassRegistry.h" #include "llvm/PassRegistry.h"
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Cloning.h"
@ -85,7 +91,6 @@ bool ShouldNotBeContextSaved(llvm::Instruction *instr) {
if (load_addr == M->getGlobalVariable("warp_vote")) if (load_addr == M->getGlobalVariable("warp_vote"))
return true; return true;
} }
// TODO: we should further analyze whether the local variable // TODO: we should further analyze whether the local variable
// is same among all threads within a wrap // is same among all threads within a wrap
return false; return false;
@ -314,7 +319,8 @@ void handle_alloc(llvm::Function *F) {
} }
} }
void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs) { void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs,
DivergenceInfo &DI) {
bool intra_warp_loop = 1; bool intra_warp_loop = 1;
// we should handle allocation generated by PHI // we should handle allocation generated by PHI
{ {
@ -324,6 +330,24 @@ void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs) {
for (auto ii = bb->begin(); ii != bb->end(); ii++) { for (auto ii = bb->begin(); ii != bb->end(); ii++) {
if (isa<AllocaInst>(&(*ii))) { if (isa<AllocaInst>(&(*ii))) {
auto alloc = dyn_cast<AllocaInst>(&(*ii)); auto alloc = dyn_cast<AllocaInst>(&(*ii));
// if this alloc's write are all non-divergence, then no need to
// replicate
bool allStoreNonDivergence = true;
for (Instruction::use_iterator ui = alloc->use_begin(),
ue = alloc->use_end();
ui != ue; ++ui) {
llvm::Instruction *user = dyn_cast<Instruction>(ui->getUser());
if (isa<StoreInst>(user)) {
auto storeVar = user->getOperand(0);
if (DI.isDivergent(*storeVar)) {
allStoreNonDivergence = false;
break;
}
}
}
if (allStoreNonDivergence) {
continue;
}
// Do not duplicate var used outside PRs // Do not duplicate var used outside PRs
bool used_in_non_PR = false; bool used_in_non_PR = false;
for (Instruction::use_iterator ui = alloc->use_begin(), for (Instruction::use_iterator ui = alloc->use_begin(),
@ -595,8 +619,6 @@ class InsertWarpLoopPass : public llvm::FunctionPass {
public: public:
static char ID; static char ID;
bool intra_warp_loop; bool intra_warp_loop;
DominatorTree *DT;
PostDominatorTree *PDT;
InsertWarpLoopPass(bool intra_warp = 0) InsertWarpLoopPass(bool intra_warp = 0)
: FunctionPass(ID), intra_warp_loop(intra_warp) {} : FunctionPass(ID), intra_warp_loop(intra_warp) {}
@ -604,6 +626,8 @@ public:
virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const { virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const {
AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<PostDominatorTreeWrapperPass>(); AU.addRequired<PostDominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
} }
void getParallelRegionBefore(llvm::BasicBlock *B, bool intra_warp_loop, void getParallelRegionBefore(llvm::BasicBlock *B, bool intra_warp_loop,
@ -789,8 +813,22 @@ public:
tempInstructionIds.clear(); tempInstructionIds.clear();
tempInstructionIndex = 0; tempInstructionIndex = 0;
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); // get DivergenceInfo
PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
llvm::Triple triple("nvptx64-nvidia-cuda");
std::string Error;
const Target *TheTarget = TargetRegistry::lookupTarget("", triple, Error);
llvm::TargetOptions Options;
llvm::TargetMachine *target_machine = TheTarget->createTargetMachine(
triple.getTriple(), "sm_35", "+ptx50", Options, llvm::Reloc::Static,
llvm::CodeModel::Small, llvm::CodeGenOpt::Aggressive);
llvm::FunctionAnalysisManager DummyFAM;
llvm::TargetTransformInfo TTI =
target_machine->getTargetIRAnalysis().run(F, DummyFAM);
DivergenceInfo DI(F, *DT, *PDT, LI, TTI, /*KnownReducible*/ true);
// find parallel region we need to wrap // find parallel region we need to wrap
auto parallel_regions = getParallelRegions(&F, intra_warp_loop); auto parallel_regions = getParallelRegions(&F, intra_warp_loop);
@ -800,7 +838,7 @@ public:
#endif #endif
if (intra_warp_loop) { if (intra_warp_loop) {
handle_local_variable_intra_warp(parallel_regions); handle_local_variable_intra_warp(parallel_regions, DI);
} }
add_warp_loop(parallel_regions, intra_warp_loop); add_warp_loop(parallel_regions, intra_warp_loop);
remove_barrier(&F, intra_warp_loop); remove_barrier(&F, intra_warp_loop);