blob: cf462ddf6f17cae5914b7febbfe8eb23a7b7abbe [file] [log] [blame]
//===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is a library that provides a shared implementation for command line
// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
// IR before JIT-compiling and executing the latter.
//
// The translation can be customized by providing an MLIR to MLIR
// transformation.
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Tools/ParseUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ToolOutputFile.h"
#include <cstdint>
#include <numeric>
#include <optional>
#include <utility>
#define DEBUG_TYPE "jit-runner"
using namespace mlir;
using llvm::Error;
namespace {
/// This options struct prevents the need for global static initializers, and
/// is only initialized if the JITRunner is invoked.
struct Options {
llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-")};
llvm::cl::opt<std::string> mainFuncName{
"e", llvm::cl::desc("The function to be called"),
llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
llvm::cl::opt<std::string> mainFuncType{
"entry-point-result",
llvm::cl::desc("Textual description of the function type to be called"),
llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
llvm::cl::OptionCategory optFlags{"opt-like flags"};
// CLI variables for -On options.
llvm::cl::opt<bool> optO0{"O0",
llvm::cl::desc("Run opt passes and codegen at O0"),
llvm::cl::cat(optFlags)};
llvm::cl::opt<bool> optO1{"O1",
llvm::cl::desc("Run opt passes and codegen at O1"),
llvm::cl::cat(optFlags)};
llvm::cl::opt<bool> optO2{"O2",
llvm::cl::desc("Run opt passes and codegen at O2"),
llvm::cl::cat(optFlags)};
llvm::cl::opt<bool> optO3{"O3",
llvm::cl::desc("Run opt passes and codegen at O3"),
llvm::cl::cat(optFlags)};
llvm::cl::list<std::string> mAttrs{
"mattr", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags)};
llvm::cl::opt<std::string> mArch{
"march",
llvm::cl::desc("Architecture to generate code for (see --version)")};
llvm::cl::OptionCategory clOptionsCategory{"linking options"};
llvm::cl::list<std::string> clSharedLibs{
"shared-libs", llvm::cl::desc("Libraries to link dynamically"),
llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
/// CLI variables for debugging.
llvm::cl::opt<bool> dumpObjectFile{
"dump-object-file",
llvm::cl::desc("Dump JITted-compiled object to file specified with "
"-object-filename (<input file>.o by default).")};
llvm::cl::opt<std::string> objectFilename{
"object-filename",
llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
llvm::cl::desc("Report host JIT support"),
llvm::cl::Hidden};
llvm::cl::opt<bool> noImplicitModule{
"no-implicit-module",
llvm::cl::desc(
"Disable implicit addition of a top-level module op during parsing"),
llvm::cl::init(false)};
};
struct CompileAndExecuteConfig {
/// LLVM module transformer that is passed to ExecutionEngine.
std::function<llvm::Error(llvm::Module *)> transformer;
/// A custom function that is passed to ExecutionEngine. It processes MLIR
/// module and creates LLVM IR module.
llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *,
llvm::LLVMContext &)>
llvmModuleBuilder;
/// A custom function that is passed to ExecutinEngine to register symbols at
/// runtime.
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
runtimeSymbolMap;
};
} // namespace
static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
bool insertImplicitModule,
MLIRContext *context) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
auto sourceMgr = std::make_shared<llvm::SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
OwningOpRef<Operation *> module =
parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
if (!module)
return nullptr;
if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
llvm::errs() << "Error: top-level op must be a symbol table.\n";
return nullptr;
}
return module;
}
static inline Error makeStringError(const Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
static std::optional<unsigned> getCommandLineOptLevel(Options &options) {
std::optional<unsigned> optLevel;
SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
options.optO0, options.optO1, options.optO2, options.optO3};
// Determine if there is an optimization flag present.
for (unsigned j = 0; j < 4; ++j) {
auto &flag = optFlags[j].get();
if (flag) {
optLevel = j;
break;
}
}
return optLevel;
}
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
static Error
compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
CompileAndExecuteConfig config, void **args,
std::unique_ptr<llvm::TargetMachine> tm = nullptr) {
std::optional<llvm::CodeGenOptLevel> jitCodeGenOptLevel;
if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(*clOptLevel);
SmallVector<StringRef, 4> sharedLibs(options.clSharedLibs.begin(),
options.clSharedLibs.end());
mlir::ExecutionEngineOptions engineOptions;
engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
if (config.transformer)
engineOptions.transformer = config.transformer;
engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
engineOptions.sharedLibPaths = sharedLibs;
engineOptions.enableObjectDump = true;
auto expectedEngine =
mlir::ExecutionEngine::create(module, engineOptions, std::move(tm));
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
if (options.dumpObjectFile)
engine->dumpToObjectFile(options.objectFilename.empty()
? options.inputFilename + ".o"
: options.objectFilename);
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(args);
return Error::success();
}
static Error compileAndExecuteVoidFunction(
Options &options, Operation *module, StringRef entryPoint,
CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(module, entryPoint));
if (!mainFunction || mainFunction.empty())
return makeStringError("entry point not found");
auto resultType = dyn_cast<LLVM::LLVMVoidType>(
mainFunction.getFunctionType().getReturnType());
if (!resultType)
return makeStringError("expected void function");
void *empty = nullptr;
return compileAndExecute(options, module, entryPoint, std::move(config),
&empty, std::move(tm));
}
template <typename Type>
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
template <>
Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
auto resultType = dyn_cast<IntegerType>(
cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
.getReturnType());
if (!resultType || resultType.getWidth() != 32)
return makeStringError("only single i32 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
auto resultType = dyn_cast<IntegerType>(
cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
.getReturnType());
if (!resultType || resultType.getWidth() != 64)
return makeStringError("only single i64 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
if (!isa<Float32Type>(
cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
.getReturnType()))
return makeStringError("only single f32 function result supported");
return Error::success();
}
template <typename Type>
Error compileAndExecuteSingleReturnFunction(
Options &options, Operation *module, StringRef entryPoint,
CompileAndExecuteConfig config, std::unique_ptr<llvm::TargetMachine> tm) {
auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(module, entryPoint));
if (!mainFunction || mainFunction.isExternal())
return makeStringError("entry point not found");
if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
.getNumParams() != 0)
return makeStringError("function inputs not supported");
if (Error error = checkCompatibleReturnType<Type>(mainFunction))
return error;
Type res;
struct {
void *data;
} data;
data.data = &res;
if (auto error =
compileAndExecute(options, module, entryPoint, std::move(config),
(void **)&data, std::move(tm)))
return error;
// Intentional printing of the output so we can test.
llvm::outs() << res << '\n';
return Error::success();
}
/// Entry point for all CPU runners. Expects the common argc/argv arguments for
/// standard C++ main functions.
int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
JitRunnerConfig config) {
llvm::ExitOnError exitOnErr;
// Create the options struct containing the command line options for the
// runner. This must come before the command line options are parsed.
Options options;
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
if (options.hostSupportsJit) {
auto j = llvm::orc::LLJITBuilder().create();
if (j)
llvm::outs() << "true\n";
else {
llvm::outs() << "false\n";
exitOnErr(j.takeError());
}
return 0;
}
std::optional<unsigned> optLevel = getCommandLineOptLevel(options);
SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
options.optO0, options.optO1, options.optO2, options.optO3};
MLIRContext context(registry);
auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
&context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";
return 1;
}
JitRunnerOptions runnerOptions{options.mainFuncName, options.mainFuncType};
if (config.mlirTransformer)
if (failed(config.mlirTransformer(m.get(), runnerOptions)))
return EXIT_FAILURE;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!tmBuilderOrError) {
llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
return EXIT_FAILURE;
}
// Configure TargetMachine builder based on the command line options
llvm::SubtargetFeatures features;
if (!options.mAttrs.empty()) {
for (StringRef attr : options.mAttrs)
features.AddFeature(attr);
tmBuilderOrError->addFeatures(features.getFeatures());
}
if (!options.mArch.empty()) {
tmBuilderOrError->getTargetTriple().setArchName(options.mArch);
}
// Build TargetMachine
auto tmOrError = tmBuilderOrError->createTargetMachine();
if (!tmOrError) {
llvm::errs() << "Failed to create a TargetMachine for the host\n";
exitOnErr(tmOrError.takeError());
}
LLVM_DEBUG({
llvm::dbgs() << " JITTargetMachineBuilder is "
<< llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError,
"\n");
});
CompileAndExecuteConfig compileAndExecuteConfig;
if (optLevel) {
compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
*optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
}
compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
// Get the function used to compile and execute the module.
using CompileAndExecuteFnT =
Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig,
std::unique_ptr<llvm::TargetMachine> tm);
auto compileAndExecuteFn =
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
.Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
.Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
.Case("f32", compileAndExecuteSingleReturnFunction<float>)
.Case("void", compileAndExecuteVoidFunction)
.Default(nullptr);
Error error = compileAndExecuteFn
? compileAndExecuteFn(
options, m.get(), options.mainFuncName.getValue(),
compileAndExecuteConfig, std::move(tmOrError.get()))
: makeStringError("unsupported function type");
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {
llvm::errs() << "Error: ";
info.log(llvm::errs());
llvm::errs() << '\n';
exitCode = EXIT_FAILURE;
});
return exitCode;
}