Refactor pass pipeline command line parsing to support explicit pipeline strings.
This allows for explicitly specifying the pipeline to add to the pass manager. This includes the nesting structure, as well as the passes/pipelines to run. A textual pipeline string is defined as a series of names, each of which may in itself recursively contain a nested pipeline description. A name is either the name of a registered pass, or pass pipeline, (e.g. "cse") or the name of an operation type (e.g. "func").
For example, the following pipeline:
$ mlir-opt foo.mlir -cse -canonicalize -lower-to-llvm
Could now be specified as:
$ mlir-opt foo.mlir -pass-pipeline='func(cse, canonicalize), lower-to-llvm'
This will allow for running pipelines on nested operations, like say spirv modules. This does not remove any of the current functionality, and in fact can be used in unison. The new option is available via 'pass-pipeline'.
PiperOrigin-RevId: 268954279
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 8f8217e..27c1050 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -109,7 +109,7 @@
};
/// The main pass manager and pipeline builder.
-class PassManager {
+class PassManager : public OpPassManager {
public:
// If verifyPasses is true, the verifier is run after each pass.
PassManager(MLIRContext *ctx, bool verifyPasses = true);
@@ -123,25 +123,6 @@
void disableMultithreading(bool disable = true);
//===--------------------------------------------------------------------===//
- // Pipeline Building
- //===--------------------------------------------------------------------===//
-
- /// Allow converting to the impl OpPassManager.
- operator OpPassManager &() { return opPassManager; }
-
- /// Add an opaque pass pointer to the current manager. This takes ownership
- /// over the provided pass pointer.
- void addPass(std::unique_ptr<Pass> pass);
-
- /// Allow nesting other operation pass managers.
- OpPassManager &nest(const OperationName &nestedName) {
- return opPassManager.nest(nestedName);
- }
- template <typename OpT> OpPassManager &nest() {
- return opPassManager.nest<OpT>();
- }
-
- //===--------------------------------------------------------------------===//
// Instrumentations
//===--------------------------------------------------------------------===//
@@ -169,9 +150,6 @@
PassTimingDisplayMode displayMode = PassTimingDisplayMode::Pipeline);
private:
- /// The top level pass manager instance.
- OpPassManager opPassManager;
-
/// Flag that specifies if pass timing is enabled.
bool passTiming : 1;
diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index eea3778..3feffa1 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -32,11 +32,11 @@
#include <memory>
namespace mlir {
+class OpPassManager;
class Pass;
-class PassManager;
/// A registry function that adds passes to the given pass manager.
-using PassRegistryFunction = std::function<void(PassManager &)>;
+using PassRegistryFunction = std::function<void(OpPassManager &)>;
using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
@@ -44,14 +44,18 @@
/// act as a unique identifier during pass registration.
using PassID = ClassID;
+//===----------------------------------------------------------------------===//
+// PassRegistry
+//===----------------------------------------------------------------------===//
+
/// Structure to group information about a passes and pass pipelines (argument
/// to invoke via mlir-opt, description, pass pipeline builder).
class PassRegistryEntry {
public:
/// Adds this pass registry entry to the given pass manager.
- void addToPipeline(PassManager &pm) const {
+ void addToPipeline(OpPassManager &pm) const {
assert(builder &&
- "Cannot call addToPipeline on PassRegistryEntry without builder");
+ "cannot call addToPipeline on PassRegistryEntry without builder");
builder(pm);
}
@@ -95,6 +99,10 @@
PassAllocatorFunction allocator);
};
+//===----------------------------------------------------------------------===//
+// PassRegistration
+//===----------------------------------------------------------------------===//
+
/// Register a specific dialect pipeline registry function with the system,
/// typically used through the PassPipelineRegistration template.
void registerPassPipeline(StringRef arg, StringRef description,
@@ -134,7 +142,7 @@
/// Usage:
///
/// // At namespace scope.
-/// void pipelineBuilder(PassManager &pm) {
+/// void pipelineBuilder(OpPassManager &pm) {
/// pm.addPass(new MyPass());
/// pm.addPass(new MyOtherPass());
/// }
@@ -154,15 +162,40 @@
PassAllocatorFunction allocator);
};
-/// Adds command line option for each registered pass.
-struct PassNameParser : public llvm::cl::parser<const PassRegistryEntry *> {
- PassNameParser(llvm::cl::Option &opt);
+//===----------------------------------------------------------------------===//
+// PassPipelineCLParser
+//===----------------------------------------------------------------------===//
- void initialize();
+namespace detail {
+struct PassPipelineCLParserImpl;
+} // end namespace detail
- void printOptionInfo(const llvm::cl::Option &O,
- size_t GlobalWidth) const override;
+/// This class implements a command-line parser for MLIR passes. It registers a
+/// cl option with a given argument and description. This parser will register
+/// options for each of the passes and pipelines that have been registered with
+/// the pass registry; Meaning that `-cse` will refer to the CSE pass in MLIR.
+/// It also registers an argument, `pass-pipeline`, that supports parsing a
+/// textual description of a pipeline.
+class PassPipelineCLParser {
+public:
+ /// Construct a pass pipeline parser with the given command line description.
+ PassPipelineCLParser(StringRef arg, StringRef description);
+ ~PassPipelineCLParser();
+
+ /// Returns true if this parser contains any valid options to add.
+ bool hasAnyOccurrences() const;
+
+ /// Returns true if the given pass registry entry was registered at the
+ /// top-level of the parser, i.e. not within an explicit textual pipeline.
+ bool contains(const PassRegistryEntry *entry) const;
+
+ /// Adds the passes defined by this parser entry to the given pass manager.
+ void addToPipeline(OpPassManager &pm) const;
+
+private:
+ std::unique_ptr<detail::PassPipelineCLParserImpl> impl;
};
+
} // end namespace mlir
#endif // MLIR_PASS_PASSREGISTRY_H_
diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h
index 00a1e48..66b1a87 100644
--- a/mlir/include/mlir/Support/MlirOptMain.h
+++ b/mlir/include/mlir/Support/MlirOptMain.h
@@ -28,11 +28,12 @@
} // end namespace llvm
namespace mlir {
struct LogicalResult;
-class PassRegistryEntry;
+class PassPipelineCLParser;
-LogicalResult
-MlirOptMain(llvm::raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> buffer,
- const std::vector<const PassRegistryEntry *> &passList,
- bool splitInputFile, bool verifyDiagnostics, bool verifyPasses);
+LogicalResult MlirOptMain(llvm::raw_ostream &os,
+ std::unique_ptr<llvm::MemoryBuffer> buffer,
+ const PassPipelineCLParser &passPipeline,
+ bool splitInputFile, bool verifyDiagnostics,
+ bool verifyPasses);
} // end namespace mlir
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 1de3025..ee5ad2e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -438,7 +438,7 @@
//===----------------------------------------------------------------------===//
PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
- : opPassManager(OperationName(ModuleOp::getOperationName(), ctx),
+ : OpPassManager(OperationName(ModuleOp::getOperationName(), ctx),
/*disableThreads=*/false, verifyPasses),
passTiming(false) {}
@@ -448,22 +448,16 @@
LogicalResult PassManager::run(ModuleOp module) {
// Before running, make sure to coalesce any adjacent pass adaptors in the
// pipeline.
- opPassManager.getImpl().coalesceAdjacentAdaptorPasses();
+ getImpl().coalesceAdjacentAdaptorPasses();
// Construct an analysis manager for the pipeline and run it.
ModuleAnalysisManager am(module, instrumentor.get());
- return opPassManager.run(module, am);
+ return OpPassManager::run(module, am);
}
/// Disable support for multi-threading within the pass manager.
void PassManager::disableMultithreading(bool disable) {
- opPassManager.getImpl().disableThreads = disable;
-}
-
-/// Add an opaque pass pointer to the current manager. This takes ownership
-/// over the provided pass pointer.
-void PassManager::addPass(std::unique_ptr<Pass> pass) {
- opPassManager.addPass(std::move(pass));
+ getImpl().disableThreads = disable;
}
/// Add the provided instrumentation to the pass manager. This takes ownership
diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index 055e81cb..5701d30 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -25,9 +25,6 @@
namespace {
struct PassManagerOptions {
- typedef llvm::cl::list<const mlir::PassRegistryEntry *, bool, PassNameParser>
- PassOptionList;
-
PassManagerOptions();
//===--------------------------------------------------------------------===//
@@ -38,8 +35,8 @@
//===--------------------------------------------------------------------===//
// IR Printing
//===--------------------------------------------------------------------===//
- PassOptionList printBefore;
- PassOptionList printAfter;
+ PassPipelineCLParser printBefore;
+ PassPipelineCLParser printAfter;
llvm::cl::opt<bool> printBeforeAll;
llvm::cl::opt<bool> printAfterAll;
llvm::cl::opt<bool> printModuleScope;
@@ -72,10 +69,8 @@
//===----------------------------------------------------------------===//
// IR Printing
//===----------------------------------------------------------------===//
- printBefore("print-ir-before",
- llvm::cl::desc("Print IR before specified passes")),
- printAfter("print-ir-after",
- llvm::cl::desc("Print IR after specified passes")),
+ printBefore("print-ir-before", "Print IR before specified passes"),
+ printAfter("print-ir-after", "Print IR after specified passes"),
printBeforeAll("print-ir-before-all",
llvm::cl::desc("Print IR before each pass"),
llvm::cl::init(false)),
@@ -112,12 +107,12 @@
if (printBeforeAll) {
// If we are printing before all, then just return true for the filter.
shouldPrintBeforePass = [](Pass *) { return true; };
- } else if (printBefore.getNumOccurrences() != 0) {
+ } else if (printBefore.hasAnyOccurrences()) {
// Otherwise if there are specific passes to print before, then check to see
// if the pass info for the current pass is included in the list.
shouldPrintBeforePass = [&](Pass *pass) {
auto *passInfo = pass->lookupPassInfo();
- return passInfo && llvm::is_contained(printBefore, passInfo);
+ return passInfo && printBefore.contains(passInfo);
};
}
@@ -125,12 +120,12 @@
if (printAfterAll) {
// If we are printing after all, then just return true for the filter.
shouldPrintAfterPass = [](Pass *) { return true; };
- } else if (printAfter.getNumOccurrences() != 0) {
+ } else if (printAfter.hasAnyOccurrences()) {
// Otherwise if there are specific passes to print after, then check to see
// if the pass info for the current pass is included in the list.
shouldPrintAfterPass = [&](Pass *pass) {
auto *passInfo = pass->lookupPassInfo();
- return passInfo && llvm::is_contained(printAfter, passInfo);
+ return passInfo && printAfter.contains(passInfo);
};
}
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 0d85761..4da4515 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -22,6 +22,7 @@
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
+using namespace detail;
/// Static mapping of all of the registered passes.
static llvm::ManagedStatic<llvm::DenseMap<const PassID *, PassInfo>>
@@ -34,7 +35,7 @@
/// Utility to create a default registry function from a pass instance.
static PassRegistryFunction
buildDefaultRegistryFn(PassAllocatorFunction allocator) {
- return [=](PassManager &pm) { pm.addPass(allocator()); };
+ return [=](OpPassManager &pm) { pm.addPass(allocator()); };
}
//===----------------------------------------------------------------------===//
@@ -83,14 +84,252 @@
}
//===----------------------------------------------------------------------===//
+// TextualPassPipeline Parser
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class represents a textual description of a pass pipeline.
+class TextualPipeline {
+public:
+ /// Try to initialize this pipeline with the given pipeline text. An option is
+ /// given to enable accurate error reporting.
+ LogicalResult initialize(StringRef text, llvm::cl::Option &opt);
+
+ /// Add the internal pipeline elements to the provided pass manager.
+ void addToPipeline(OpPassManager &pm) const;
+
+private:
+ /// A struct to capture parsed pass pipeline names.
+ ///
+ /// A pipeline is defined as a series of names, each of which may in itself
+ /// recursively contain a nested pipeline. A name is either the name of a pass
+ /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name
+ /// is the name of a pass, the InnerPipeline is empty, since passes cannot
+ /// contain inner pipelines.
+ struct PipelineElement {
+ StringRef name;
+ const PassRegistryEntry *registryEntry;
+ std::vector<PipelineElement> innerPipeline;
+ };
+
+ /// Parse the given pipeline text into the internal pipeline vector. This
+ /// function only parses the structure of the pipeline, and does not resolve
+ /// its elements.
+ LogicalResult parsePipelineText(StringRef text, llvm::cl::Option &opt);
+
+ /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
+ /// the corresponding registry entry.
+ LogicalResult
+ resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
+ llvm::cl::Option &opt);
+
+ /// Resolve a single element of the pipeline.
+ LogicalResult resolvePipelineElement(PipelineElement &element,
+ llvm::cl::Option &opt);
+
+ /// Add the given pipeline elements to the provided pass manager.
+ void addToPipeline(ArrayRef<PipelineElement> elements,
+ OpPassManager &pm) const;
+
+ std::vector<PipelineElement> pipeline;
+};
+
+} // end anonymous namespace
+
+/// Try to initialize this pipeline with the given pipeline text. An option is
+/// given to enable accurate error reporting.
+LogicalResult TextualPipeline::initialize(StringRef text,
+ llvm::cl::Option &opt) {
+ // Parse the provided pipeline string.
+ if (failed(parsePipelineText(text, opt)))
+ return failure(opt.error("failed to parse pass pipeline: `" + text + "'"));
+ return resolvePipelineElements(pipeline, opt);
+}
+
+/// Add the internal pipeline elements to the provided pass manager.
+void TextualPipeline::addToPipeline(OpPassManager &pm) const {
+ addToPipeline(pipeline, pm);
+}
+
+/// Parse the given pipeline text into the internal pipeline vector. This
+/// function only parses the structure of the pipeline, and does not resolve
+/// its elements.
+LogicalResult TextualPipeline::parsePipelineText(StringRef text,
+ llvm::cl::Option &opt) {
+ SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
+ for (;;) {
+ std::vector<PipelineElement> &pipeline = *pipelineStack.back();
+ size_t pos = text.find_first_of(",()");
+ pipeline.push_back({text.substr(0, pos).trim(), {}});
+
+ // If we have a single terminating name, we're done.
+ if (pos == text.npos)
+ break;
+
+ char sep = text[pos];
+ text = text.substr(pos + 1);
+
+ // Just a name ending in a comma, continue.
+ if (sep == ',')
+ continue;
+
+ if (sep == '(') {
+ // Push the inner pipeline onto the stack to continue processing.
+ pipelineStack.push_back(&pipeline.back().innerPipeline);
+ continue;
+ }
+
+ // When handling the close parenthesis, we greedily consume them to avoid
+ // empty strings in the pipeline.
+ assert(sep == ')' && "Bogus separator!");
+ do {
+ // If we try to pop the outer pipeline we have unbalanced parentheses.
+ if (pipelineStack.size() == 1)
+ return failure(
+ opt.error("encountered extra closing ')' creating unbalanced "
+ "parentheses while parsing pipeline"));
+
+ pipelineStack.pop_back();
+ } while (text.consume_front(")"));
+
+ // Check if we've finished parsing.
+ if (text.empty())
+ break;
+
+ // Otherwise, the end of an inner pipeline always has to be followed by
+ // a comma, and then we can continue.
+ if (!text.consume_front(","))
+ return failure(opt.error("expected ',' after parsing pipeline near: " +
+ pipeline.back().name));
+ }
+
+ // Check for unbalanced parentheses.
+ if (pipelineStack.size() > 1)
+ return failure(
+ opt.error("encountered unbalanced parentheses while parsing pipeline"));
+
+ assert(pipelineStack.back() == &pipeline &&
+ "wrong pipeline at the bottom of the stack");
+ return success();
+}
+
+/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
+/// the corresponding registry entry.
+LogicalResult TextualPipeline::resolvePipelineElements(
+ MutableArrayRef<PipelineElement> elements, llvm::cl::Option &opt) {
+ for (auto &elt : elements)
+ if (failed(resolvePipelineElement(elt, opt)))
+ return failure();
+ return success();
+}
+
+/// Resolve a single element of the pipeline.
+LogicalResult TextualPipeline::resolvePipelineElement(PipelineElement &element,
+ llvm::cl::Option &opt) {
+ // If the inner pipeline of this element is not empty, this is an operation
+ // pipeline.
+ if (!element.innerPipeline.empty())
+ return resolvePipelineElements(element.innerPipeline, opt);
+
+ // Otherwise, this must be a pass or pass pipeline.
+ // Check to see if a pipeline was registered with this name.
+ auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
+ if (pipelineRegistryIt != passPipelineRegistry->end()) {
+ element.registryEntry = &pipelineRegistryIt->second;
+ return success();
+ }
+
+ // If not, then this must be a specific pass name.
+ for (auto &passIt : *passRegistry) {
+ if (passIt.second.getPassArgument() == element.name) {
+ element.registryEntry = &passIt.second;
+ return success();
+ }
+ }
+
+ // Emit an error for the unknown pass.
+ opt.error("'" + element.name +
+ "' does not refer to a registered pass or pass pipeline");
+ return failure();
+}
+
+/// Add the given pipeline elements to the provided pass manager.
+void TextualPipeline::addToPipeline(ArrayRef<PipelineElement> elements,
+ OpPassManager &pm) const {
+ for (auto &elt : elements) {
+ if (elt.registryEntry)
+ elt.registryEntry->addToPipeline(pm);
+ else
+ addToPipeline(elt.innerPipeline, pm.nest(elt.name));
+ }
+}
+
+//===----------------------------------------------------------------------===//
// PassNameParser
//===----------------------------------------------------------------------===//
-PassNameParser::PassNameParser(llvm::cl::Option &opt)
- : llvm::cl::parser<const PassRegistryEntry *>(opt) {}
+namespace {
+/// This struct represents the possible data entries in a parsed pass pipeline
+/// list.
+struct PassArgData {
+ PassArgData() : registryEntry(nullptr) {}
+ PassArgData(const PassRegistryEntry *registryEntry)
+ : registryEntry(registryEntry) {}
+
+ /// This field is used when the parsed option corresponds to a registered pass
+ /// or pass pipeline.
+ const PassRegistryEntry *registryEntry;
+
+ /// This field is used when the parsed option corresponds to an explicit
+ /// pipeline.
+ TextualPipeline pipeline;
+};
+} // end anonymous namespace
+
+namespace llvm {
+namespace cl {
+/// Define a valid OptionValue for the command line pass argument.
+template <>
+struct OptionValue<PassArgData> final
+ : OptionValueBase<PassArgData, /*isClass=*/true> {
+ OptionValue(const PassArgData &value) { this->setValue(value); }
+ OptionValue() = default;
+ void anchor() override {}
+
+ bool hasValue() const { return true; }
+ const PassArgData &getValue() const { return value; }
+ void setValue(const PassArgData &value) { this->value = value; }
+
+ PassArgData value;
+};
+} // end namespace cl
+} // end namespace llvm
+
+namespace {
+
+/// The name for the command line option used for parsing the textual pass
+/// pipeline.
+static constexpr llvm::StringLiteral passPipelineArg = "pass-pipeline";
+
+/// Adds command line option for each registered pass or pass pipeline, as well
+/// as textual pass pipelines.
+struct PassNameParser : public llvm::cl::parser<PassArgData> {
+ PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
+
+ void initialize();
+ void printOptionInfo(const llvm::cl::Option &opt,
+ size_t globalWidth) const override;
+ bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
+ PassArgData &value);
+};
+} // namespace
void PassNameParser::initialize() {
- llvm::cl::parser<const PassRegistryEntry *>::initialize();
+ llvm::cl::parser<PassArgData>::initialize();
+
+ /// Add an entry for the textual pass pipeline option.
+ addLiteralOption(passPipelineArg, PassArgData(),
+ "A textual description of a pass pipeline to run");
/// Add the pass entries.
for (const auto &kv : *passRegistry) {
@@ -112,6 +351,62 @@
const PassNameParser::OptionInfo *VT2) {
return VT1->Name.compare(VT2->Name);
});
- using llvm::cl::parser;
- parser<const PassRegistryEntry *>::printOptionInfo(O, GlobalWidth);
+ llvm::cl::parser<PassArgData>::printOptionInfo(O, GlobalWidth);
+}
+
+bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
+ StringRef arg, PassArgData &value) {
+ // Handle the pipeline option explicitly.
+ if (argName == passPipelineArg)
+ return failed(value.pipeline.initialize(arg, opt));
+
+ // Otherwise, default to the base for handling.
+ return llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value);
+}
+
+//===----------------------------------------------------------------------===//
+// PassPipelineCLParser
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+struct PassPipelineCLParserImpl {
+ PassPipelineCLParserImpl(StringRef arg, StringRef description)
+ : passList(arg, llvm::cl::desc(description)) {
+ passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
+ }
+
+ /// The set of passes and pass pipelines to run.
+ llvm::cl::list<PassArgData, bool, PassNameParser> passList;
+};
+} // end namespace detail
+} // end namespace mlir
+
+/// Construct a pass pipeline parser with the given command line description.
+PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
+ : impl(std::make_unique<detail::PassPipelineCLParserImpl>(arg,
+ description)) {}
+PassPipelineCLParser::~PassPipelineCLParser() {}
+
+/// Returns true if this parser contains any valid options to add.
+bool PassPipelineCLParser::hasAnyOccurrences() const {
+ return impl->passList.getNumOccurrences() != 0;
+}
+
+/// Returns true if the given pass registry entry was registered at the
+/// top-level of the parser, i.e. not within an explicit textual pipeline.
+bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
+ return llvm::any_of(impl->passList, [&](const PassArgData &data) {
+ return data.registryEntry == entry;
+ });
+}
+
+/// Adds the passes defined by this parser entry to the given pass manager.
+void PassPipelineCLParser::addToPipeline(OpPassManager &pm) const {
+ for (auto &passIt : impl->passList) {
+ if (passIt.registryEntry)
+ passIt.registryEntry->addToPipeline(pm);
+ else
+ passIt.pipeline.addToPipeline(pm);
+ }
}
diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 0b234e6..055692d 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -45,10 +45,10 @@
/// This typically parses the main source file, runs zero or more optimization
/// passes, then prints the output.
///
-static LogicalResult
-performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
- SourceMgr &sourceMgr, MLIRContext *context,
- const std::vector<const mlir::PassRegistryEntry *> &passList) {
+static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
+ bool verifyPasses, SourceMgr &sourceMgr,
+ MLIRContext *context,
+ const PassPipelineCLParser &passPipeline) {
OwningModuleRef module(parseSourceFile(sourceMgr, context));
if (!module)
return failure();
@@ -57,9 +57,8 @@
PassManager pm(context, verifyPasses);
applyPassManagerCLOptions(pm);
- // Run each of the passes that were selected.
- for (const auto *passEntry : passList)
- passEntry->addToPipeline(pm);
+ // Build the provided pipeline.
+ passPipeline.addToPipeline(pm);
// Run the pipeline.
if (failed(pm.run(*module)))
@@ -72,10 +71,10 @@
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
-static LogicalResult
-processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
- bool verifyDiagnostics, bool verifyPasses,
- const std::vector<const mlir::PassRegistryEntry *> &passList) {
+static LogicalResult processBuffer(raw_ostream &os,
+ std::unique_ptr<MemoryBuffer> ownedBuffer,
+ bool verifyDiagnostics, bool verifyPasses,
+ const PassPipelineCLParser &passPipeline) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -88,7 +87,7 @@
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
- &context, passList);
+ &context, passPipeline);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -97,7 +96,7 @@
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
- passList);
+ passPipeline);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
@@ -108,10 +107,11 @@
/// according to the normal processBuffer logic. This is primarily used to
/// allow a large number of small independent parser tests to be put into a
/// single test, but could be used for other purposes as well.
-static LogicalResult splitAndProcessFile(
- raw_ostream &os, std::unique_ptr<MemoryBuffer> originalBuffer,
- bool verifyDiagnostics, bool verifyPasses,
- const std::vector<const mlir::PassRegistryEntry *> &passList) {
+static LogicalResult
+splitAndProcessFile(raw_ostream &os,
+ std::unique_ptr<MemoryBuffer> originalBuffer,
+ bool verifyDiagnostics, bool verifyPasses,
+ const PassPipelineCLParser &passPipeline) {
const char marker[] = "// -----";
auto *origMemBuffer = originalBuffer.get();
SmallVector<StringRef, 8> sourceBuffers;
@@ -132,24 +132,24 @@
subBuffer, origMemBuffer->getBufferIdentifier() +
Twine(" split at line #") + Twine(splitLine));
if (failed(processBuffer(os, std::move(subMemBuffer), verifyDiagnostics,
- verifyPasses, passList)))
+ verifyPasses, passPipeline)))
hadUnexpectedResult = true;
}
return failure(hadUnexpectedResult);
}
-LogicalResult
-mlir::MlirOptMain(raw_ostream &os, std::unique_ptr<MemoryBuffer> buffer,
- const std::vector<const mlir::PassRegistryEntry *> &passList,
- bool splitInputFile, bool verifyDiagnostics,
- bool verifyPasses) {
+LogicalResult mlir::MlirOptMain(raw_ostream &os,
+ std::unique_ptr<MemoryBuffer> buffer,
+ const PassPipelineCLParser &passPipeline,
+ bool splitInputFile, bool verifyDiagnostics,
+ bool verifyPasses) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
if (splitInputFile)
return splitAndProcessFile(os, std::move(buffer), verifyDiagnostics,
- verifyPasses, passList);
+ verifyPasses, passPipeline);
return processBuffer(os, std::move(buffer), verifyDiagnostics, verifyPasses,
- passList);
+ passPipeline);
}
diff --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir
new file mode 100644
index 0000000..6a9c52e
--- /dev/null
+++ b/mlir/test/Pass/pipeline-parsing.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass,func(test-function-pass)),func(test-function-pass)' -cse -pass-pipeline="func(canonicalize)" -verify-each=false -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck %s
+// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s
+// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
+// RUN: not mlir-opt %s -pass-pipeline='module()(' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
+// RUN: not mlir-opt %s -pass-pipeline=',' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s
+
+// CHECK_ERROR_1: encountered unbalanced parentheses while parsing pipeline
+// CHECK_ERROR_2: encountered extra closing ')' creating unbalanced parentheses while parsing pipeline
+// CHECK_ERROR_3: expected ',' after parsing pipeline
+// CHECK_ERROR_4: does not refer to a registered pass or pass pipeline
+
+func @foo() {
+ return
+}
+
+module {
+ func @foo() {
+ return
+ }
+}
+
+// CHECK: Pipeline Collection : ['func', 'module']
+// CHECK-NEXT: 'func' Pipeline
+// CHECK-NEXT: TestFunctionPass
+// CHECK-NEXT: CSE
+// CHECK-NEXT: DominanceInfo
+// CHECK-NEXT: Canonicalizer
+// CHECK-NEXT: 'module' Pipeline
+// CHECK-NEXT: TestModulePass
+// CHECK-NEXT: 'func' Pipeline
+// CHECK-NEXT: TestFunctionPass
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 5efb2fb..7e4d8a7 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -30,7 +30,7 @@
};
} // namespace
-static void testNestedPipeline(PassManager &pm) {
+static void testNestedPipeline(OpPassManager &pm) {
// Nest a module pipeline that contains:
/// A module pass.
auto &modulePM = pm.nest<ModuleOp>();
@@ -44,6 +44,11 @@
functionPM.addPass(std::make_unique<TestFunctionPass>());
}
+static PassRegistration<TestModulePass>
+ unusedMP("test-module-pass", "Test a module pass in the pass manager");
+static PassRegistration<TestFunctionPass>
+ unusedFP("test-function-pass", "Test a function pass in the pass manager");
+
static PassPipelineRegistration
unused("test-pm-nested-pipeline",
"Test a nested pipeline in the pass manager", testNestedPipeline);
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3f9dbcde..d01f66d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -56,18 +56,14 @@
cl::desc("Run the verifier after each transformation pass"),
cl::init(true));
-static std::vector<const PassRegistryEntry *> *passList;
-
int main(int argc, char **argv) {
InitLLVM y(argc, argv);
// Register any pass manager command line options.
registerPassManagerCLOptions();
+ PassPipelineCLParser passPipeline("", "Compiler passes to run");
// Parse pass names in main to ensure static initialization completed.
- llvm::cl::list<const PassRegistryEntry *, bool, PassNameParser> passList(
- "", llvm::cl::desc("Compiler passes to run"));
- ::passList = &passList;
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
// Set up the input file.
@@ -84,6 +80,6 @@
exit(1);
}
- return failed(MlirOptMain(output->os(), std::move(file), passList,
+ return failed(MlirOptMain(output->os(), std::move(file), passPipeline,
splitInputFile, verifyDiagnostics, verifyPasses));
}