| //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// |
| // |
| // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This is a library that provides a shared implementation for command line |
| // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM |
| // IR before JIT-compiling and executing the latter. |
| // |
| // The translation can be customized by providing an MLIR to MLIR |
| // transformation. |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Support/JitRunner.h" |
| |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/ExecutionEngine/ExecutionEngine.h" |
| #include "mlir/ExecutionEngine/OptUtils.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Parser.h" |
| #include "mlir/Support/FileUtilities.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/LLVMContext.h" |
| #include "llvm/IR/LegacyPassNameParser.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FileUtilities.h" |
| #include "llvm/Support/InitLLVM.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "llvm/Support/StringSaver.h" |
| #include "llvm/Support/TargetSelect.h" |
| #include "llvm/Support/ToolOutputFile.h" |
| #include <numeric> |
| |
| using namespace mlir; |
| using llvm::Error; |
| |
| static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional, |
| llvm::cl::desc("<input file>"), |
| llvm::cl::init("-")); |
| static llvm::cl::opt<std::string> |
| mainFuncName("e", llvm::cl::desc("The function to be called"), |
| llvm::cl::value_desc("<function name>"), |
| llvm::cl::init("main")); |
| static llvm::cl::opt<std::string> mainFuncType( |
| "entry-point-result", |
| llvm::cl::desc("Textual description of the function type to be called"), |
| llvm::cl::value_desc("f32 | void"), llvm::cl::init("f32")); |
| |
| static llvm::cl::OptionCategory optFlags("opt-like flags"); |
| |
| // CLI list of pass information |
| static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> |
| llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"), |
| llvm::cl::cat(optFlags)); |
| |
| // CLI variables for -On options. |
| static llvm::cl::opt<bool> |
| optO0("O0", llvm::cl::desc("Run opt passes and codegen at O0"), |
| llvm::cl::cat(optFlags)); |
| static llvm::cl::opt<bool> |
| optO1("O1", llvm::cl::desc("Run opt passes and codegen at O1"), |
| llvm::cl::cat(optFlags)); |
| static llvm::cl::opt<bool> |
| optO2("O2", llvm::cl::desc("Run opt passes and codegen at O2"), |
| llvm::cl::cat(optFlags)); |
| static llvm::cl::opt<bool> |
| optO3("O3", llvm::cl::desc("Run opt passes and codegen at O3"), |
| llvm::cl::cat(optFlags)); |
| |
| static llvm::cl::OptionCategory clOptionsCategory("linking options"); |
| static llvm::cl::list<std::string> |
| clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"), |
| llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, |
| llvm::cl::cat(clOptionsCategory)); |
| |
| // CLI variables for debugging. |
| static llvm::cl::opt<bool> dumpObjectFile( |
| "dump-object-file", |
| llvm::cl::desc("Dump JITted-compiled object to file specified with " |
| "-object-filename (<input file>.o by default).")); |
| |
| static llvm::cl::opt<std::string> objectFilename( |
| "object-filename", |
| llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")); |
| |
| static OwningModuleRef parseMLIRInput(StringRef inputFilename, |
| MLIRContext *context) { |
| // Set up the input file. |
| std::string errorMessage; |
| auto file = openInputFile(inputFilename, &errorMessage); |
| if (!file) { |
| llvm::errs() << errorMessage << "\n"; |
| return nullptr; |
| } |
| |
| llvm::SourceMgr sourceMgr; |
| sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); |
| return OwningModuleRef(parseSourceFile(sourceMgr, context)); |
| } |
| |
| // Initialize the relevant subsystems of LLVM. |
| static void initializeLLVM() { |
| llvm::InitializeNativeTarget(); |
| llvm::InitializeNativeTargetAsmPrinter(); |
| } |
| |
| static inline Error make_string_error(const Twine &message) { |
| return llvm::make_error<llvm::StringError>(message.str(), |
| llvm::inconvertibleErrorCode()); |
| } |
| |
| static Optional<unsigned> getCommandLineOptLevel() { |
| Optional<unsigned> optLevel; |
| SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ |
| optO0, optO1, optO2, optO3}; |
| |
| // Determine if there is an optimization flag present. |
| for (unsigned j = 0; j < 4; ++j) { |
| auto &flag = optFlags[j].get(); |
| if (flag) { |
| optLevel = j; |
| break; |
| } |
| } |
| return optLevel; |
| } |
| |
| // JIT-compile the given module and run "entryPoint" with "args" as arguments. |
| static Error |
| compileAndExecute(ModuleOp module, StringRef entryPoint, |
| std::function<llvm::Error(llvm::Module *)> transformer, |
| void **args) { |
| Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; |
| if (auto clOptLevel = getCommandLineOptLevel()) |
| jitCodeGenOptLevel = |
| static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); |
| SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end()); |
| auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, |
| jitCodeGenOptLevel, libs); |
| if (!expectedEngine) |
| return expectedEngine.takeError(); |
| |
| auto engine = std::move(*expectedEngine); |
| auto expectedFPtr = engine->lookup(entryPoint); |
| if (!expectedFPtr) |
| return expectedFPtr.takeError(); |
| |
| if (dumpObjectFile) |
| engine->dumpToObjectFile(objectFilename.empty() ? inputFilename + ".o" |
| : objectFilename); |
| |
| void (*fptr)(void **) = *expectedFPtr; |
| (*fptr)(args); |
| |
| return Error::success(); |
| } |
| |
| static Error compileAndExecuteVoidFunction( |
| ModuleOp module, StringRef entryPoint, |
| std::function<llvm::Error(llvm::Module *)> transformer) { |
| auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); |
| if (!mainFunction || mainFunction.getBlocks().empty()) |
| return make_string_error("entry point not found"); |
| void *empty = nullptr; |
| return compileAndExecute(module, entryPoint, transformer, &empty); |
| } |
| |
| static Error compileAndExecuteSingleFloatReturnFunction( |
| ModuleOp module, StringRef entryPoint, |
| std::function<llvm::Error(llvm::Module *)> transformer) { |
| auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); |
| if (!mainFunction || mainFunction.isExternal()) |
| return make_string_error("entry point not found"); |
| |
| if (mainFunction.getType().getFunctionNumParams() != 0) |
| return make_string_error("function inputs not supported"); |
| |
| if (!mainFunction.getType().getFunctionResultType().isFloatTy()) |
| return make_string_error("only single llvm.f32 function result supported"); |
| |
| float res; |
| struct { |
| void *data; |
| } data; |
| data.data = &res; |
| if (auto error = |
| compileAndExecute(module, entryPoint, transformer, (void **)&data)) |
| return error; |
| |
| // Intentional printing of the output so we can test. |
| llvm::outs() << res << '\n'; |
| |
| return Error::success(); |
| } |
| |
| // Entry point for all CPU runners. Expects the common argc/argv arguments for |
| // standard C++ main functions and an mlirTransformer. |
| // The latter is applied after parsing the input into MLIR IR and before passing |
| // the MLIR module to the ExecutionEngine. |
| int mlir::JitRunnerMain( |
| int argc, char **argv, |
| function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { |
| llvm::InitLLVM y(argc, argv); |
| |
| initializeLLVM(); |
| mlir::initializeLLVMPasses(); |
| |
| llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); |
| |
| Optional<unsigned> optLevel = getCommandLineOptLevel(); |
| SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ |
| optO0, optO1, optO2, optO3}; |
| unsigned optCLIPosition = 0; |
| // Determine if there is an optimization flag present, and its CLI position |
| // (optCLIPosition). |
| for (unsigned j = 0; j < 4; ++j) { |
| auto &flag = optFlags[j].get(); |
| if (flag) { |
| optCLIPosition = flag.getPosition(); |
| break; |
| } |
| } |
| // Generate vector of pass information, plus the index at which we should |
| // insert any optimization passes in that vector (optPosition). |
| SmallVector<const llvm::PassInfo *, 4> passes; |
| unsigned optPosition = 0; |
| for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) { |
| passes.push_back(llvmPasses[i]); |
| if (optCLIPosition < llvmPasses.getPosition(i)) { |
| optPosition = i; |
| optCLIPosition = UINT_MAX; // To ensure we never insert again |
| } |
| } |
| |
| MLIRContext context; |
| auto m = parseMLIRInput(inputFilename, &context); |
| if (!m) { |
| llvm::errs() << "could not parse the input IR\n"; |
| return 1; |
| } |
| |
| if (mlirTransformer) |
| if (failed(mlirTransformer(m.get()))) |
| return EXIT_FAILURE; |
| |
| auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); |
| if (!tmBuilderOrError) { |
| llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; |
| return EXIT_FAILURE; |
| } |
| auto tmOrError = tmBuilderOrError->createTargetMachine(); |
| if (!tmOrError) { |
| llvm::errs() << "Failed to create a TargetMachine for the host\n"; |
| return EXIT_FAILURE; |
| } |
| |
| auto transformer = mlir::makeLLVMPassesTransformer( |
| passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); |
| |
| // Get the function used to compile and execute the module. |
| using CompileAndExecuteFnT = Error (*)( |
| ModuleOp, StringRef, std::function<llvm::Error(llvm::Module *)>); |
| auto compileAndExecuteFn = |
| llvm::StringSwitch<CompileAndExecuteFnT>(mainFuncType.getValue()) |
| .Case("f32", compileAndExecuteSingleFloatReturnFunction) |
| .Case("void", compileAndExecuteVoidFunction) |
| .Default(nullptr); |
| |
| Error error = |
| compileAndExecuteFn |
| ? compileAndExecuteFn(m.get(), mainFuncName.getValue(), transformer) |
| : make_string_error("unsupported function type"); |
| |
| int exitCode = EXIT_SUCCESS; |
| llvm::handleAllErrors(std::move(error), |
| [&exitCode](const llvm::ErrorInfoBase &info) { |
| llvm::errs() << "Error: "; |
| info.log(llvm::errs()); |
| llvm::errs() << '\n'; |
| exitCode = EXIT_FAILURE; |
| }); |
| |
| return exitCode; |
| } |