|  | //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "llvm/ADT/Triple.h" | 
|  | #include "llvm/Analysis/CGSCCPassManager.h" | 
|  | #include "llvm/Analysis/LoopAnalysisManager.h" | 
|  | #include "llvm/AsmParser/Parser.h" | 
|  | #include "llvm/CodeGen/MachineModuleInfo.h" | 
|  | #include "llvm/CodeGen/MachinePassManager.h" | 
|  | #include "llvm/IR/LLVMContext.h" | 
|  | #include "llvm/IR/Module.h" | 
|  | #include "llvm/Passes/PassBuilder.h" | 
|  | #include "llvm/Support/Host.h" | 
|  | #include "llvm/Support/SourceMgr.h" | 
|  | #include "llvm/Support/TargetRegistry.h" | 
|  | #include "llvm/Support/TargetSelect.h" | 
|  | #include "llvm/Target/TargetMachine.h" | 
|  | #include "gtest/gtest.h" | 
|  |  | 
|  | using namespace llvm; | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> { | 
|  | public: | 
|  | struct Result { | 
|  | Result(int Count) : InstructionCount(Count) {} | 
|  | int InstructionCount; | 
|  | }; | 
|  |  | 
|  | /// Run the analysis pass over the function and return a result. | 
|  | Result run(Function &F, FunctionAnalysisManager &AM) { | 
|  | int Count = 0; | 
|  | for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) | 
|  | for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE; | 
|  | ++II) | 
|  | ++Count; | 
|  | return Result(Count); | 
|  | } | 
|  |  | 
|  | private: | 
|  | friend AnalysisInfoMixin<TestFunctionAnalysis>; | 
|  | static AnalysisKey Key; | 
|  | }; | 
|  |  | 
|  | AnalysisKey TestFunctionAnalysis::Key; | 
|  |  | 
|  | class TestMachineFunctionAnalysis | 
|  | : public AnalysisInfoMixin<TestMachineFunctionAnalysis> { | 
|  | public: | 
|  | struct Result { | 
|  | Result(int Count) : InstructionCount(Count) {} | 
|  | int InstructionCount; | 
|  | }; | 
|  |  | 
|  | /// Run the analysis pass over the machine function and return a result. | 
|  | Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) { | 
|  | auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM); | 
|  | // Query function analysis result. | 
|  | TestFunctionAnalysis::Result &FAR = | 
|  | MFAM.getResult<TestFunctionAnalysis>(MF.getFunction()); | 
|  | // + 5 | 
|  | return FAR.InstructionCount; | 
|  | } | 
|  |  | 
|  | private: | 
|  | friend AnalysisInfoMixin<TestMachineFunctionAnalysis>; | 
|  | static AnalysisKey Key; | 
|  | }; | 
|  |  | 
|  | AnalysisKey TestMachineFunctionAnalysis::Key; | 
|  |  | 
|  | const std::string DoInitErrMsg = "doInitialization failed"; | 
|  | const std::string DoFinalErrMsg = "doFinalization failed"; | 
|  |  | 
|  | struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> { | 
|  | TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization, | 
|  | std::vector<int> &BeforeFinalization, | 
|  | std::vector<int> &MachineFunctionPassCount) | 
|  | : Count(Count), BeforeInitialization(BeforeInitialization), | 
|  | BeforeFinalization(BeforeFinalization), | 
|  | MachineFunctionPassCount(MachineFunctionPassCount) {} | 
|  |  | 
|  | Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) { | 
|  | // Force doInitialization fail by starting with big `Count`. | 
|  | if (Count > 10000) | 
|  | return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode()); | 
|  |  | 
|  | // + 1 | 
|  | ++Count; | 
|  | BeforeInitialization.push_back(Count); | 
|  | return Error::success(); | 
|  | } | 
|  | Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) { | 
|  | // Force doFinalization fail by starting with big `Count`. | 
|  | if (Count > 1000) | 
|  | return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode()); | 
|  |  | 
|  | // + 1 | 
|  | ++Count; | 
|  | BeforeFinalization.push_back(Count); | 
|  | return Error::success(); | 
|  | } | 
|  |  | 
|  | PreservedAnalyses run(MachineFunction &MF, | 
|  | MachineFunctionAnalysisManager &MFAM) { | 
|  | // Query function analysis result. | 
|  | TestFunctionAnalysis::Result &FAR = | 
|  | MFAM.getResult<TestFunctionAnalysis>(MF.getFunction()); | 
|  | // 3 + 1 + 1 = 5 | 
|  | Count += FAR.InstructionCount; | 
|  |  | 
|  | // Query module analysis result. | 
|  | MachineModuleInfo &MMI = | 
|  | MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent()); | 
|  | // 1 + 1 + 1 = 3 | 
|  | Count += (MMI.getModule() == MF.getFunction().getParent()); | 
|  |  | 
|  | // Query machine function analysis result. | 
|  | TestMachineFunctionAnalysis::Result &MFAR = | 
|  | MFAM.getResult<TestMachineFunctionAnalysis>(MF); | 
|  | // 3 + 1 + 1 = 5 | 
|  | Count += MFAR.InstructionCount; | 
|  |  | 
|  | MachineFunctionPassCount.push_back(Count); | 
|  |  | 
|  | return PreservedAnalyses::none(); | 
|  | } | 
|  |  | 
|  | int &Count; | 
|  | std::vector<int> &BeforeInitialization; | 
|  | std::vector<int> &BeforeFinalization; | 
|  | std::vector<int> &MachineFunctionPassCount; | 
|  | }; | 
|  |  | 
|  | struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> { | 
|  | TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount) | 
|  | : Count(Count), MachineModulePassCount(MachineModulePassCount) {} | 
|  |  | 
|  | Error run(Module &M, MachineFunctionAnalysisManager &MFAM) { | 
|  | MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M); | 
|  | // + 1 | 
|  | Count += (MMI.getModule() == &M); | 
|  | MachineModulePassCount.push_back(Count); | 
|  | return Error::success(); | 
|  | } | 
|  |  | 
|  | PreservedAnalyses run(MachineFunction &MF, | 
|  | MachineFunctionAnalysisManager &AM) { | 
|  | llvm_unreachable( | 
|  | "This should never be reached because this is machine module pass"); | 
|  | } | 
|  |  | 
|  | int &Count; | 
|  | std::vector<int> &MachineModulePassCount; | 
|  | }; | 
|  |  | 
|  | std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) { | 
|  | SMDiagnostic Err; | 
|  | return parseAssemblyString(IR, Err, Context); | 
|  | } | 
|  |  | 
|  | class PassManagerTest : public ::testing::Test { | 
|  | protected: | 
|  | LLVMContext Context; | 
|  | std::unique_ptr<Module> M; | 
|  | std::unique_ptr<TargetMachine> TM; | 
|  |  | 
|  | public: | 
|  | PassManagerTest() | 
|  | : M(parseIR(Context, "define void @f() {\n" | 
|  | "entry:\n" | 
|  | "  call void @g()\n" | 
|  | "  call void @h()\n" | 
|  | "  ret void\n" | 
|  | "}\n" | 
|  | "define void @g() {\n" | 
|  | "  ret void\n" | 
|  | "}\n" | 
|  | "define void @h() {\n" | 
|  | "  ret void\n" | 
|  | "}\n")) { | 
|  | // MachineModuleAnalysis needs a TargetMachine instance. | 
|  | llvm::InitializeAllTargets(); | 
|  |  | 
|  | std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple()); | 
|  | std::string Error; | 
|  | const Target *TheTarget = | 
|  | TargetRegistry::lookupTarget(TripleName, Error); | 
|  | if (!TheTarget) | 
|  | return; | 
|  |  | 
|  | TargetOptions Options; | 
|  | TM.reset(TheTarget->createTargetMachine(TripleName, "", "", | 
|  | Options, None)); | 
|  | } | 
|  | }; | 
|  |  | 
|  | TEST_F(PassManagerTest, Basic) { | 
|  | if (!TM) | 
|  | return; | 
|  |  | 
|  | LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get()); | 
|  | M->setDataLayout(TM->createDataLayout()); | 
|  |  | 
|  | LoopAnalysisManager LAM(/*DebugLogging=*/true); | 
|  | FunctionAnalysisManager FAM(/*DebugLogging=*/true); | 
|  | CGSCCAnalysisManager CGAM(/*DebugLogging=*/true); | 
|  | ModuleAnalysisManager MAM(/*DebugLogging=*/true); | 
|  | PassBuilder PB(TM.get()); | 
|  | PB.registerModuleAnalyses(MAM); | 
|  | PB.registerFunctionAnalyses(FAM); | 
|  | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); | 
|  |  | 
|  | FAM.registerPass([&] { return TestFunctionAnalysis(); }); | 
|  | FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); | 
|  | MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); }); | 
|  | MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); | 
|  |  | 
|  | MachineFunctionAnalysisManager MFAM; | 
|  | { | 
|  | // Test move assignment. | 
|  | MachineFunctionAnalysisManager NestedMFAM(FAM, MAM, | 
|  | /*DebugLogging*/ true); | 
|  | NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); }); | 
|  | NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); }); | 
|  | MFAM = std::move(NestedMFAM); | 
|  | } | 
|  |  | 
|  | int Count = 0; | 
|  | std::vector<int> BeforeInitialization[2]; | 
|  | std::vector<int> BeforeFinalization[2]; | 
|  | std::vector<int> TestMachineFunctionCount[2]; | 
|  | std::vector<int> TestMachineModuleCount[2]; | 
|  |  | 
|  | MachineFunctionPassManager MFPM; | 
|  | { | 
|  | // Test move assignment. | 
|  | MachineFunctionPassManager NestedMFPM(/*DebugLogging*/ true); | 
|  | NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0])); | 
|  | NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0], | 
|  | BeforeFinalization[0], | 
|  | TestMachineFunctionCount[0])); | 
|  | NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1])); | 
|  | NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], | 
|  | BeforeFinalization[1], | 
|  | TestMachineFunctionCount[1])); | 
|  | MFPM = std::move(NestedMFPM); | 
|  | } | 
|  |  | 
|  | ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM))); | 
|  |  | 
|  | // Check first machine module pass | 
|  | EXPECT_EQ(1u, TestMachineModuleCount[0].size()); | 
|  | EXPECT_EQ(3, TestMachineModuleCount[0][0]); | 
|  |  | 
|  | // Check first machine function pass | 
|  | EXPECT_EQ(1u, BeforeInitialization[0].size()); | 
|  | EXPECT_EQ(1, BeforeInitialization[0][0]); | 
|  | EXPECT_EQ(3u, TestMachineFunctionCount[0].size()); | 
|  | EXPECT_EQ(10, TestMachineFunctionCount[0][0]); | 
|  | EXPECT_EQ(13, TestMachineFunctionCount[0][1]); | 
|  | EXPECT_EQ(16, TestMachineFunctionCount[0][2]); | 
|  | EXPECT_EQ(1u, BeforeFinalization[0].size()); | 
|  | EXPECT_EQ(31, BeforeFinalization[0][0]); | 
|  |  | 
|  | // Check second machine module pass | 
|  | EXPECT_EQ(1u, TestMachineModuleCount[1].size()); | 
|  | EXPECT_EQ(17, TestMachineModuleCount[1][0]); | 
|  |  | 
|  | // Check second machine function pass | 
|  | EXPECT_EQ(1u, BeforeInitialization[1].size()); | 
|  | EXPECT_EQ(2, BeforeInitialization[1][0]); | 
|  | EXPECT_EQ(3u, TestMachineFunctionCount[1].size()); | 
|  | EXPECT_EQ(24, TestMachineFunctionCount[1][0]); | 
|  | EXPECT_EQ(27, TestMachineFunctionCount[1][1]); | 
|  | EXPECT_EQ(30, TestMachineFunctionCount[1][2]); | 
|  | EXPECT_EQ(1u, BeforeFinalization[1].size()); | 
|  | EXPECT_EQ(32, BeforeFinalization[1][0]); | 
|  |  | 
|  | EXPECT_EQ(32, Count); | 
|  |  | 
|  | // doInitialization returns error | 
|  | Count = 10000; | 
|  | MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], | 
|  | BeforeFinalization[1], | 
|  | TestMachineFunctionCount[1])); | 
|  | std::string Message; | 
|  | llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { | 
|  | Message = Error.getMessage(); | 
|  | }); | 
|  | EXPECT_EQ(Message, DoInitErrMsg); | 
|  |  | 
|  | // doFinalization returns error | 
|  | Count = 1000; | 
|  | MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], | 
|  | BeforeFinalization[1], | 
|  | TestMachineFunctionCount[1])); | 
|  | llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { | 
|  | Message = Error.getMessage(); | 
|  | }); | 
|  | EXPECT_EQ(Message, DoFinalErrMsg); | 
|  | } | 
|  |  | 
|  | } // namespace |