| //===- UniformityAnalysis.cpp ---------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Analysis/UniformityAnalysis.h" |
| #include "llvm/ADT/GenericUniformityImpl.h" |
| #include "llvm/ADT/SmallBitVector.h" |
| #include "llvm/Analysis/CycleAnalysis.h" |
| #include "llvm/Analysis/TargetTransformInfo.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/IR/InstIterator.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/InitializePasses.h" |
| |
| using namespace llvm; |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs( |
| const Instruction &I) const { |
| return isDivergent((const Value *)&I); |
| } |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( |
| const Instruction &Instr) { |
| return markDivergent(cast<Value>(&Instr)); |
| } |
| |
| template <> |
| void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( |
| const Value *V) { |
| for (const auto *User : V->users()) { |
| if (const auto *UserInstr = dyn_cast<const Instruction>(User)) { |
| markDivergent(*UserInstr); |
| } |
| } |
| } |
| |
| template <> |
| void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( |
| const Instruction &Instr) { |
| assert(!isAlwaysUniform(Instr)); |
| if (Instr.isTerminator()) |
| return; |
| pushUsers(cast<Value>(&Instr)); |
| } |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::printDivergentArgs( |
| raw_ostream &OS) const { |
| bool HaveDivergentArgs = false; |
| for (const auto &Arg : F.args()) { |
| if (isDivergent(&Arg)) { |
| if (!HaveDivergentArgs) { |
| OS << "DIVERGENT ARGUMENTS:\n"; |
| HaveDivergentArgs = true; |
| } |
| OS << " DIVERGENT: " << Context.print(&Arg) << '\n'; |
| } |
| } |
| return HaveDivergentArgs; |
| } |
| |
| template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { |
| // Pre-populate UniformValues with uniform values, then seed divergence. |
| // NeverUniform values are not inserted -- they are divergent by definition |
| // and will be reported as such by isDivergent() (not in UniformValues). |
| SmallVector<const Value *, 4> DivergentArgs; |
| for (auto &Arg : F.args()) { |
| if (TTI->getValueUniformity(&Arg) == ValueUniformity::NeverUniform) |
| DivergentArgs.push_back(&Arg); |
| else |
| UniformValues.insert(&Arg); |
| } |
| for (auto &I : instructions(F)) { |
| ValueUniformity IU = TTI->getValueUniformity(&I); |
| switch (IU) { |
| case ValueUniformity::AlwaysUniform: |
| UniformValues.insert(&I); |
| addUniformOverride(I); |
| continue; |
| case ValueUniformity::NeverUniform: |
| // Skip inserting -- divergent by definition. Add to Worklist directly |
| // so compute() propagates divergence to users. |
| if (I.isTerminator()) |
| DivergentTermBlocks.insert(I.getParent()); |
| Worklist.push_back(&I); |
| continue; |
| case ValueUniformity::Custom: |
| UniformValues.insert(&I); |
| addCustomUniformityCandidate(&I); |
| continue; |
| case ValueUniformity::Default: |
| UniformValues.insert(&I); |
| break; |
| } |
| } |
| // Arguments are not instructions and cannot go on the Worklist, so we |
| // propagate their divergence to users explicitly here. This must happen |
| // after all instructions are in UniformValues so markDivergent (called |
| // inside pushUsers) can successfully erase user instructions from the set. |
| for (const Value *Arg : DivergentArgs) |
| pushUsers(Arg); |
| } |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( |
| const Instruction &I, const Cycle &DefCycle) const { |
| assert(!isAlwaysUniform(I)); |
| for (const Use &U : I.operands()) { |
| if (auto *I = dyn_cast<Instruction>(&U)) { |
| if (DefCycle.contains(I->getParent())) |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| template <> |
| void llvm::GenericUniformityAnalysisImpl< |
| SSAContext>::propagateTemporalDivergence(const Instruction &I, |
| const Cycle &DefCycle) { |
| for (auto *User : I.users()) { |
| auto *UserInstr = cast<Instruction>(User); |
| if (DefCycle.contains(UserInstr->getParent())) |
| continue; |
| markDivergent(*UserInstr); |
| recordTemporalDivergence(&I, UserInstr, &DefCycle); |
| } |
| } |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse( |
| const Use &U) const { |
| const auto *V = U.get(); |
| if (isDivergent(V)) |
| return true; |
| if (const auto *DefInstr = dyn_cast<Instruction>(V)) { |
| const auto *UseInstr = cast<Instruction>(U.getUser()); |
| return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); |
| } |
| return false; |
| } |
| |
| template <> |
| bool GenericUniformityAnalysisImpl<SSAContext>::isCustomUniform( |
| const Instruction &I) const { |
| SmallBitVector UniformArgs(I.getNumOperands()); |
| for (auto [Idx, Use] : enumerate(I.operands())) |
| UniformArgs[Idx] = !isDivergentUse(Use); |
| return TTI->isUniform(&I, UniformArgs); |
| } |
| |
| // This ensures explicit instantiation of |
| // GenericUniformityAnalysisImpl::ImplDeleter::operator() |
| template class llvm::GenericUniformityInfo<SSAContext>; |
| template struct llvm::GenericUniformityAnalysisImplDeleter< |
| llvm::GenericUniformityAnalysisImpl<SSAContext>>; |
| |
| //===----------------------------------------------------------------------===// |
| // UniformityInfoAnalysis and related pass implementations |
| //===----------------------------------------------------------------------===// |
| |
| llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, |
| FunctionAnalysisManager &FAM) { |
| TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); |
| if (!TTI.hasBranchDivergence(&F)) |
| return UniformityInfo{}; |
| DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); |
| CycleInfo &CI = FAM.getResult<CycleAnalysis>(F); |
| UniformityInfo UI{DT, CI, &TTI}; |
| UI.compute(); |
| return UI; |
| } |
| |
| AnalysisKey UniformityInfoAnalysis::Key; |
| |
| UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS) |
| : OS(OS) {} |
| |
| PreservedAnalyses UniformityInfoPrinterPass::run(Function &F, |
| FunctionAnalysisManager &AM) { |
| OS << "UniformityInfo for function '" << F.getName() << "':\n"; |
| AM.getResult<UniformityInfoAnalysis>(F).print(OS); |
| |
| return PreservedAnalyses::all(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UniformityInfoWrapperPass Implementation |
| //===----------------------------------------------------------------------===// |
| |
| char UniformityInfoWrapperPass::ID = 0; |
| |
| UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {} |
| |
| INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", |
| "Uniformity Analysis", false, true) |
| INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
| INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", |
| "Uniformity Analysis", false, true) |
| |
| void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { |
| AU.setPreservesAll(); |
| AU.addRequired<DominatorTreeWrapperPass>(); |
| AU.addRequiredTransitive<CycleInfoWrapperPass>(); |
| AU.addRequired<TargetTransformInfoWrapperPass>(); |
| } |
| |
| bool UniformityInfoWrapperPass::runOnFunction(Function &F) { |
| TargetTransformInfo &TTI = |
| getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
| |
| Fn = &F; |
| |
| if (!TTI.hasBranchDivergence(Fn)) { |
| UI = UniformityInfo{}; |
| return false; |
| } |
| |
| CycleInfo &CI = getAnalysis<CycleInfoWrapperPass>().getResult(); |
| DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
| UI = UniformityInfo{DT, CI, &TTI}; |
| UI.compute(); |
| return false; |
| } |
| |
| void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { |
| OS << "UniformityInfo for function '" << Fn->getName() << "':\n"; |
| UI.print(OS); |
| } |
| |
| void UniformityInfoWrapperPass::releaseMemory() { |
| UI = UniformityInfo{}; |
| Fn = nullptr; |
| } |