| //===-------- SYCLSplitModule.cpp - Split a module into call graphs -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // See comments in the header. |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Transforms/Utils/SYCLSplitModule.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallPtrSet.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/Bitcode/BitcodeWriterPass.h" |
| #include "llvm/IR/Constants.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/InstIterator.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/PassManager.h" |
| #include "llvm/IR/PassManagerImpl.h" |
| #include "llvm/IRPrinter/IRPrintingPasses.h" |
| #include "llvm/Support/Compiler.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/Error.h" |
| #include "llvm/Support/FileSystem.h" |
| #include "llvm/Support/LineIterator.h" |
| #include "llvm/Support/MemoryBuffer.h" |
| #include "llvm/Transforms/IPO/GlobalDCE.h" |
| #include "llvm/Transforms/IPO/StripDeadPrototypes.h" |
| #include "llvm/Transforms/IPO/StripSymbols.h" |
| #include "llvm/Transforms/Utils/Cloning.h" |
| #include "llvm/Transforms/Utils/SYCLUtils.h" |
| |
| #include <map> |
| #include <utility> |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "sycl-split-module" |
| |
| static bool isKernel(const Function &F) { |
| return F.getCallingConv() == CallingConv::SPIR_KERNEL || |
| F.getCallingConv() == CallingConv::AMDGPU_KERNEL; |
| } |
| |
| static bool isEntryPoint(const Function &F) { |
| // Skip declarations, if any: they should not be included into a vector of |
| // entry points groups or otherwise we will end up with incorrectly generated |
| // list of symbols. |
| if (F.isDeclaration()) |
| return false; |
| |
| // Kernels are always considered to be entry points |
| return isKernel(F); |
| } |
| |
| namespace { |
| |
| // A vector that contains all entry point functions in a split module. |
| using EntryPointSet = SetVector<const Function *>; |
| |
| /// Represents a named group entry points. |
| struct EntryPointGroup { |
| std::string GroupName; |
| EntryPointSet Functions; |
| |
| EntryPointGroup() = default; |
| EntryPointGroup(const EntryPointGroup &) = default; |
| EntryPointGroup &operator=(const EntryPointGroup &) = default; |
| EntryPointGroup(EntryPointGroup &&) = default; |
| EntryPointGroup &operator=(EntryPointGroup &&) = default; |
| |
| EntryPointGroup(StringRef GroupName, |
| EntryPointSet Functions = EntryPointSet()) |
| : GroupName(GroupName), Functions(std::move(Functions)) {} |
| |
| #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| LLVM_DUMP_METHOD void dump() const { |
| constexpr size_t INDENT = 4; |
| dbgs().indent(INDENT) << "ENTRY POINTS" |
| << " " << GroupName << " {\n"; |
| for (const Function *F : Functions) |
| dbgs().indent(INDENT) << " " << F->getName() << "\n"; |
| |
| dbgs().indent(INDENT) << "}\n"; |
| } |
| #endif |
| }; |
| |
| /// Annotates an llvm::Module with information necessary to perform and track |
| /// the result of device code (llvm::Module instances) splitting: |
| /// - entry points group from the module. |
| class ModuleDesc { |
| std::unique_ptr<Module> M; |
| EntryPointGroup EntryPoints; |
| |
| public: |
| ModuleDesc() = delete; |
| ModuleDesc(const ModuleDesc &) = delete; |
| ModuleDesc &operator=(const ModuleDesc &) = delete; |
| ModuleDesc(ModuleDesc &&) = default; |
| ModuleDesc &operator=(ModuleDesc &&) = default; |
| |
| ModuleDesc(std::unique_ptr<Module> M, |
| EntryPointGroup EntryPoints = EntryPointGroup()) |
| : M(std::move(M)), EntryPoints(std::move(EntryPoints)) { |
| assert(this->M && "Module should be non-null"); |
| } |
| |
| const EntryPointSet &entries() const { return EntryPoints.Functions; } |
| const EntryPointGroup &getEntryPointGroup() const { return EntryPoints; } |
| EntryPointSet &entries() { return EntryPoints.Functions; } |
| Module &getModule() { return *M; } |
| const Module &getModule() const { return *M; } |
| |
| // Cleans up module IR - removes dead globals, debug info etc. |
| void cleanup() { |
| ModuleAnalysisManager MAM; |
| MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); |
| ModulePassManager MPM; |
| MPM.addPass(GlobalDCEPass()); // Delete unreachable globals. |
| MPM.addPass(StripDeadDebugInfoPass()); // Remove dead debug info. |
| MPM.addPass(StripDeadPrototypesPass()); // Remove dead func decls. |
| MPM.run(*M, MAM); |
| } |
| |
| std::string makeSymbolTable() const { |
| SmallString<128> ST; |
| for (const Function *F : EntryPoints.Functions) { |
| ST += F->getName(); |
| ST += "\n"; |
| } |
| |
| return std::string(ST); |
| } |
| |
| #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| LLVM_DUMP_METHOD void dump() const { |
| dbgs() << "ModuleDesc[" << M->getName() << "] {\n"; |
| EntryPoints.dump(); |
| dbgs() << "}\n"; |
| } |
| #endif |
| }; |
| |
| // Represents "dependency" or "use" graph of global objects (functions and |
| // global variables) in a module. It is used during device code split to |
| // understand which global variables and functions (other than entry points) |
| // should be included into a split module. |
| // |
| // Nodes of the graph represent LLVM's GlobalObjects, edges "A" -> "B" represent |
| // the fact that if "A" is included into a module, then "B" should be included |
| // as well. |
| // |
| // Examples of dependencies which are represented in this graph: |
| // - Function FA calls function FB |
| // - Function FA uses global variable GA |
| // - Global variable GA references (initialized with) function FB |
| // - Function FA stores address of a function FB somewhere |
| // |
| // The following cases are treated as dependencies between global objects: |
| // 1. Global object A is used within by a global object B in any way (store, |
| // bitcast, phi node, call, etc.): "A" -> "B" edge will be added to the |
| // graph; |
| // 2. function A performs an indirect call of a function with signature S and |
| // there is a function B with signature S. "A" -> "B" edge will be added to |
| // the graph; |
| class DependencyGraph { |
| public: |
| using GlobalSet = SmallPtrSet<const GlobalValue *, 16>; |
| |
| DependencyGraph(const Module &M) { |
| // Group functions by their signature to handle case (2) described above |
| DenseMap<const FunctionType *, DependencyGraph::GlobalSet> |
| FuncTypeToFuncsMap; |
| for (const auto &F : M.functions()) { |
| // Kernels can't be called (either directly or indirectly) in SYCL |
| if (isKernel(F)) |
| continue; |
| |
| FuncTypeToFuncsMap[F.getFunctionType()].insert(&F); |
| } |
| |
| for (const auto &F : M.functions()) { |
| // case (1), see comment above the class definition |
| for (const Value *U : F.users()) |
| addUserToGraphRecursively(cast<const User>(U), &F); |
| |
| // case (2), see comment above the class definition |
| for (const auto &I : instructions(F)) { |
| const auto *CI = dyn_cast<CallInst>(&I); |
| if (!CI || !CI->isIndirectCall()) // Direct calls were handled above |
| continue; |
| |
| const FunctionType *Signature = CI->getFunctionType(); |
| const auto &PotentialCallees = FuncTypeToFuncsMap[Signature]; |
| Graph[&F].insert(PotentialCallees.begin(), PotentialCallees.end()); |
| } |
| } |
| |
| // And every global variable (but their handling is a bit simpler) |
| for (const auto &GV : M.globals()) |
| for (const Value *U : GV.users()) |
| addUserToGraphRecursively(cast<const User>(U), &GV); |
| } |
| |
| iterator_range<GlobalSet::const_iterator> |
| dependencies(const GlobalValue *Val) const { |
| auto It = Graph.find(Val); |
| return (It == Graph.end()) |
| ? make_range(EmptySet.begin(), EmptySet.end()) |
| : make_range(It->second.begin(), It->second.end()); |
| } |
| |
| private: |
| void addUserToGraphRecursively(const User *Root, const GlobalValue *V) { |
| SmallVector<const User *, 8> WorkList; |
| WorkList.push_back(Root); |
| |
| while (!WorkList.empty()) { |
| const User *U = WorkList.pop_back_val(); |
| if (const auto *I = dyn_cast<const Instruction>(U)) { |
| const auto *UFunc = I->getFunction(); |
| Graph[UFunc].insert(V); |
| } else if (isa<const Constant>(U)) { |
| if (const auto *GV = dyn_cast<const GlobalVariable>(U)) |
| Graph[GV].insert(V); |
| // This could be a global variable or some constant expression (like |
| // bitcast or gep). We trace users of this constant further to reach |
| // global objects they are used by and add them to the graph. |
| for (const auto *UU : U->users()) |
| WorkList.push_back(UU); |
| } else |
| llvm_unreachable("Unhandled type of function user"); |
| } |
| } |
| |
| DenseMap<const GlobalValue *, GlobalSet> Graph; |
| SmallPtrSet<const GlobalValue *, 1> EmptySet; |
| }; |
| |
| void collectFunctionsAndGlobalVariablesToExtract( |
| SetVector<const GlobalValue *> &GVs, const Module &M, |
| const EntryPointGroup &ModuleEntryPoints, const DependencyGraph &DG) { |
| // We start with module entry points |
| for (const auto *F : ModuleEntryPoints.Functions) |
| GVs.insert(F); |
| |
| // Non-discardable global variables are also include into the initial set |
| for (const auto &GV : M.globals()) |
| if (!GV.isDiscardableIfUnused()) |
| GVs.insert(&GV); |
| |
| // GVs has SetVector type. This type inserts a value only if it is not yet |
| // present there. So, recursion is not expected here. |
| size_t Idx = 0; |
| while (Idx < GVs.size()) { |
| const GlobalValue *Obj = GVs[Idx++]; |
| |
| for (const GlobalValue *Dep : DG.dependencies(Obj)) { |
| if (const auto *Func = dyn_cast<const Function>(Dep)) { |
| if (!Func->isDeclaration()) |
| GVs.insert(Func); |
| } else |
| GVs.insert(Dep); // Global variables are added unconditionally |
| } |
| } |
| } |
| |
| ModuleDesc extractSubModule(const ModuleDesc &MD, |
| const SetVector<const GlobalValue *> &GVs, |
| EntryPointGroup ModuleEntryPoints) { |
| const Module &M = MD.getModule(); |
| // For each group of entry points collect all dependencies. |
| ValueToValueMapTy VMap; |
| // Clone definitions only for needed globals. Others will be added as |
| // declarations and removed later. |
| std::unique_ptr<Module> SubM = CloneModule( |
| M, VMap, [&](const GlobalValue *GV) { return GVs.count(GV); }); |
| // Replace entry points with cloned ones. |
| EntryPointSet NewEPs; |
| const EntryPointSet &EPs = ModuleEntryPoints.Functions; |
| std::for_each(EPs.begin(), EPs.end(), [&](const Function *F) { |
| NewEPs.insert(cast<Function>(VMap[F])); |
| }); |
| ModuleEntryPoints.Functions = std::move(NewEPs); |
| return ModuleDesc{std::move(SubM), std::move(ModuleEntryPoints)}; |
| } |
| |
| // The function produces a copy of input LLVM IR module M with only those |
| // functions and globals that can be called from entry points that are specified |
| // in ModuleEntryPoints vector, in addition to the entry point functions. |
| ModuleDesc extractCallGraph(const ModuleDesc &MD, |
| EntryPointGroup ModuleEntryPoints, |
| const DependencyGraph &DG) { |
| SetVector<const GlobalValue *> GVs; |
| collectFunctionsAndGlobalVariablesToExtract(GVs, MD.getModule(), |
| ModuleEntryPoints, DG); |
| |
| ModuleDesc SplitM = extractSubModule(MD, GVs, std::move(ModuleEntryPoints)); |
| LLVM_DEBUG(SplitM.dump()); |
| SplitM.cleanup(); |
| return SplitM; |
| } |
| |
| using EntryPointGroupVec = SmallVector<EntryPointGroup, 0>; |
| |
| /// Module Splitter. |
| /// It gets a module (in a form of module descriptor, to get additional info) |
| /// and a collection of entry points groups. Each group specifies subset entry |
| /// points from input module that should be included in a split module. |
| class ModuleSplitter { |
| private: |
| ModuleDesc Input; |
| EntryPointGroupVec Groups; |
| DependencyGraph DG; |
| |
| private: |
| EntryPointGroup drawEntryPointGroup() { |
| assert(Groups.size() > 0 && "Reached end of entry point groups list."); |
| EntryPointGroup Group = std::move(Groups.back()); |
| Groups.pop_back(); |
| return Group; |
| } |
| |
| public: |
| ModuleSplitter(ModuleDesc MD, EntryPointGroupVec GroupVec) |
| : Input(std::move(MD)), Groups(std::move(GroupVec)), |
| DG(Input.getModule()) { |
| assert(!Groups.empty() && "Entry points groups collection is empty!"); |
| } |
| |
| /// Gets next subsequence of entry points in an input module and provides |
| /// split submodule containing these entry points and their dependencies. |
| ModuleDesc getNextSplit() { |
| return extractCallGraph(Input, drawEntryPointGroup(), DG); |
| } |
| |
| /// Check that there are still submodules to split. |
| bool hasMoreSplits() const { return Groups.size() > 0; } |
| }; |
| |
| } // namespace |
| |
| static EntryPointGroupVec selectEntryPointGroups(const Module &M, |
| IRSplitMode Mode) { |
| // std::map is used here to ensure stable ordering of entry point groups, |
| // which is based on their contents, this greatly helps LIT tests |
| std::map<std::string, EntryPointSet> EntryPointsMap; |
| |
| static constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id"; |
| for (const auto &F : M.functions()) { |
| if (!isEntryPoint(F)) |
| continue; |
| |
| std::string Key; |
| switch (Mode) { |
| case IRSplitMode::IRSM_PER_KERNEL: |
| Key = F.getName(); |
| break; |
| case IRSplitMode::IRSM_PER_TU: |
| Key = F.getFnAttribute(ATTR_SYCL_MODULE_ID).getValueAsString(); |
| break; |
| case IRSplitMode::IRSM_NONE: |
| llvm_unreachable(""); |
| } |
| |
| EntryPointsMap[Key].insert(&F); |
| } |
| |
| EntryPointGroupVec Groups; |
| if (EntryPointsMap.empty()) { |
| // No entry points met, record this. |
| Groups.emplace_back("-", EntryPointSet()); |
| } else { |
| Groups.reserve(EntryPointsMap.size()); |
| // Start with properties of a source module |
| for (auto &[Key, EntryPoints] : EntryPointsMap) |
| Groups.emplace_back(Key, std::move(EntryPoints)); |
| } |
| |
| return Groups; |
| } |
| |
| static Error saveModuleIRInFile(Module &M, StringRef FilePath, |
| bool OutputAssembly) { |
| int FD = -1; |
| if (std::error_code EC = sys::fs::openFileForWrite(FilePath, FD)) |
| return errorCodeToError(EC); |
| |
| raw_fd_ostream OS(FD, true); |
| ModulePassManager MPM; |
| ModuleAnalysisManager MAM; |
| MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); |
| if (OutputAssembly) |
| MPM.addPass(PrintModulePass(OS)); |
| else |
| MPM.addPass(BitcodeWriterPass(OS)); |
| |
| MPM.run(M, MAM); |
| return Error::success(); |
| } |
| |
| static Expected<ModuleAndSYCLMetadata> |
| saveModuleDesc(ModuleDesc &MD, std::string Prefix, bool OutputAssembly) { |
| Prefix += OutputAssembly ? ".ll" : ".bc"; |
| if (Error E = saveModuleIRInFile(MD.getModule(), Prefix, OutputAssembly)) |
| return E; |
| |
| ModuleAndSYCLMetadata SM; |
| SM.ModuleFilePath = Prefix; |
| SM.Symbols = MD.makeSymbolTable(); |
| return SM; |
| } |
| |
| namespace llvm { |
| |
| Expected<SmallVector<ModuleAndSYCLMetadata, 0>> |
| parseModuleAndSYCLMetadataFromFile(StringRef File) { |
| auto EntriesMBOrErr = llvm::MemoryBuffer::getFile(File); |
| if (!EntriesMBOrErr) |
| return createFileError(File, EntriesMBOrErr.getError()); |
| |
| line_iterator LI(**EntriesMBOrErr); |
| if (LI.is_at_eof() || *LI != "[Code|Symbols]") |
| return createStringError(inconvertibleErrorCode(), |
| "invalid SYCL Table file."); |
| |
| // "Code" and "Symbols" at the moment. |
| static constexpr int NUMBER_COLUMNS = 2; |
| ++LI; |
| SmallVector<ModuleAndSYCLMetadata, 0> Modules; |
| while (!LI.is_at_eof()) { |
| StringRef Line = *LI; |
| if (Line.empty()) |
| return createStringError("invalid SYCL table row."); |
| |
| SmallVector<StringRef, NUMBER_COLUMNS> Parts; |
| Line.split(Parts, "|"); |
| if (Parts.size() != NUMBER_COLUMNS) |
| return createStringError("invalid SYCL Table row."); |
| |
| auto [IRFilePath, SymbolsFilePath] = std::tie(Parts[0], Parts[1]); |
| if (SymbolsFilePath.empty()) |
| return createStringError("invalid SYCL Table row."); |
| |
| auto MBOrErr = MemoryBuffer::getFile(SymbolsFilePath); |
| if (!MBOrErr) |
| return createFileError(SymbolsFilePath, MBOrErr.getError()); |
| |
| auto &MB2 = *MBOrErr; |
| std::string Symbols = |
| std::string(MB2->getBufferStart(), MB2->getBufferEnd()); |
| Modules.emplace_back(IRFilePath, std::move(Symbols)); |
| ++LI; |
| } |
| |
| return Modules; |
| } |
| |
| std::optional<IRSplitMode> convertStringToSplitMode(StringRef S) { |
| static const StringMap<IRSplitMode> Values = { |
| {"source", IRSplitMode::IRSM_PER_TU}, |
| {"kernel", IRSplitMode::IRSM_PER_KERNEL}, |
| {"none", IRSplitMode::IRSM_NONE}}; |
| |
| auto It = Values.find(S); |
| if (It == Values.end()) |
| return std::nullopt; |
| |
| return It->second; |
| } |
| |
| Expected<SmallVector<ModuleAndSYCLMetadata, 0>> |
| SYCLSplitModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) { |
| SmallVector<ModuleAndSYCLMetadata, 0> OutputImages; |
| if (Settings.Mode == IRSplitMode::IRSM_NONE) { |
| ModuleDesc MD = std::move(M); |
| std::string OutIRFileName = (Settings.OutputPrefix + Twine("_0")).str(); |
| auto ImageOrErr = |
| saveModuleDesc(MD, OutIRFileName, Settings.OutputAssembly); |
| if (!ImageOrErr) |
| return ImageOrErr.takeError(); |
| |
| OutputImages.emplace_back(std::move(*ImageOrErr)); |
| return OutputImages; |
| } |
| |
| EntryPointGroupVec Groups = selectEntryPointGroups(*M, Settings.Mode); |
| ModuleDesc MD = std::move(M); |
| ModuleSplitter Splitter(std::move(MD), std::move(Groups)); |
| size_t ID = 0; |
| while (Splitter.hasMoreSplits()) { |
| ModuleDesc MD = Splitter.getNextSplit(); |
| |
| std::string OutIRFileName = (Settings.OutputPrefix + "_" + Twine(ID)).str(); |
| auto SplitImageOrErr = |
| saveModuleDesc(MD, OutIRFileName, Settings.OutputAssembly); |
| if (!SplitImageOrErr) |
| return SplitImageOrErr.takeError(); |
| |
| OutputImages.emplace_back(std::move(*SplitImageOrErr)); |
| ++ID; |
| } |
| |
| return OutputImages; |
| } |
| |
| } // namespace llvm |