| //===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===// |
| // |
| // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| #include "mlir/Dialect/Transform/IR/Utils.h" |
| #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" |
| #include "mlir/IR/AsmState.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/DialectRegistry.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/InitAllDialects.h" |
| #include "mlir/InitAllExtensions.h" |
| #include "mlir/InitAllPasses.h" |
| #include "mlir/Parser/Parser.h" |
| #include "mlir/Support/FileUtilities.h" |
| #include "mlir/Tools/mlir-opt/MlirOptMain.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/InitLLVM.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "llvm/Support/ToolOutputFile.h" |
| #include <cstdlib> |
| |
| namespace { |
| |
| using namespace llvm; |
| |
| /// Structure containing command line options for the tool, these will get |
| /// initialized when an instance is created. |
| struct MlirTransformOptCLOptions { |
| cl::opt<bool> allowUnregisteredDialects{ |
| "allow-unregistered-dialect", |
| cl::desc("Allow operations coming from an unregistered dialect"), |
| cl::init(false)}; |
| |
| cl::opt<bool> verifyDiagnostics{ |
| "verify-diagnostics", |
| cl::desc("Check that emitted diagnostics match expected-* lines " |
| "on the corresponding line"), |
| cl::init(false)}; |
| |
| cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"), |
| cl::init("-")}; |
| |
| cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"), |
| cl::value_desc("filename"), |
| cl::init("-")}; |
| |
| cl::opt<std::string> transformMainFilename{ |
| "transform", |
| cl::desc("File containing entry point of the transform script, if " |
| "different from the input file"), |
| cl::value_desc("filename"), cl::init("")}; |
| |
| cl::list<std::string> transformLibraryFilenames{ |
| "transform-library", cl::desc("File(s) containing definitions of " |
| "additional transform script symbols")}; |
| |
| cl::opt<std::string> transformEntryPoint{ |
| "transform-entry-point", |
| cl::desc("Name of the entry point transform symbol"), |
| cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName |
| .str())}; |
| |
| cl::opt<bool> disableExpensiveChecks{ |
| "disable-expensive-checks", |
| cl::desc("Disables potentially expensive checks in the transform " |
| "interpreter, providing more speed at the expense of " |
| "potential memory problems and silent corruptions"), |
| cl::init(false)}; |
| |
| cl::opt<bool> dumpLibraryModule{ |
| "dump-library-module", |
| cl::desc("Prints the combined library module before the output"), |
| cl::init(false)}; |
| }; |
| } // namespace |
| |
| /// "Managed" static instance of the command-line options structure. This makes |
| /// them locally-scoped and explicitly initialized/deinitialized. While this is |
| /// not strictly necessary in the tool source file that is not being used as a |
| /// library (where the options would pollute the global list of options), it is |
| /// good practice to follow this. |
| static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions; |
| |
| /// Explicitly registers command-line options. |
| static void registerCLOptions() { *clOptions; } |
| |
| namespace { |
| /// A wrapper class for source managers diagnostic. This provides both unique |
| /// ownership and virtual function-like overload for a pair of |
| /// inheritance-related classes that do not use virtual functions. |
| class DiagnosticHandlerWrapper { |
| public: |
| /// Kind of the diagnostic handler to use. |
| enum class Kind { EmitDiagnostics, VerifyDiagnostics }; |
| |
| /// Constructs the diagnostic handler of the specified kind of the given |
| /// source manager and context. |
| DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr, |
| mlir::MLIRContext *context) { |
| if (kind == Kind::EmitDiagnostics) |
| handler = new mlir::SourceMgrDiagnosticHandler(mgr, context); |
| else |
| handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context); |
| } |
| |
| /// This object is non-copyable but movable. |
| DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete; |
| DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default; |
| DiagnosticHandlerWrapper & |
| operator=(const DiagnosticHandlerWrapper &) = delete; |
| DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default; |
| |
| /// Verifies the captured "expected-*" diagnostics if required. |
| mlir::LogicalResult verify() const { |
| if (auto *ptr = |
| handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) { |
| return ptr->verify(); |
| } |
| return mlir::success(); |
| } |
| |
| /// Destructs the object of the same type as allocated. |
| ~DiagnosticHandlerWrapper() { |
| if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) { |
| delete ptr; |
| } else { |
| delete handler.get<mlir::SourceMgrDiagnosticVerifierHandler *>(); |
| } |
| } |
| |
| private: |
| /// Internal storage is a type-safe union. |
| llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *, |
| mlir::SourceMgrDiagnosticVerifierHandler *> |
| handler; |
| }; |
| |
| /// MLIR has deeply rooted expectations that the LLVM source manager contains |
| /// exactly one buffer, until at least the lexer level. This class wraps |
| /// multiple LLVM source managers each managing a buffer to match MLIR's |
| /// expectations while still providing a centralized handling mechanism. |
| class TransformSourceMgr { |
| public: |
| /// Constructs the source manager indicating whether diagnostic messages will |
| /// be verified later on. |
| explicit TransformSourceMgr(bool verifyDiagnostics) |
| : verifyDiagnostics(verifyDiagnostics) {} |
| |
| /// Deconstructs the source manager. Note that `checkResults` must have been |
| /// called on this instance before deconstructing it. |
| ~TransformSourceMgr() { |
| assert(resultChecked && "must check the result of diagnostic handlers by " |
| "running TransformSourceMgr::checkResult"); |
| } |
| |
| /// Parses the given buffer and creates the top-level operation of the kind |
| /// specified as template argument in the given context. Additional parsing |
| /// options may be provided. |
| template <typename OpTy = mlir::Operation *> |
| mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer, |
| mlir::MLIRContext &context, |
| const mlir::ParserConfig &config) { |
| // Create a single-buffer LLVM source manager. Note that `unique_ptr` allows |
| // the code below to capture a reference to the source manager in such a way |
| // that it is not invalidated when the vector contents is eventually |
| // reallocated. |
| llvm::SourceMgr &mgr = |
| *sourceMgrs.emplace_back(std::make_unique<llvm::SourceMgr>()); |
| mgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); |
| |
| // Choose the type of diagnostic handler depending on whether diagnostic |
| // verification needs to happen and store it. |
| if (verifyDiagnostics) { |
| diagHandlers.emplace_back( |
| DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context); |
| } else { |
| diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics, |
| mgr, &context); |
| } |
| |
| // Defer to MLIR's parser. |
| return mlir::parseSourceFile<OpTy>(mgr, config); |
| } |
| |
| /// If diagnostic message verification has been requested upon construction of |
| /// this source manager, performs the verification, reports errors and returns |
| /// the result of the verification. Otherwise passes through the given value. |
| mlir::LogicalResult checkResult(mlir::LogicalResult result) { |
| resultChecked = true; |
| if (!verifyDiagnostics) |
| return result; |
| |
| return mlir::failure(llvm::any_of(diagHandlers, [](const auto &handler) { |
| return mlir::failed(handler.verify()); |
| })); |
| } |
| |
| private: |
| /// Indicates whether diagnostic message verification is requested. |
| const bool verifyDiagnostics; |
| |
| /// Indicates that diagnostic message verification has taken place, and the |
| /// deconstruction is therefore safe. |
| bool resultChecked = false; |
| |
| /// Storage for per-buffer source managers and diagnostic handlers. These are |
| /// wrapped into unique pointers in order to make it safe to capture |
| /// references to these objects: if the vector is reallocated, the unique |
| /// pointer objects are moved by the pointer addresses won't change. Also, for |
| /// handlers, this allows to store the pointer to the base class. |
| SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs; |
| SmallVector<DiagnosticHandlerWrapper> diagHandlers; |
| }; |
| } // namespace |
| |
| /// Trivial wrapper around `applyTransforms` that doesn't support extra mapping |
| /// and doesn't enforce the entry point transform ops being top-level. |
| static mlir::LogicalResult |
| applyTransforms(mlir::Operation *payloadRoot, |
| mlir::transform::TransformOpInterface transformRoot, |
| const mlir::transform::TransformOptions &options) { |
| return applyTransforms(payloadRoot, transformRoot, {}, options, |
| /*enforceToplevelTransformOp=*/false); |
| } |
| |
| /// Applies transforms indicated in the transform dialect script to the input |
| /// buffer. The transform script may be embedded in the input buffer or as a |
| /// separate buffer. The transform script may have external symbols, the |
| /// definitions of which must be provided in transform library buffers. If the |
| /// application is successful, prints the transformed input buffer into the |
| /// given output stream. Additional configuration options are derived from |
| /// command-line options. |
| static mlir::LogicalResult processPayloadBuffer( |
| raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer, |
| std::unique_ptr<llvm::MemoryBuffer> transformBuffer, |
| MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries, |
| mlir::DialectRegistry ®istry) { |
| |
| // Initialize the MLIR context, and various configurations. |
| mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED); |
| context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects); |
| mlir::ParserConfig config(&context); |
| TransformSourceMgr sourceMgr( |
| /*verifyDiagnostics=*/clOptions->verifyDiagnostics); |
| |
| // Parse the input buffer that will be used as transform payload. |
| mlir::OwningOpRef<mlir::Operation *> payloadRoot = |
| sourceMgr.parseBuffer(std::move(inputBuffer), context, config); |
| if (!payloadRoot) |
| return sourceMgr.checkResult(mlir::failure()); |
| |
| // Identify the module containing the transform script entry point. This may |
| // be the same module as the input or a separate module. In the former case, |
| // make a copy of the module so it can be modified freely. Modification may |
| // happen in the script itself (at which point it could be rewriting itself |
| // during interpretation, leading to tricky memory errors) or by embedding |
| // library modules in the script. |
| mlir::OwningOpRef<mlir::ModuleOp> transformRoot; |
| if (transformBuffer) { |
| transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>( |
| std::move(transformBuffer), context, config); |
| if (!transformRoot) |
| return sourceMgr.checkResult(mlir::failure()); |
| } else { |
| transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone()); |
| } |
| |
| // Parse and merge the libraries into the main transform module. |
| for (auto &&transformLibrary : transformLibraries) { |
| mlir::OwningOpRef<mlir::ModuleOp> libraryModule = |
| sourceMgr.parseBuffer<mlir::ModuleOp>(std::move(transformLibrary), |
| context, config); |
| |
| if (!libraryModule || |
| mlir::failed(mlir::transform::detail::mergeSymbolsInto( |
| *transformRoot, std::move(libraryModule)))) |
| return sourceMgr.checkResult(mlir::failure()); |
| } |
| |
| // If requested, dump the combined transform module. |
| if (clOptions->dumpLibraryModule) |
| transformRoot->dump(); |
| |
| // Find the entry point symbol. Even if it had originally been in the payload |
| // module, it was cloned into the transform module so only look there. |
| mlir::transform::TransformOpInterface entryPoint = |
| mlir::transform::detail::findTransformEntryPoint( |
| *transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint); |
| if (!entryPoint) |
| return sourceMgr.checkResult(mlir::failure()); |
| |
| // Apply the requested transformations. |
| mlir::transform::TransformOptions transformOptions; |
| transformOptions.enableExpensiveChecks(!clOptions->disableExpensiveChecks); |
| if (mlir::failed(applyTransforms(*payloadRoot, entryPoint, transformOptions))) |
| return sourceMgr.checkResult(mlir::failure()); |
| |
| // Print the transformed result and check the captured diagnostics if |
| // requested. |
| payloadRoot->print(os); |
| return sourceMgr.checkResult(mlir::success()); |
| } |
| |
| /// Tool entry point. |
| static mlir::LogicalResult runMain(int argc, char **argv) { |
| // Register all upstream dialects and extensions. Specific uses are advised |
| // not to register all dialects indiscriminately but rather hand-pick what is |
| // necessary for their use case. |
| mlir::DialectRegistry registry; |
| mlir::registerAllDialects(registry); |
| mlir::registerAllExtensions(registry); |
| mlir::registerAllPasses(); |
| |
| // Explicitly register the transform dialect. This is not strictly necessary |
| // since it has been already registered as part of the upstream dialect list, |
| // but useful for example purposes for cases when dialects to register are |
| // hand-picked. The transform dialect must be registered. |
| registry.insert<mlir::transform::TransformDialect>(); |
| |
| // Register various command-line options. Note that the LLVM initializer |
| // object is a RAII that ensures correct deconstruction of command-line option |
| // objects inside ManagedStatic. |
| llvm::InitLLVM y(argc, argv); |
| mlir::registerAsmPrinterCLOptions(); |
| mlir::registerMLIRContextCLOptions(); |
| registerCLOptions(); |
| llvm::cl::ParseCommandLineOptions(argc, argv, |
| "Minimal Transform dialect driver\n"); |
| |
| // Try opening the main input file. |
| std::string errorMessage; |
| std::unique_ptr<llvm::MemoryBuffer> payloadFile = |
| mlir::openInputFile(clOptions->payloadFilename, &errorMessage); |
| if (!payloadFile) { |
| llvm::errs() << errorMessage << "\n"; |
| return mlir::failure(); |
| } |
| |
| // Try opening the output file. |
| std::unique_ptr<llvm::ToolOutputFile> outputFile = |
| mlir::openOutputFile(clOptions->outputFilename, &errorMessage); |
| if (!outputFile) { |
| llvm::errs() << errorMessage << "\n"; |
| return mlir::failure(); |
| } |
| |
| // Try opening the main transform file if provided. |
| std::unique_ptr<llvm::MemoryBuffer> transformRootFile; |
| if (!clOptions->transformMainFilename.empty()) { |
| if (clOptions->transformMainFilename == clOptions->payloadFilename) { |
| llvm::errs() << "warning: " << clOptions->payloadFilename |
| << " is provided as both payload and transform file\n"; |
| } else { |
| transformRootFile = |
| mlir::openInputFile(clOptions->transformMainFilename, &errorMessage); |
| if (!transformRootFile) { |
| llvm::errs() << errorMessage << "\n"; |
| return mlir::failure(); |
| } |
| } |
| } |
| |
| // Try opening transform library files if provided. |
| SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries; |
| transformLibraries.reserve(clOptions->transformLibraryFilenames.size()); |
| for (llvm::StringRef filename : clOptions->transformLibraryFilenames) { |
| transformLibraries.emplace_back( |
| mlir::openInputFile(filename, &errorMessage)); |
| if (!transformLibraries.back()) { |
| llvm::errs() << errorMessage << "\n"; |
| return mlir::failure(); |
| } |
| } |
| |
| return processPayloadBuffer(outputFile->os(), std::move(payloadFile), |
| std::move(transformRootFile), transformLibraries, |
| registry); |
| } |
| |
| int main(int argc, char **argv) { |
| return mlir::asMainReturnCode(runMain(argc, argv)); |
| } |