blob: 2466a80d0b6411dd1712e3576d0df28efbc07e36 [file] [log] [blame]
//===- PredicateTree.cpp - Predicate tree merging -------------------------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "PredicateTree.h"
#include "RootOrdering.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <queue>
#define DEBUG_TYPE "pdl-predicate-tree"
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
// Predicate List Building
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos);
/// Compares the depths of two positions.
static bool comparePosDepth(Position *lhs, Position *rhs) {
return lhs->getOperationDepth() < rhs->getOperationDepth();
/// Returns the number of non-range elements within `values`.
static unsigned getNumNonRangeValues(ValueRange values) {
return llvm::count_if(values.getTypes(),
[](Type type) { return !type.isa<pdl::RangeType>(); });
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type or value, add a constraint.
if (Value type = attr.type())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.valueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
/// Collect all of the predicates for the given operand position.
static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos) {
Type valueType = val.getType();
bool isVariadic = valueType.isa<pdl::RangeType>();
// If this is a typed operand, add a type constraint.
TypeSwitch<Operation *>(val.getDefiningOp())
.Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
// Prevent traversal into a null value if the operand has a proper
// index.
if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
predList.emplace_back(pos, builder.getIsNotNull());
if (Value type = op.type())
getTreePredicates(predList, type, builder, inputs,
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
Optional<unsigned> index = op.index();
// Prevent traversal into a null value if the result has a proper index.
if (index)
predList.emplace_back(pos, builder.getIsNotNull());
// Get the parent operation of this operand.
OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
predList.emplace_back(parentPos, builder.getIsNotNull());
// Ensure that the operands match the corresponding results of the
// parent operation.
Position *resultPos = nullptr;
if (std::is_same<pdl::ResultOp, decltype(op)>::value)
resultPos = builder.getResult(parentPos, *index);
resultPos = builder.getResultGroup(parentPos, index, isVariadic);
predList.emplace_back(resultPos, builder.getEqualTo(pos));
// Collect the predicates of the parent operation.
getTreePredicates(predList, op.parent(), builder, inputs,
(Position *)parentPos);
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
OperationPosition *pos,
Optional<unsigned> ignoreOperand = llvm::None) {
assert(val.getType().isa<pdl::OperationType>() && "expected operation");
pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
OperationPosition *opPos = cast<OperationPosition>(pos);
// Ensure getDefiningOp returns a non-null operation.
if (!opPos->isRoot())
predList.emplace_back(pos, builder.getIsNotNull());
// Check that this is the correct root operation.
if (Optional<StringRef> opName =
predList.emplace_back(pos, builder.getOperationName(*opName));
// Check that the operation has the proper number of operands. If there are
// any variable length operands, we check a minimum instead of an exact count.
OperandRange operands = op.operands();
unsigned minOperands = getNumNonRangeValues(operands);
if (minOperands != operands.size()) {
if (minOperands)
predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
} else {
predList.emplace_back(pos, builder.getOperandCount(minOperands));
// Check that the operation has the proper number of results. If there are
// any variable length results, we check a minimum instead of an exact count.
OperandRange types = op.types();
unsigned minResults = getNumNonRangeValues(types);
if (minResults == types.size())
predList.emplace_back(pos, builder.getResultCount(types.size()));
else if (minResults)
predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
predList, std::get<1>(it), builder, inputs,
// Process the operands and results of the operation. For all values up to
// the first variable length value, we use the concrete operand/result
// number. After that, we use the "group" given that we can't know the
// concrete indices until runtime. If there is only one variadic operand
// group, we treat it as all of the operands/results of the operation.
/// Operands.
if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
getTreePredicates(predList, operands.front(), builder, inputs,
} else {
bool foundVariableLength = false;
for (auto operandIt : llvm::enumerate(operands)) {
bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
foundVariableLength |= isVariadic;
// Ignore the specified operand, usually because this position was
// visited in an upward traversal via an iterative choice.
if (ignoreOperand && *ignoreOperand == operandIt.index())
Position *pos =
? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
: builder.getOperand(opPos, operandIt.index());
getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
/// Results.
if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
getTreePredicates(predList, types.front(), builder, inputs,
} else {
bool foundVariableLength = false;
for (auto &resultIt : llvm::enumerate(types)) {
bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
foundVariableLength |= isVariadic;
auto *resultPos =
? builder.getResultGroup(pos, resultIt.index(), isVariadic)
: builder.getResult(pos, resultIt.index());
predList.emplace_back(resultPos, builder.getIsNotNull());
getTreePredicates(predList, resultIt.value(), builder, inputs,
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
TypePosition *pos) {
// Check for a constraint on a constant type.
if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
if (Attribute type = typeOp.typeAttr())
predList.emplace_back(pos, builder.getTypeConstraint(type));
} else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
if (Attribute typeAttr = typeOp.typesAttr())
predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
/// Collect the tree predicates anchored at the given value.
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos) {
// Make sure this input value is accessible to the rewrite.
auto it = inputs.try_emplace(val, pos);
if (!it.second) {
// If this is an input value that has been visited in the tree, add a
// constraint to ensure that both instances refer to the same value.
if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
pdl::TypeOp>(val.getDefiningOp())) {
auto minMaxPositions =
std::minmax(pos, it.first->second, comparePosDepth);
TypeSwitch<Position *>(pos)
.Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
getTreePredicates(predList, val, builder, inputs, pos);
.Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
getOperandTreePredicates(predList, val, builder, inputs, pos);
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
/// Collect all of the predicates related to constraints within the given
/// pattern operation.
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
OperandRange arguments = op.args();
ArrayAttr parameters = op.constParamsAttr();
std::vector<Position *> allPositions;
for (Value arg : arguments)
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
PredicateBuilder::Predicate pred =
builder.getConstraint(, std::move(allPositions), parameters);
predList.emplace_back(pos, pred);
static void getResultPredicates(pdl::ResultOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&resultPos = inputs[op];
if (resultPos)
// Ensure that the result isn't null.
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
resultPos = builder.getResult(parentPos, op.index());
predList.emplace_back(resultPos, builder.getIsNotNull());
static void getResultPredicates(pdl::ResultsOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&resultPos = inputs[op];
if (resultPos)
// Ensure that the result isn't null if the result has an index.
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
bool isVariadic = op.getType().isa<pdl::RangeType>();
Optional<unsigned> index = op.index();
resultPos = builder.getResultGroup(parentPos, index, isVariadic);
if (index)
predList.emplace_back(resultPos, builder.getIsNotNull());
/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
TypeSwitch<Operation *>(&op)
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
getConstraintPredicates(constraintOp, predList, builder, inputs);
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
getResultPredicates(resultOp, predList, builder, inputs);
namespace {
/// An op accepting a value at an optional index.
struct OpIndex {
Value parent;
Optional<unsigned> index;
/// The parent and operand index of each operation for each root, stored
/// as a nested map [root][operation].
using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
} // namespace
/// Given a pattern, determines the set of roots present in this pattern.
/// These are the operations whose results are not consumed by other operations.
static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
// First, collect all the operations that are used as operands
// to other operations. These are not roots by default.
DenseSet<Value> used;
for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
for (Value operand : operationOp.operands())
TypeSwitch<Operation *>(operand.getDefiningOp())
.Case<pdl::ResultOp, pdl::ResultsOp>(
[&used](auto resultOp) { used.insert(resultOp.parent()); });
// Remove the specified root from the use set, so that we can
// always select it as a root, even if it is used by other operations.
if (Value root = pattern.getRewriter().root())
// Finally, collect all the unused operations.
SmallVector<Value> roots;
for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
if (!used.contains(operationOp))
return roots;
/// Given a list of candidate roots, builds the cost graph for connecting them.
/// The graph is formed by traversing the DAG of operations starting from each
/// root and marking the depth of each connector value (operand). Then we join
/// the candidate roots based on the common connector values, taking the one
/// with the minimum depth. Along the way, we compute, for each candidate root,
/// a mapping from each operation (in the DAG underneath this root) to its
/// parent operation and the corresponding operand index.
static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
ParentMaps &parentMaps) {
// The entry of a queue. The entry consists of the following items:
// * the value in the DAG underneath the root;
// * the parent of the value;
// * the operand index of the value in its parent;
// * the depth of the visited value.
struct Entry {
Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
: value(value), parent(parent), index(index), depth(depth) {}
Value value;
Value parent;
Optional<unsigned> index;
unsigned depth;
// A root of a value and its depth (distance from root to the value).
struct RootDepth {
Value root;
unsigned depth = 0;
// Map from candidate connector values to their roots and depths. Using a
// small vector with 1 entry because most values belong to a single root.
llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
// Perform a breadth-first traversal of the op DAG rooted at each root.
for (Value root : roots) {
// The queue of visited values. A value may be present multiple times in
// the queue, for multiple parents. We only accept the first occurrence,
// which is guaranteed to have the lowest depth.
std::queue<Entry> toVisit;
toVisit.emplace(root, Value(), 0, 0);
// The map from value to its parent for the current root.
DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
while (!toVisit.empty()) {
Entry entry = toVisit.front();
// Skip if already visited.
if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
// Mark the root and depth of the value.
connectorsRootsDepths[entry.value].push_back({root, entry.depth});
// Traverse the operands of an operation and result ops.
// We intentionally do not traverse attributes and types, because those
// are expensive to join on.
TypeSwitch<Operation *>(entry.value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
OperandRange operands = operationOp.operands();
// Special case when we pass all the operands in one range.
// For those, the index is empty.
if (operands.size() == 1 &&
operands[0].getType().isa<pdl::RangeType>()) {
toVisit.emplace(operands[0], entry.value, llvm::None,
entry.depth + 1);
// Default case: visit all the operands.
for (auto p : llvm::enumerate(operationOp.operands()))
toVisit.emplace(p.value(), entry.value, p.index(),
entry.depth + 1);
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
// Now build the cost graph.
// This is simply a minimum over all depths for the target root.
unsigned nextID = 0;
for (const auto &connectorRootsDepths : connectorsRootsDepths) {
Value value = connectorRootsDepths.first;
ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
// If there is only one root for this value, this will not trigger
// any edges in the cost graph (a perf optimization).
if (rootsDepths.size() == 1)
for (const RootDepth &p : rootsDepths) {
for (const RootDepth &q : rootsDepths) {
if (&p == &q)
// Insert or retrieve the property of edge from p to q.
RootOrderingCost &cost = graph[q.root][p.root];
if (!cost.connector /* new edge */ || cost.cost.first > q.depth) {
if (!cost.connector)
cost.cost.second = nextID++;
cost.cost.first = q.depth;
cost.connector = value;
assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
"the pattern contains a candidate root disconnected from the others");
/// Visit a node during upward traversal.
void visitUpward(std::vector<PositionalPredicate> &predList, OpIndex opIndex,
PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition, Position *&pos,
bool &first) {
Value value = opIndex.parent;
TypeSwitch<Operation *>(value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
OperationPosition *opPos = builder.getUsersOp(pos, opIndex.index);
// Guard against traversing back to where we came from.
if (first) {
Position *parent = pos->getParent();
predList.emplace_back(opPos, builder.getNotEqualTo(parent));
first = false;
// Guard against duplicate upward visits. These are not possible,
// because if this value was already visited, it would have been
// cheaper to start the traversal at this value rather than at the
// `connector`, violating the optimality of our spanning tree.
bool inserted = valueToPosition.try_emplace(value, opPos).second;
assert(inserted && "duplicate upward visit");
// Obtain the tree predicates at the current value.
getTreePredicates(predList, value, builder, valueToPosition, opPos,
// Update the position
pos = opPos;
.Case<pdl::ResultOp>([&](auto resultOp) {
// Traverse up an individual result.
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");
pos = builder.getResult(opPos, *opIndex.index);
.Case<pdl::ResultsOp>([&](auto resultOp) {
// Traverse up a group of results.
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");
bool isVariadic = value.getType().isa<pdl::RangeType>();
if (opIndex.index)
pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
pos = builder.getAllResults(opPos);
/// Given a pattern operation, build the set of matcher predicates necessary to
/// match this pattern.
static Value buildPredicateList(pdl::PatternOp pattern,
PredicateBuilder &builder,
std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &valueToPosition) {
SmallVector<Value> roots = detectRoots(pattern);
// Build the root ordering graph and compute the parent maps.
RootOrderingGraph graph;
ParentMaps parentMaps;
buildCostGraph(roots, graph, parentMaps);
llvm::dbgs() << "Graph:\n";
for (auto &target : graph) {
llvm::dbgs() << " * " << target.first << "\n";
for (auto &source : target.second) {
RootOrderingCost c = source.second;
llvm::dbgs() << " <- " << source.first << ": " << c.cost.first
<< ":" << c.cost.second << " via " << c.connector.getLoc()
<< "\n";
// Solve the optimal branching problem for each candidate root, or use the
// provided one.
Value bestRoot = pattern.getRewriter().root();
OptimalBranching::EdgeList bestEdges;
if (!bestRoot) {
unsigned bestCost = 0;
LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
for (Value root : roots) {
OptimalBranching solver(graph, root);
unsigned cost = solver.solve();
LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
if (!bestRoot || bestCost > cost) {
bestCost = cost;
bestRoot = root;
bestEdges = solver.preOrderTraversal(roots);
} else {
OptimalBranching solver(graph, bestRoot);
bestEdges = solver.preOrderTraversal(roots);
LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
// The best root is the starting point for the traversal. Get the tree
// predicates for the DAG rooted at bestRoot.
getTreePredicates(predList, bestRoot, builder, valueToPosition,
// Traverse the selected optimal branching. For all edges in order, traverse
// up starting from the connector, until the candidate root is reached, and
// call getTreePredicates at every node along the way.
for (const std::pair<Value, Value> &edge : bestEdges) {
Value target = edge.first;
Value source = edge.second;
// Check if we already visited the target root. This happens in two cases:
// 1) the initial root (bestRoot);
// 2) a root that is dominated by (contained in the subtree rooted at) an
// already visited root.
if (valueToPosition.count(target))
// Determine the connector.
Value connector = graph[target][source].connector;
assert(connector && "invalid edge");
LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
Position *pos = valueToPosition.lookup(connector);
assert(pos && "The value has not been traversed yet");
bool first = true;
// Traverse from the connector upwards towards the target root.
for (Value value = connector; value != target;) {
OpIndex opIndex = parentMap.lookup(value);
assert(opIndex.parent && "missing parent");
visitUpward(predList, opIndex, builder, valueToPosition, pos, first);
value = opIndex.parent;
getNonTreePredicates(pattern, predList, builder, valueToPosition);
return bestRoot;
// Pattern Predicate Tree Merging
namespace {
/// This class represents a specific predicate applied to a position, and
/// provides hashing and ordering operators. This class allows for computing a
/// frequence sum and ordering predicates based on a cost model.
struct OrderedPredicate {
OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
: position(ip.first), question(ip.second) {}
OrderedPredicate(const PositionalPredicate &ip)
: position(ip.position), question(ip.question) {}
/// The position this predicate is applied to.
Position *position;
/// The question that is applied by this predicate onto the position.
Qualifier *question;
/// The first and second order benefit sums.
/// The primary sum is the number of occurrences of this predicate among all
/// of the patterns.
unsigned primary = 0;
/// The secondary sum is a squared summation of the primary sum of all of the
/// predicates within each pattern that contains this predicate. This allows
/// for favoring predicates that are more commonly shared within a pattern, as
/// opposed to those shared across patterns.
unsigned secondary = 0;
/// A map between a pattern operation and the answer to the predicate question
/// within that pattern.
DenseMap<Operation *, Qualifier *> patternToAnswer;
/// Returns true if this predicate is ordered before `rhs`, based on the cost
/// model.
bool operator<(const OrderedPredicate &rhs) const {
// Sort by:
// * higher first and secondary order sums
// * lower depth
// * lower position dependency
// * lower predicate dependency
auto *rhsPos = rhs.position;
return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
rhsPos->getKind(), rhs.question->getKind()) >
std::make_tuple(rhs.primary, rhs.secondary,
position->getOperationDepth(), position->getKind(),
/// A DenseMapInfo for OrderedPredicate based solely on the position and
/// question.
struct OrderedPredicateDenseInfo {
using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
static bool isEqual(const OrderedPredicate &lhs,
const OrderedPredicate &rhs) {
return lhs.position == rhs.position && lhs.question == rhs.question;
static unsigned getHashValue(const OrderedPredicate &p) {
return llvm::hash_combine(p.position, p.question);
/// This class wraps a set of ordered predicates that are used within a specific
/// pattern operation.
struct OrderedPredicateList {
OrderedPredicateList(pdl::PatternOp pattern, Value root)
: pattern(pattern), root(root) {}
pdl::PatternOp pattern;
Value root;
DenseSet<OrderedPredicate *> predicates;
} // end anonymous namespace
/// Returns true if the given matcher refers to the same predicate as the given
/// ordered predicate. This means that the position and questions of the two
/// match.
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
return node->getPosition() == predicate->position &&
node->getQuestion() == predicate->question;
/// Get or insert a child matcher for the given parent switch node, given a
/// predicate and parent pattern.
std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
OrderedPredicate *predicate,
pdl::PatternOp pattern) {
assert(isSamePredicate(node, predicate) &&
"expected matcher to equal the given predicate");
auto it = predicate->patternToAnswer.find(pattern);
assert(it != predicate->patternToAnswer.end() &&
"expected pattern to exist in predicate");
return node->getChildren().insert({it->second, nullptr}).first->second;
/// Build the matcher CFG by "pushing" patterns through by sorted predicate
/// order. A pattern will traverse as far as possible using common predicates
/// and then either diverge from the CFG or reach the end of a branch and start
/// creating new nodes.
static void propagatePattern(std::unique_ptr<MatcherNode> &node,
OrderedPredicateList &list,
std::vector<OrderedPredicate *>::iterator current,
std::vector<OrderedPredicate *>::iterator end) {
if (current == end) {
// We've hit the end of a pattern, so create a successful result node.
node =
std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
// If the pattern doesn't contain this predicate, ignore it.
} else if (list.predicates.find(*current) == list.predicates.end()) {
propagatePattern(node, list, std::next(current), end);
// If the current matcher node is invalid, create a new one for this
// position and continue propagation.
} else if (!node) {
// Create a new node at this position and continue
node = std::make_unique<SwitchNode>((*current)->position,
getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
list, std::next(current), end);
// If the matcher has already been created, and it is for this predicate we
// continue propagation to the child.
} else if (isSamePredicate(node.get(), *current)) {
getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
list, std::next(current), end);
// If the matcher doesn't match the current predicate, insert a branch as
// the common set of matchers has diverged.
} else {
propagatePattern(node->getFailureNode(), list, current, end);
/// Fold any switch nodes nested under `node` to boolean nodes when possible.
/// `node` is updated in-place if it is a switch.
static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
if (!node)
if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
SwitchNode::ChildMapT &children = switchNode->getChildren();
for (auto &it : children)
// If the node only contains one child, collapse it into a boolean predicate
// node.
if (children.size() == 1) {
auto childIt = children.begin();
node = std::make_unique<BoolNode>(
node->getPosition(), node->getQuestion(), childIt->first,
std::move(childIt->second), std::move(node->getFailureNode()));
} else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
/// Insert an exit node at the end of the failure path of the `root`.
static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
while (*root)
root = &(*root)->getFailureNode();
*root = std::make_unique<ExitNode>();
/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition) {
// The set of predicates contained within the pattern operations of the
// module.
struct PatternPredicates {
PatternPredicates(pdl::PatternOp pattern, Value root,
std::vector<PositionalPredicate> predicates)
: pattern(pattern), root(root), predicates(std::move(predicates)) {}
/// A pattern.
pdl::PatternOp pattern;
/// A root of the pattern chosen among the candidate roots in pdl.rewrite.
Value root;
/// The extracted predicates for this pattern and root.
std::vector<PositionalPredicate> predicates;
SmallVector<PatternPredicates, 16> patternsAndPredicates;
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
std::vector<PositionalPredicate> predicateList;
Value root =
buildPredicateList(pattern, builder, predicateList, valueToPosition);
patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
// Associate a pattern result with each unique predicate.
DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
for (auto &patternAndPredList : patternsAndPredicates) {
for (auto &predicate : patternAndPredList.predicates) {
auto it = uniqued.insert(predicate);
// Associate each pattern to a set of its ordered predicates for later lookup.
std::vector<OrderedPredicateList> lists;
for (auto &patternAndPredList : patternsAndPredicates) {
OrderedPredicateList list(patternAndPredList.pattern,
for (auto &predicate : patternAndPredList.predicates) {
OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
// Increment the primary sum for each reference to a particular predicate.
// For a particular pattern, get the total primary sum and add it to the
// secondary sum of each predicate. Square the primary sums to emphasize
// shared predicates within rather than across patterns.
for (auto &list : lists) {
unsigned total = 0;
for (auto *predicate : list.predicates)
total += predicate->primary * predicate->primary;
for (auto *predicate : list.predicates)
predicate->secondary += total;
// Sort the set of predicates now that the cost primary and secondary sums
// have been computed.
std::vector<OrderedPredicate *> ordered;
for (auto &ip : uniqued)
ordered.begin(), ordered.end(),
[](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)
propagatePattern(root, list, ordered.begin(), ordered.end());
// Collapse the graph and insert the exit node.
return root;
// MatcherNode
MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
std::unique_ptr<MatcherNode> failureNode)
: position(p), question(q), failureNode(std::move(failureNode)),
matcherTypeID(matcherTypeID) {}
// BoolNode
BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
std::unique_ptr<MatcherNode> successNode,
std::unique_ptr<MatcherNode> failureNode)
: MatcherNode(TypeID::get<BoolNode>(), position, question,
answer(answer), successNode(std::move(successNode)) {}
// SuccessNode
SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
std::unique_ptr<MatcherNode> failureNode)
: MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
/*question=*/nullptr, std::move(failureNode)),
pattern(pattern), root(root) {}
// SwitchNode
SwitchNode::SwitchNode(Position *position, Qualifier *question)
: MatcherNode(TypeID::get<SwitchNode>(), position, question) {}