| //===- TransformInterpreterPassBase.cpp -----------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Base class with shared implementation for transform dialect interpreter |
| // passes. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" |
| #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| #include "mlir/Dialect/Transform/IR/TransformOps.h" |
| #include "mlir/Dialect/Transform/IR/Utils.h" |
| #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Verifier.h" |
| #include "mlir/IR/Visitors.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Parser/Parser.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/FileUtilities.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/FileSystem.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/Mutex.h" |
| #include "llvm/Support/Path.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "transform-dialect-interpreter" |
| #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| #define DEBUG_TYPE_DUMP_STDERR "transform-dialect-dump-repro" |
| #define DEBUG_TYPE_DUMP_FILE "transform-dialect-save-repro" |
| |
| /// Name of the attribute used for targeting the transform dialect interpreter |
| /// at specific operations. |
| constexpr static llvm::StringLiteral kTransformDialectTagAttrName = |
| "transform.target_tag"; |
| /// Value of the attribute indicating the root payload operation. |
| constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue = |
| "payload_root"; |
| /// Value of the attribute indicating the container of transform operations |
| /// (containing the top-level transform operation). |
| constexpr static llvm::StringLiteral |
| kTransformDialectTagTransformContainerValue = "transform_container"; |
| |
| /// Finds the single top-level transform operation with `root` as ancestor. |
| /// Reports an error if there is more than one such operation and returns the |
| /// first one found. Reports an error returns nullptr if no such operation |
| /// found. |
| static Operation * |
| findTopLevelTransform(Operation *root, StringRef filenameOption, |
| mlir::transform::TransformOptions options) { |
| ::mlir::transform::TransformOpInterface topLevelTransform = nullptr; |
| root->walk<WalkOrder::PreOrder>( |
| [&](::mlir::transform::TransformOpInterface transformOp) { |
| if (!transformOp |
| ->hasTrait<transform::PossibleTopLevelTransformOpTrait>()) |
| return WalkResult::skip(); |
| if (!topLevelTransform) { |
| topLevelTransform = transformOp; |
| return WalkResult::skip(); |
| } |
| if (options.getEnforceSingleToplevelTransformOp()) { |
| auto diag = transformOp.emitError() |
| << "more than one top-level transform op"; |
| diag.attachNote(topLevelTransform.getLoc()) |
| << "previous top-level transform op"; |
| return WalkResult::interrupt(); |
| } |
| return WalkResult::skip(); |
| }); |
| if (!topLevelTransform) { |
| auto diag = root->emitError() |
| << "could not find a nested top-level transform op"; |
| diag.attachNote() << "use the '" << filenameOption |
| << "' option to provide transform as external file"; |
| return nullptr; |
| } |
| return topLevelTransform; |
| } |
| |
| /// Finds an operation nested in `root` that has the transform dialect tag |
| /// attribute with the value specified as `tag`. Assumes only one operation |
| /// may have the tag. Returns nullptr if there is no such operation. |
| static Operation *findOpWithTag(Operation *root, StringRef tagKey, |
| StringRef tagValue) { |
| Operation *found = nullptr; |
| WalkResult walkResult = root->walk<WalkOrder::PreOrder>( |
| [tagKey, tagValue, &found, root](Operation *op) { |
| auto attr = op->getAttrOfType<StringAttr>(tagKey); |
| if (!attr || attr.getValue() != tagValue) |
| return WalkResult::advance(); |
| |
| if (found) { |
| InFlightDiagnostic diag = root->emitError() |
| << "more than one operation with " << tagKey |
| << "=\"" << tagValue << "\" attribute"; |
| diag.attachNote(found->getLoc()) << "first operation"; |
| diag.attachNote(op->getLoc()) << "other operation"; |
| return WalkResult::interrupt(); |
| } |
| |
| found = op; |
| return WalkResult::advance(); |
| }); |
| if (walkResult.wasInterrupted()) |
| return nullptr; |
| |
| if (!found) { |
| root->emitError() << "could not find the operation with " << tagKey << "=\"" |
| << tagValue << "\" attribute"; |
| } |
| return found; |
| } |
| |
| /// Returns the ancestor of `target` that doesn't have a parent. |
| static Operation *getRootOperation(Operation *target) { |
| Operation *root = target; |
| while (root->getParentOp()) |
| root = root->getParentOp(); |
| return root; |
| } |
| |
| /// Prints the CLI command running the repro with the current path. |
| // TODO: make binary name optional by querying LLVM command line API for the |
| // name of the current binary. |
| static llvm::raw_ostream & |
| printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, |
| const Pass::Option<std::string> &debugPayloadRootTag, |
| const Pass::Option<std::string> &debugTransformRootTag, |
| StringRef binaryName) { |
| os << llvm::formatv( |
| "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName, |
| passName, debugPayloadRootTag.getArgStr(), |
| debugPayloadRootTag.empty() |
| ? StringRef(kTransformDialectTagPayloadRootValue) |
| : debugPayloadRootTag, |
| debugTransformRootTag.getArgStr(), |
| debugTransformRootTag.empty() |
| ? StringRef(kTransformDialectTagTransformContainerValue) |
| : debugTransformRootTag, |
| binaryName); |
| return os; |
| } |
| |
| /// Prints the module rooted at `root` to `os` and appends |
| /// `transformContainer` if it is not nested in `root`. |
| static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, |
| Operation *root, |
| Operation *transform) { |
| root->print(os); |
| if (!root->isAncestor(transform)) |
| transform->print(os); |
| return os; |
| } |
| |
| /// Saves the payload and the transform IR into a temporary file and reports |
| /// the file name to `os`. |
| [[maybe_unused]] static void |
| saveReproToTempFile(llvm::raw_ostream &os, Operation *target, |
| Operation *transform, StringRef passName, |
| const Pass::Option<std::string> &debugPayloadRootTag, |
| const Pass::Option<std::string> &debugTransformRootTag, |
| const Pass::ListOption<std::string> &transformLibraryPaths, |
| StringRef binaryName) { |
| using llvm::sys::fs::TempFile; |
| Operation *root = getRootOperation(target); |
| |
| SmallVector<char, 128> tmpPath; |
| llvm::sys::path::system_temp_directory(/*erasedOnReboot=*/true, tmpPath); |
| llvm::sys::path::append(tmpPath, "transform_dialect_%%%%%%.mlir"); |
| llvm::Expected<TempFile> tempFile = TempFile::create(tmpPath); |
| if (!tempFile) { |
| os << "could not open temporary file to save the repro\n"; |
| return; |
| } |
| |
| llvm::raw_fd_ostream fout(tempFile->FD, /*shouldClose=*/false); |
| printModuleForRepro(fout, root, transform); |
| fout.flush(); |
| std::string filename = tempFile->TmpName; |
| |
| if (tempFile->keep()) { |
| os << "could not preserve the temporary file with the repro\n"; |
| return; |
| } |
| |
| os << "=== Transform Interpreter Repro ===\n"; |
| printReproCall(os, root->getName().getStringRef(), passName, |
| debugPayloadRootTag, debugTransformRootTag, binaryName) |
| << " " << filename << "\n"; |
| os << "===================================\n"; |
| } |
| |
| // Optionally perform debug actions requested by the user to dump IR and a |
| // repro to stderr and/or a file. |
| static void performOptionalDebugActions( |
| Operation *target, Operation *transform, StringRef passName, |
| const Pass::Option<std::string> &debugPayloadRootTag, |
| const Pass::Option<std::string> &debugTransformRootTag, |
| const Pass::ListOption<std::string> &transformLibraryPaths, |
| StringRef binaryName) { |
| MLIRContext *context = target->getContext(); |
| |
| // If we are not planning to print, bail early. |
| bool hasDebugFlags = false; |
| DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { hasDebugFlags = true; }); |
| DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { hasDebugFlags = true; }); |
| if (!hasDebugFlags) |
| return; |
| |
| // We will be mutating the IR to set attributes. If this is running |
| // concurrently on several parts of a container or using a shared transform |
| // script, this would create a race. Bail in multithreaded mode and require |
| // the user to disable threading to dump repros. |
| static llvm::sys::SmartMutex<true> dbgStreamMutex; |
| if (target->getContext()->isMultithreadingEnabled()) { |
| llvm::sys::SmartScopedLock<true> lock(dbgStreamMutex); |
| llvm::dbgs() << "=======================================================\n"; |
| llvm::dbgs() << "| Transform reproducers cannot be produced |\n"; |
| llvm::dbgs() << "| in multi-threaded mode! |\n"; |
| llvm::dbgs() << "=======================================================\n"; |
| return; |
| } |
| |
| Operation *root = getRootOperation(target); |
| |
| // Add temporary debug / repro attributes, these must never leak out. |
| if (debugPayloadRootTag.empty()) { |
| target->setAttr( |
| kTransformDialectTagAttrName, |
| StringAttr::get(context, kTransformDialectTagPayloadRootValue)); |
| } |
| if (debugTransformRootTag.empty()) { |
| transform->setAttr( |
| kTransformDialectTagAttrName, |
| StringAttr::get(context, kTransformDialectTagTransformContainerValue)); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_STDERR, { |
| llvm::dbgs() << "=== Transform Interpreter Repro ===\n"; |
| printReproCall(llvm::dbgs() << "cat <<EOF | ", |
| root->getName().getStringRef(), passName, |
| debugPayloadRootTag, debugTransformRootTag, binaryName) |
| << "\n"; |
| printModuleForRepro(llvm::dbgs(), root, transform); |
| llvm::dbgs() << "\nEOF\n"; |
| llvm::dbgs() << "===================================\n"; |
| }); |
| (void)root; |
| DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, { |
| saveReproToTempFile(llvm::dbgs(), target, transform, passName, |
| debugPayloadRootTag, debugTransformRootTag, |
| transformLibraryPaths, binaryName); |
| }); |
| |
| // Remove temporary attributes if they were set. |
| if (debugPayloadRootTag.empty()) |
| target->removeAttr(kTransformDialectTagAttrName); |
| if (debugTransformRootTag.empty()) |
| transform->removeAttr(kTransformDialectTagAttrName); |
| } |
| |
| LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( |
| Operation *target, StringRef passName, |
| const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule, |
| const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule, |
| const RaggedArray<MappedValue> &extraMappings, |
| const TransformOptions &options, |
| const Pass::Option<std::string> &transformFileName, |
| const Pass::ListOption<std::string> &transformLibraryPaths, |
| const Pass::Option<std::string> &debugPayloadRootTag, |
| const Pass::Option<std::string> &debugTransformRootTag, |
| StringRef binaryName) { |
| bool hasSharedTransformModule = |
| sharedTransformModule && *sharedTransformModule; |
| bool hasTransformLibraryModule = |
| transformLibraryModule && *transformLibraryModule; |
| assert((!hasSharedTransformModule || !hasTransformLibraryModule) && |
| "at most one of shared or library transform module can be set"); |
| |
| // Step 1 |
| // ------ |
| // If debugPayloadRootTag was passed, then we are in user-specified selection |
| // of the transformed IR. This corresponds to REPL debug mode. Otherwise, just |
| // apply to `target`. |
| Operation *payloadRoot = target; |
| if (!debugPayloadRootTag.empty()) { |
| payloadRoot = findOpWithTag(target, kTransformDialectTagAttrName, |
| debugPayloadRootTag); |
| if (!payloadRoot) |
| return failure(); |
| } |
| |
| // Step 2 |
| // ------ |
| // If a shared transform was specified separately, use it. Otherwise, the |
| // transform is embedded in the payload IR. If debugTransformRootTag was |
| // passed, then we are in user-specified selection of the transforming IR. |
| // This corresponds to REPL debug mode. |
| Operation *transformContainer = |
| hasSharedTransformModule ? sharedTransformModule->get() : target; |
| Operation *transformRoot = |
| debugTransformRootTag.empty() |
| ? findTopLevelTransform(transformContainer, |
| transformFileName.getArgStr(), options) |
| : findOpWithTag(transformContainer, kTransformDialectTagAttrName, |
| debugTransformRootTag); |
| if (!transformRoot) |
| return failure(); |
| |
| if (!transformRoot->hasTrait<PossibleTopLevelTransformOpTrait>()) { |
| return emitError(transformRoot->getLoc()) |
| << "expected the transform entry point to be a top-level transform " |
| "op"; |
| } |
| |
| // Step 3 |
| // ------ |
| // Copy external defintions for symbols if provided. Be aware of potential |
| // concurrent execution (normally, the error shouldn't be triggered unless the |
| // transform IR modifies itself in a pass, which is also forbidden elsewhere). |
| if (hasTransformLibraryModule) { |
| if (!target->isProperAncestor(transformRoot)) { |
| InFlightDiagnostic diag = |
| transformRoot->emitError() |
| << "cannot inject transform definitions next to pass anchor op"; |
| diag.attachNote(target->getLoc()) << "pass anchor op"; |
| return diag; |
| } |
| InFlightDiagnostic diag = detail::mergeSymbolsInto( |
| SymbolTable::getNearestSymbolTable(transformRoot), |
| transformLibraryModule->get()->clone()); |
| if (failed(diag)) { |
| diag.attachNote(transformRoot->getLoc()) |
| << "failed to merge library symbols into transform root"; |
| return diag; |
| } |
| } |
| |
| // Step 4 |
| // ------ |
| // Optionally perform debug actions requested by the user to dump IR and a |
| // repro to stderr and/or a file. |
| performOptionalDebugActions(target, transformRoot, passName, |
| debugPayloadRootTag, debugTransformRootTag, |
| transformLibraryPaths, binaryName); |
| |
| // Step 5 |
| // ------ |
| // Apply the transform to the IR |
| return applyTransforms(payloadRoot, cast<TransformOpInterface>(transformRoot), |
| extraMappings, options); |
| } |
| |
| LogicalResult transform::detail::interpreterBaseInitializeImpl( |
| MLIRContext *context, StringRef transformFileName, |
| ArrayRef<std::string> transformLibraryPaths, |
| std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule, |
| std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule, |
| function_ref<std::optional<LogicalResult>(OpBuilder &, Location)> |
| moduleBuilder) { |
| auto unknownLoc = UnknownLoc::get(context); |
| |
| // Parse module from file. |
| OwningOpRef<ModuleOp> moduleFromFile; |
| { |
| auto loc = FileLineColLoc::get(context, transformFileName, 0, 0); |
| if (failed(detail::parseTransformModuleFromFile(context, transformFileName, |
| moduleFromFile))) |
| return emitError(loc) << "failed to parse transform module"; |
| if (moduleFromFile && failed(mlir::verify(*moduleFromFile))) |
| return emitError(loc) << "failed to verify transform module"; |
| } |
| |
| // Assemble list of library files. |
| SmallVector<std::string> libraryFileNames; |
| if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context, |
| libraryFileNames))) |
| return failure(); |
| |
| // Parse modules from library files. |
| SmallVector<OwningOpRef<ModuleOp>> parsedLibraries; |
| for (const std::string &libraryFileName : libraryFileNames) { |
| OwningOpRef<ModuleOp> parsedLibrary; |
| auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0); |
| if (failed(detail::parseTransformModuleFromFile(context, libraryFileName, |
| parsedLibrary))) |
| return emitError(loc) << "failed to parse transform library module"; |
| if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) |
| return emitError(loc) << "failed to verify transform library module"; |
| parsedLibraries.push_back(std::move(parsedLibrary)); |
| } |
| |
| // Build shared transform module. |
| if (moduleFromFile) { |
| sharedTransformModule = |
| std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile)); |
| } else if (moduleBuilder) { |
| auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0); |
| auto localModule = std::make_shared<OwningOpRef<ModuleOp>>( |
| ModuleOp::create(unknownLoc, "__transform")); |
| |
| OpBuilder b(context); |
| b.setInsertionPointToEnd(localModule->get().getBody()); |
| if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) { |
| if (failed(*result)) |
| return (*localModule)->emitError() |
| << "failed to create shared transform module"; |
| sharedTransformModule = std::move(localModule); |
| } |
| } |
| |
| if (parsedLibraries.empty()) |
| return success(); |
| |
| // Merge parsed libraries into one module. |
| auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0); |
| OwningOpRef<ModuleOp> mergedParsedLibraries = |
| ModuleOp::create(loc, "__transform"); |
| { |
| mergedParsedLibraries.get()->setAttr("transform.with_named_sequence", |
| UnitAttr::get(context)); |
| IRRewriter rewriter(context); |
| // TODO: extend `mergeSymbolsInto` to support multiple `other` modules. |
| for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) { |
| if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(), |
| std::move(parsedLibrary)))) |
| return mergedParsedLibraries->emitError() |
| << "failed to verify merged transform module"; |
| } |
| } |
| |
| // Use parsed libaries to resolve symbols in shared transform module or return |
| // as separate library module. |
| if (sharedTransformModule && *sharedTransformModule) { |
| if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(), |
| std::move(mergedParsedLibraries)))) |
| return (*sharedTransformModule)->emitError() |
| << "failed to merge symbols from library files " |
| "into shared transform module"; |
| } else { |
| transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>( |
| std::move(mergedParsedLibraries)); |
| } |
| return success(); |
| } |