|  | //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  | // | 
|  | // This file defines helpers used in the op generators. | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "OpGenHelpers.h" | 
|  | #include "llvm/ADT/StringSet.h" | 
|  | #include "llvm/Support/CommandLine.h" | 
|  | #include "llvm/Support/FormatVariadic.h" | 
|  | #include "llvm/Support/Regex.h" | 
|  | #include "llvm/TableGen/Error.h" | 
|  |  | 
|  | using namespace llvm; | 
|  | using namespace mlir; | 
|  | using namespace mlir::tblgen; | 
|  |  | 
|  | cl::OptionCategory opDefGenCat("Options for op definition generators"); | 
|  |  | 
|  | static cl::opt<std::string> opIncFilter( | 
|  | "op-include-regex", | 
|  | cl::desc("Regex of name of op's to include (no filter if empty)"), | 
|  | cl::cat(opDefGenCat)); | 
|  | static cl::opt<std::string> opExcFilter( | 
|  | "op-exclude-regex", | 
|  | cl::desc("Regex of name of op's to exclude (no filter if empty)"), | 
|  | cl::cat(opDefGenCat)); | 
|  | static cl::opt<unsigned> opShardCount( | 
|  | "op-shard-count", | 
|  | cl::desc("The number of shards into which the op classes will be divided"), | 
|  | cl::cat(opDefGenCat), cl::init(1)); | 
|  |  | 
|  | static std::string getOperationName(const Record &def) { | 
|  | auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); | 
|  | auto opName = def.getValueAsString("opName"); | 
|  | if (prefix.empty()) | 
|  | return std::string(opName); | 
|  | return std::string(formatv("{0}.{1}", prefix, opName)); | 
|  | } | 
|  |  | 
|  | std::vector<const Record *> | 
|  | mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) { | 
|  | const Record *classDef = records.getClass("Op"); | 
|  | if (!classDef) | 
|  | PrintFatalError("ERROR: Couldn't find the 'Op' class!\n"); | 
|  |  | 
|  | Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); | 
|  | std::vector<const Record *> defs; | 
|  | for (const auto &def : records.getDefs()) { | 
|  | if (!def.second->isSubClassOf(classDef)) | 
|  | continue; | 
|  | // Include if no include filter or include filter matches. | 
|  | if (!opIncFilter.empty() && | 
|  | !includeRegex.match(getOperationName(*def.second))) | 
|  | continue; | 
|  | // Unless there is an exclude filter and it matches. | 
|  | if (!opExcFilter.empty() && | 
|  | excludeRegex.match(getOperationName(*def.second))) | 
|  | continue; | 
|  | defs.push_back(def.second.get()); | 
|  | } | 
|  |  | 
|  | return defs; | 
|  | } | 
|  |  | 
|  | bool mlir::tblgen::isPythonReserved(StringRef str) { | 
|  | static StringSet<> reserved({ | 
|  | "False",  "None",   "True",    "and",      "as",       "assert", "async", | 
|  | "await",  "break",  "class",   "continue", "def",      "del",    "elif", | 
|  | "else",   "except", "finally", "for",      "from",     "global", "if", | 
|  | "import", "in",     "is",      "lambda",   "nonlocal", "not",    "or", | 
|  | "pass",   "raise",  "return",  "try",      "while",    "with",   "yield", | 
|  | }); | 
|  | // These aren't Python keywords but builtin functions that shouldn't/can't be | 
|  | // shadowed. | 
|  | reserved.insert("callable"); | 
|  | reserved.insert("issubclass"); | 
|  | reserved.insert("type"); | 
|  | return reserved.contains(str); | 
|  | } | 
|  |  | 
|  | void mlir::tblgen::shardOpDefinitions( | 
|  | ArrayRef<const Record *> defs, | 
|  | SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) { | 
|  | assert(opShardCount > 0 && "expected a positive shard count"); | 
|  | if (opShardCount == 1) { | 
|  | shardedDefs.push_back(defs); | 
|  | return; | 
|  | } | 
|  |  | 
|  | unsigned minShardSize = defs.size() / opShardCount; | 
|  | unsigned numMissing = defs.size() - minShardSize * opShardCount; | 
|  | shardedDefs.reserve(opShardCount); | 
|  | for (unsigned i = 0, start = 0; i < opShardCount; ++i) { | 
|  | unsigned size = minShardSize + (i < numMissing); | 
|  | shardedDefs.push_back(defs.slice(start, size)); | 
|  | start += size; | 
|  | } | 
|  | } |