apply divergence analysis for replicating local variables
This commit is contained in:
parent
8da1ecc5fd
commit
e99205aa8b
|
@ -8,9 +8,12 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "llvm/ADT/Statistic.h"
|
#include "llvm/ADT/Statistic.h"
|
||||||
|
#include "llvm/ADT/Triple.h"
|
||||||
|
#include "llvm/Analysis/DivergenceAnalysis.h"
|
||||||
#include "llvm/Analysis/LoopInfo.h"
|
#include "llvm/Analysis/LoopInfo.h"
|
||||||
#include "llvm/Analysis/LoopPass.h"
|
#include "llvm/Analysis/LoopPass.h"
|
||||||
#include "llvm/Analysis/PostDominators.h"
|
#include "llvm/Analysis/PostDominators.h"
|
||||||
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||||
#include "llvm/IR/CFG.h"
|
#include "llvm/IR/CFG.h"
|
||||||
#include "llvm/IR/DataLayout.h"
|
#include "llvm/IR/DataLayout.h"
|
||||||
#include "llvm/IR/Function.h"
|
#include "llvm/IR/Function.h"
|
||||||
|
@ -23,9 +26,12 @@
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/ValueSymbolTable.h"
|
#include "llvm/IR/ValueSymbolTable.h"
|
||||||
#include "llvm/InitializePasses.h"
|
#include "llvm/InitializePasses.h"
|
||||||
|
#include "llvm/MC/TargetRegistry.h"
|
||||||
#include "llvm/PassInfo.h"
|
#include "llvm/PassInfo.h"
|
||||||
#include "llvm/PassRegistry.h"
|
#include "llvm/PassRegistry.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Target/TargetMachine.h"
|
||||||
|
#include "llvm/Target/TargetOptions.h"
|
||||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||||
#include "llvm/Transforms/Utils/Cloning.h"
|
#include "llvm/Transforms/Utils/Cloning.h"
|
||||||
|
@ -85,7 +91,6 @@ bool ShouldNotBeContextSaved(llvm::Instruction *instr) {
|
||||||
if (load_addr == M->getGlobalVariable("warp_vote"))
|
if (load_addr == M->getGlobalVariable("warp_vote"))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: we should further analyze whether the local variable
|
// TODO: we should further analyze whether the local variable
|
||||||
// is same among all threads within a wrap
|
// is same among all threads within a wrap
|
||||||
return false;
|
return false;
|
||||||
|
@ -314,7 +319,8 @@ void handle_alloc(llvm::Function *F) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs) {
|
void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs,
|
||||||
|
DivergenceInfo &DI) {
|
||||||
bool intra_warp_loop = 1;
|
bool intra_warp_loop = 1;
|
||||||
// we should handle allocation generated by PHI
|
// we should handle allocation generated by PHI
|
||||||
{
|
{
|
||||||
|
@ -324,6 +330,24 @@ void handle_local_variable_intra_warp(std::vector<ParallelRegion> PRs) {
|
||||||
for (auto ii = bb->begin(); ii != bb->end(); ii++) {
|
for (auto ii = bb->begin(); ii != bb->end(); ii++) {
|
||||||
if (isa<AllocaInst>(&(*ii))) {
|
if (isa<AllocaInst>(&(*ii))) {
|
||||||
auto alloc = dyn_cast<AllocaInst>(&(*ii));
|
auto alloc = dyn_cast<AllocaInst>(&(*ii));
|
||||||
|
// if this alloc's write are all non-divergence, then no need to
|
||||||
|
// replicate
|
||||||
|
bool allStoreNonDivergence = true;
|
||||||
|
for (Instruction::use_iterator ui = alloc->use_begin(),
|
||||||
|
ue = alloc->use_end();
|
||||||
|
ui != ue; ++ui) {
|
||||||
|
llvm::Instruction *user = dyn_cast<Instruction>(ui->getUser());
|
||||||
|
if (isa<StoreInst>(user)) {
|
||||||
|
auto storeVar = user->getOperand(0);
|
||||||
|
if (DI.isDivergent(*storeVar)) {
|
||||||
|
allStoreNonDivergence = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allStoreNonDivergence) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
// Do not duplicate var used outside PRs
|
// Do not duplicate var used outside PRs
|
||||||
bool used_in_non_PR = false;
|
bool used_in_non_PR = false;
|
||||||
for (Instruction::use_iterator ui = alloc->use_begin(),
|
for (Instruction::use_iterator ui = alloc->use_begin(),
|
||||||
|
@ -595,8 +619,6 @@ class InsertWarpLoopPass : public llvm::FunctionPass {
|
||||||
public:
|
public:
|
||||||
static char ID;
|
static char ID;
|
||||||
bool intra_warp_loop;
|
bool intra_warp_loop;
|
||||||
DominatorTree *DT;
|
|
||||||
PostDominatorTree *PDT;
|
|
||||||
|
|
||||||
InsertWarpLoopPass(bool intra_warp = 0)
|
InsertWarpLoopPass(bool intra_warp = 0)
|
||||||
: FunctionPass(ID), intra_warp_loop(intra_warp) {}
|
: FunctionPass(ID), intra_warp_loop(intra_warp) {}
|
||||||
|
@ -604,6 +626,8 @@ public:
|
||||||
virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const {
|
virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const {
|
||||||
AU.addRequired<DominatorTreeWrapperPass>();
|
AU.addRequired<DominatorTreeWrapperPass>();
|
||||||
AU.addRequired<PostDominatorTreeWrapperPass>();
|
AU.addRequired<PostDominatorTreeWrapperPass>();
|
||||||
|
AU.addRequired<LoopInfoWrapperPass>();
|
||||||
|
AU.addRequired<TargetTransformInfoWrapperPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void getParallelRegionBefore(llvm::BasicBlock *B, bool intra_warp_loop,
|
void getParallelRegionBefore(llvm::BasicBlock *B, bool intra_warp_loop,
|
||||||
|
@ -789,8 +813,22 @@ public:
|
||||||
tempInstructionIds.clear();
|
tempInstructionIds.clear();
|
||||||
tempInstructionIndex = 0;
|
tempInstructionIndex = 0;
|
||||||
|
|
||||||
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
// get DivergenceInfo
|
||||||
PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
|
auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
||||||
|
auto PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
|
||||||
|
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
||||||
|
llvm::Triple triple("nvptx64-nvidia-cuda");
|
||||||
|
std::string Error;
|
||||||
|
const Target *TheTarget = TargetRegistry::lookupTarget("", triple, Error);
|
||||||
|
llvm::TargetOptions Options;
|
||||||
|
llvm::TargetMachine *target_machine = TheTarget->createTargetMachine(
|
||||||
|
triple.getTriple(), "sm_35", "+ptx50", Options, llvm::Reloc::Static,
|
||||||
|
llvm::CodeModel::Small, llvm::CodeGenOpt::Aggressive);
|
||||||
|
|
||||||
|
llvm::FunctionAnalysisManager DummyFAM;
|
||||||
|
llvm::TargetTransformInfo TTI =
|
||||||
|
target_machine->getTargetIRAnalysis().run(F, DummyFAM);
|
||||||
|
DivergenceInfo DI(F, *DT, *PDT, LI, TTI, /*KnownReducible*/ true);
|
||||||
|
|
||||||
// find parallel region we need to wrap
|
// find parallel region we need to wrap
|
||||||
auto parallel_regions = getParallelRegions(&F, intra_warp_loop);
|
auto parallel_regions = getParallelRegions(&F, intra_warp_loop);
|
||||||
|
@ -800,7 +838,7 @@ public:
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (intra_warp_loop) {
|
if (intra_warp_loop) {
|
||||||
handle_local_variable_intra_warp(parallel_regions);
|
handle_local_variable_intra_warp(parallel_regions, DI);
|
||||||
}
|
}
|
||||||
add_warp_loop(parallel_regions, intra_warp_loop);
|
add_warp_loop(parallel_regions, intra_warp_loop);
|
||||||
remove_barrier(&F, intra_warp_loop);
|
remove_barrier(&F, intra_warp_loop);
|
||||||
|
|
Loading…
Reference in New Issue