| //===- Predicate.cpp - Predicate class ------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Wrapper around predicates defined in TableGen. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/Predicate.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallPtrSet.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace mlir; |
| using namespace tblgen; |
| |
| // Construct a Predicate from a record. |
| Pred::Pred(const llvm::Record *record) : def(record) { |
| assert(def->isSubClassOf("Pred") && |
| "must be a subclass of TableGen 'Pred' class"); |
| } |
| |
| // Construct a Predicate from an initializer. |
| Pred::Pred(const llvm::Init *init) : def(nullptr) { |
| if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) |
| def = defInit->getDef(); |
| } |
| |
| std::string Pred::getCondition() const { |
| // Static dispatch to subclasses. |
| if (def->isSubClassOf("CombinedPred")) |
| return static_cast<const CombinedPred *>(this)->getConditionImpl(); |
| if (def->isSubClassOf("CPred")) |
| return static_cast<const CPred *>(this)->getConditionImpl(); |
| llvm_unreachable("Pred::getCondition must be overridden in subclasses"); |
| } |
| |
| bool Pred::isCombined() const { |
| return def && def->isSubClassOf("CombinedPred"); |
| } |
| |
| ArrayRef<llvm::SMLoc> Pred::getLoc() const { return def->getLoc(); } |
| |
| CPred::CPred(const llvm::Record *record) : Pred(record) { |
| assert(def->isSubClassOf("CPred") && |
| "must be a subclass of Tablegen 'CPred' class"); |
| } |
| |
| CPred::CPred(const llvm::Init *init) : Pred(init) { |
| assert((!def || def->isSubClassOf("CPred")) && |
| "must be a subclass of Tablegen 'CPred' class"); |
| } |
| |
| // Get condition of the C Predicate. |
| std::string CPred::getConditionImpl() const { |
| assert(!isNull() && "null predicate does not have a condition"); |
| return std::string(def->getValueAsString("predExpr")); |
| } |
| |
| CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { |
| assert(def->isSubClassOf("CombinedPred") && |
| "must be a subclass of Tablegen 'CombinedPred' class"); |
| } |
| |
| CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { |
| assert((!def || def->isSubClassOf("CombinedPred")) && |
| "must be a subclass of Tablegen 'CombinedPred' class"); |
| } |
| |
| const llvm::Record *CombinedPred::getCombinerDef() const { |
| assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); |
| return def->getValueAsDef("kind"); |
| } |
| |
| const std::vector<llvm::Record *> CombinedPred::getChildren() const { |
| assert(def->getValue("children") && |
| "CombinedPred must have a value 'children'"); |
| return def->getValueAsListOfDefs("children"); |
| } |
| |
| namespace { |
| // Kinds of nodes in a logical predicate tree. |
| enum class PredCombinerKind { |
| Leaf, |
| And, |
| Or, |
| Not, |
| SubstLeaves, |
| Concat, |
| // Special kinds that are used in simplification. |
| False, |
| True |
| }; |
| |
| // A node in a logical predicate tree. |
| struct PredNode { |
| PredCombinerKind kind; |
| const Pred *predicate; |
| SmallVector<PredNode *, 4> children; |
| std::string expr; |
| |
| // Prefix and suffix are used by ConcatPred. |
| std::string prefix; |
| std::string suffix; |
| }; |
| } // end anonymous namespace |
| |
| // Get a predicate tree node kind based on the kind used in the predicate |
| // TableGen record. |
| static PredCombinerKind getPredCombinerKind(const Pred &pred) { |
| if (!pred.isCombined()) |
| return PredCombinerKind::Leaf; |
| |
| const auto &combinedPred = static_cast<const CombinedPred &>(pred); |
| return StringSwitch<PredCombinerKind>( |
| combinedPred.getCombinerDef()->getName()) |
| .Case("PredCombinerAnd", PredCombinerKind::And) |
| .Case("PredCombinerOr", PredCombinerKind::Or) |
| .Case("PredCombinerNot", PredCombinerKind::Not) |
| .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) |
| .Case("PredCombinerConcat", PredCombinerKind::Concat); |
| } |
| |
| namespace { |
| // Substitution<pattern, replacement>. |
| using Subst = std::pair<StringRef, StringRef>; |
| } // end anonymous namespace |
| |
| /// Perform the given substitutions on 'str' in-place. |
| static void performSubstitutions(std::string &str, |
| ArrayRef<Subst> substitutions) { |
| // Apply all parent substitutions from innermost to outermost. |
| for (const auto &subst : llvm::reverse(substitutions)) { |
| auto pos = str.find(std::string(subst.first)); |
| while (pos != std::string::npos) { |
| str.replace(pos, subst.first.size(), std::string(subst.second)); |
| // Skip the newly inserted substring, which itself may consider the |
| // pattern to match. |
| pos += subst.second.size(); |
| // Find the next possible match position. |
| pos = str.find(std::string(subst.first), pos); |
| } |
| } |
| } |
| |
| // Build the predicate tree starting from the top-level predicate, which may |
| // have children, and perform leaf substitutions inplace. Note that after |
| // substitution, nodes are still pointing to the original TableGen record. |
| // All nodes are created within "allocator". |
| static PredNode * |
| buildPredicateTree(const Pred &root, |
| llvm::SpecificBumpPtrAllocator<PredNode> &allocator, |
| ArrayRef<Subst> substitutions) { |
| auto *rootNode = allocator.Allocate(); |
| new (rootNode) PredNode; |
| rootNode->kind = getPredCombinerKind(root); |
| rootNode->predicate = &root; |
| if (!root.isCombined()) { |
| rootNode->expr = root.getCondition(); |
| performSubstitutions(rootNode->expr, substitutions); |
| return rootNode; |
| } |
| |
| // If the current combined predicate is a leaf substitution, append it to the |
| // list before continuing. |
| auto allSubstitutions = llvm::to_vector<4>(substitutions); |
| if (rootNode->kind == PredCombinerKind::SubstLeaves) { |
| const auto &substPred = static_cast<const SubstLeavesPred &>(root); |
| allSubstitutions.push_back( |
| {substPred.getPattern(), substPred.getReplacement()}); |
| |
| // If the current predicate is a ConcatPred, record the prefix and suffix. |
| } else if (rootNode->kind == PredCombinerKind::Concat) { |
| const auto &concatPred = static_cast<const ConcatPred &>(root); |
| rootNode->prefix = std::string(concatPred.getPrefix()); |
| performSubstitutions(rootNode->prefix, substitutions); |
| rootNode->suffix = std::string(concatPred.getSuffix()); |
| performSubstitutions(rootNode->suffix, substitutions); |
| } |
| |
| // Build child subtrees. |
| auto combined = static_cast<const CombinedPred &>(root); |
| for (const auto *record : combined.getChildren()) { |
| auto childTree = |
| buildPredicateTree(Pred(record), allocator, allSubstitutions); |
| rootNode->children.push_back(childTree); |
| } |
| return rootNode; |
| } |
| |
| // Simplify a predicate tree rooted at "node" using the predicates that are |
| // known to be true(false). For AND(OR) combined predicates, if any of the |
| // children is known to be false(true), the result is also false(true). |
| // Furthermore, for AND(OR) combined predicates, children that are known to be |
| // true(false) don't have to be checked dynamically. |
| static PredNode * |
| propagateGroundTruth(PredNode *node, |
| const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds, |
| const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) { |
| // If the current predicate is known to be true or false, change the kind of |
| // the node and return immediately. |
| if (knownTruePreds.count(node->predicate) != 0) { |
| node->kind = PredCombinerKind::True; |
| node->children.clear(); |
| return node; |
| } |
| if (knownFalsePreds.count(node->predicate) != 0) { |
| node->kind = PredCombinerKind::False; |
| node->children.clear(); |
| return node; |
| } |
| |
| // If the current node is a substitution, stop recursion now. |
| // The expressions in the leaves below this node were rewritten, but the nodes |
| // still point to the original predicate records. While the original |
| // predicate may be known to be true or false, it is not necessarily the case |
| // after rewriting. |
| // TODO: we can support ground truth for rewritten |
| // predicates by either (a) having our own unique'ing of the predicates |
| // instead of relying on TableGen record pointers or (b) taking ground truth |
| // values optionally prefixed with a list of substitutions to apply, e.g. |
| // "predX is true by itself as well as predSubY leaf substitution had been |
| // applied to it". |
| if (node->kind == PredCombinerKind::SubstLeaves) { |
| return node; |
| } |
| |
| // Otherwise, look at child nodes. |
| |
| // Move child nodes into some local variable so that they can be optimized |
| // separately and re-added if necessary. |
| llvm::SmallVector<PredNode *, 4> children; |
| std::swap(node->children, children); |
| |
| for (auto &child : children) { |
| // First, simplify the child. This maintains the predicate as it was. |
| auto simplifiedChild = |
| propagateGroundTruth(child, knownTruePreds, knownFalsePreds); |
| |
| // Just add the child if we don't know how to simplify the current node. |
| if (node->kind != PredCombinerKind::And && |
| node->kind != PredCombinerKind::Or) { |
| node->children.push_back(simplifiedChild); |
| continue; |
| } |
| |
| // Second, based on the type define which known values of child predicates |
| // immediately collapse this predicate to a known value, and which others |
| // may be safely ignored. |
| // OR(..., True, ...) = True |
| // OR(..., False, ...) = OR(..., ...) |
| // AND(..., False, ...) = False |
| // AND(..., True, ...) = AND(..., ...) |
| auto collapseKind = node->kind == PredCombinerKind::And |
| ? PredCombinerKind::False |
| : PredCombinerKind::True; |
| auto eraseKind = node->kind == PredCombinerKind::And |
| ? PredCombinerKind::True |
| : PredCombinerKind::False; |
| const auto &collapseList = |
| node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; |
| const auto &eraseList = |
| node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; |
| if (simplifiedChild->kind == collapseKind || |
| collapseList.count(simplifiedChild->predicate) != 0) { |
| node->kind = collapseKind; |
| node->children.clear(); |
| return node; |
| } else if (simplifiedChild->kind == eraseKind || |
| eraseList.count(simplifiedChild->predicate) != 0) { |
| continue; |
| } |
| node->children.push_back(simplifiedChild); |
| } |
| return node; |
| } |
| |
| // Combine a list of predicate expressions using a binary combiner. If a list |
| // is empty, return "init". |
| static std::string combineBinary(ArrayRef<std::string> children, |
| std::string combiner, std::string init) { |
| if (children.empty()) |
| return init; |
| |
| auto size = children.size(); |
| if (size == 1) |
| return children.front(); |
| |
| std::string str; |
| llvm::raw_string_ostream os(str); |
| os << '(' << children.front() << ')'; |
| for (unsigned i = 1; i < size; ++i) { |
| os << ' ' << combiner << " (" << children[i] << ')'; |
| } |
| return os.str(); |
| } |
| |
| // Prepend negation to the only condition in the predicate expression list. |
| static std::string combineNot(ArrayRef<std::string> children) { |
| assert(children.size() == 1 && "expected exactly one child predicate of Neg"); |
| return (Twine("!(") + children.front() + Twine(')')).str(); |
| } |
| |
| // Recursively traverse the predicate tree in depth-first post-order and build |
| // the final expression. |
| static std::string getCombinedCondition(const PredNode &root) { |
| // Immediately return for non-combiner predicates that don't have children. |
| if (root.kind == PredCombinerKind::Leaf) |
| return root.expr; |
| if (root.kind == PredCombinerKind::True) |
| return "true"; |
| if (root.kind == PredCombinerKind::False) |
| return "false"; |
| |
| // Recurse into children. |
| llvm::SmallVector<std::string, 4> childExpressions; |
| childExpressions.reserve(root.children.size()); |
| for (const auto &child : root.children) |
| childExpressions.push_back(getCombinedCondition(*child)); |
| |
| // Combine the expressions based on the predicate node kind. |
| if (root.kind == PredCombinerKind::And) |
| return combineBinary(childExpressions, "&&", "true"); |
| if (root.kind == PredCombinerKind::Or) |
| return combineBinary(childExpressions, "||", "false"); |
| if (root.kind == PredCombinerKind::Not) |
| return combineNot(childExpressions); |
| if (root.kind == PredCombinerKind::Concat) { |
| assert(childExpressions.size() == 1 && |
| "ConcatPred should only have one child"); |
| return root.prefix + childExpressions.front() + root.suffix; |
| } |
| |
| // Substitutions were applied before so just ignore them. |
| if (root.kind == PredCombinerKind::SubstLeaves) { |
| assert(childExpressions.size() == 1 && |
| "substitution predicate must have one child"); |
| return childExpressions[0]; |
| } |
| |
| llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); |
| } |
| |
| std::string CombinedPred::getConditionImpl() const { |
| llvm::SpecificBumpPtrAllocator<PredNode> allocator; |
| auto predicateTree = buildPredicateTree(*this, allocator, {}); |
| predicateTree = |
| propagateGroundTruth(predicateTree, |
| /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), |
| /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>()); |
| |
| return getCombinedCondition(*predicateTree); |
| } |
| |
| StringRef SubstLeavesPred::getPattern() const { |
| return def->getValueAsString("pattern"); |
| } |
| |
| StringRef SubstLeavesPred::getReplacement() const { |
| return def->getValueAsString("replacement"); |
| } |
| |
| StringRef ConcatPred::getPrefix() const { |
| return def->getValueAsString("prefix"); |
| } |
| |
| StringRef ConcatPred::getSuffix() const { |
| return def->getValueAsString("suffix"); |
| } |