| //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v[1].0 with LLVM |
| // Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "gtest/gtest.h" |
| |
| using namespace mlir; |
| using namespace mlir::pdl_to_pdl_interp; |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Test Fixture |
| //===----------------------------------------------------------------------===// |
| |
| /// The test fixture for constructing root ordering tests and verifying results. |
| /// This fixture constructs the test values v. The test populates the graph |
| /// with the desired costs and then calls check(), passing the expeted optimal |
| /// cost and the list of edges in the preorder traversal of the optimal |
| /// branching. |
| class RootOrderingTest : public ::testing::Test { |
| protected: |
| RootOrderingTest() { |
| context.loadDialect<StandardOpsDialect>(); |
| createValues(); |
| } |
| |
| /// Creates the test values. |
| void createValues() { |
| OpBuilder builder(&context); |
| for (int i = 0; i < 4; ++i) |
| v[i] = builder.create<ConstantOp>(builder.getUnknownLoc(), |
| builder.getI32IntegerAttr(i)); |
| } |
| |
| /// Checks that optimal branching on graph has the given cost and |
| /// its preorder traversal results in the specified edges. |
| void check(unsigned cost, OptimalBranching::EdgeList edges) { |
| OptimalBranching opt(graph, v[0]); |
| EXPECT_EQ(opt.solve(), cost); |
| EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges); |
| for (std::pair<Value, Value> edge : edges) |
| EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second); |
| } |
| |
| protected: |
| /// The context for creating the values. |
| MLIRContext context; |
| |
| /// Values used in the graph definition. We always use leading `n` values. |
| Value v[4]; |
| |
| /// The graph being tested on. |
| RootOrderingGraph graph; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Simple 3-node graphs |
| //===----------------------------------------------------------------------===// |
| |
| TEST_F(RootOrderingTest, simpleA) { |
| graph[v[1]][v[0]].cost = {1, 10}; |
| graph[v[2]][v[0]].cost = {1, 11}; |
| graph[v[1]][v[2]].cost = {2, 12}; |
| graph[v[2]][v[1]].cost = {2, 13}; |
| check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}}); |
| } |
| |
| TEST_F(RootOrderingTest, simpleB) { |
| graph[v[1]][v[0]].cost = {1, 10}; |
| graph[v[2]][v[0]].cost = {2, 11}; |
| graph[v[1]][v[2]].cost = {1, 12}; |
| graph[v[2]][v[1]].cost = {1, 13}; |
| check(2, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); |
| } |
| |
| TEST_F(RootOrderingTest, simpleC) { |
| graph[v[1]][v[0]].cost = {2, 10}; |
| graph[v[2]][v[0]].cost = {2, 11}; |
| graph[v[1]][v[2]].cost = {1, 12}; |
| graph[v[2]][v[1]].cost = {1, 13}; |
| check(3, {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Graph for testing contraction |
| //===----------------------------------------------------------------------===// |
| |
| TEST_F(RootOrderingTest, contraction) { |
| graph[v[1]][v[0]].cost = {10, 0}; |
| graph[v[2]][v[0]].cost = {5, 0}; |
| graph[v[2]][v[1]].cost = {1, 0}; |
| graph[v[3]][v[2]].cost = {2, 0}; |
| graph[v[1]][v[3]].cost = {3, 0}; |
| check(10, {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}}); |
| } |
| |
| } // end namespace |