| //===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Bytecode/BytecodeWriter.h" |
| #include "mlir/Bytecode/Encoding.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/OwningOpRef.h" |
| #include "mlir/Parser/Parser.h" |
| #include "mlir/Pass/Pass.h" |
| |
| #include <numeric> |
| #include <random> |
| |
| using namespace mlir; |
| |
| namespace { |
| /// This pass tests that: |
| /// 1) we can shuffle use-lists correctly; |
| /// 2) use-list orders are preserved after a roundtrip to bytecode. |
| class TestPreserveUseListOrders |
| : public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders) |
| |
| TestPreserveUseListOrders() = default; |
| TestPreserveUseListOrders(const TestPreserveUseListOrders &pass) |
| : PassWrapper(pass) {} |
| StringRef getArgument() const final { return "test-verify-uselistorder"; } |
| StringRef getDescription() const final { |
| return "Verify that roundtripping the IR to bytecode preserves the order " |
| "of the uselists"; |
| } |
| Option<unsigned> rngSeed{*this, "rng-seed", |
| llvm::cl::desc("Specify an input random seed"), |
| llvm::cl::init(1)}; |
| |
| LogicalResult initialize(MLIRContext *context) override { |
| rng.seed(static_cast<unsigned>(rngSeed)); |
| return success(); |
| } |
| |
| void runOnOperation() override { |
| // Clone the module so that we can plug in this pass to any other |
| // independently. |
| OwningOpRef<ModuleOp> cloneModule = getOperation().clone(); |
| |
| // 1. Compute the op numbering of the module. |
| computeOpNumbering(*cloneModule); |
| |
| // 2. Loop over all the values and shuffle the uses. While doing so, check |
| // that each shuffle is correct. |
| if (failed(shuffleUses(*cloneModule))) |
| return signalPassFailure(); |
| |
| // 3. Do a bytecode roundtrip to version 3, which supports use-list order |
| // preservation. |
| auto roundtripModuleOr = doRoundtripToBytecode(*cloneModule, 3); |
| // If the bytecode roundtrip failed, try to roundtrip the original module |
| // to version 2, which does not support use-list. If this also fails, the |
| // original module had an issue unrelated to uselists. |
| if (failed(roundtripModuleOr)) { |
| auto testModuleOr = doRoundtripToBytecode(getOperation(), 2); |
| if (failed(testModuleOr)) |
| return; |
| |
| return signalPassFailure(); |
| } |
| |
| // 4. Recompute the op numbering on the new module. The numbering should be |
| // the same as (1), but on the new operation pointers. |
| computeOpNumbering(roundtripModuleOr->get()); |
| |
| // 5. Loop over all the values and verify that the use-list is consistent |
| // with the post-shuffle order of step (2). |
| if (failed(verifyUseListOrders(roundtripModuleOr->get()))) |
| return signalPassFailure(); |
| } |
| |
| private: |
| FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module, |
| uint32_t version) { |
| std::string str; |
| llvm::raw_string_ostream m(str); |
| BytecodeWriterConfig config; |
| config.setDesiredBytecodeVersion(version); |
| if (failed(writeBytecodeToFile(module, m, config))) |
| return failure(); |
| |
| ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true); |
| auto newModuleOp = parseSourceString(StringRef(str), parseConfig); |
| if (!newModuleOp.get()) |
| return failure(); |
| return newModuleOp; |
| } |
| |
| /// Compute an ordered numbering for all the operations in the IR. |
| void computeOpNumbering(Operation *topLevelOp) { |
| uint32_t operationID = 0; |
| opNumbering.clear(); |
| topLevelOp->walk<mlir::WalkOrder::PreOrder>( |
| [&](Operation *op) { opNumbering.try_emplace(op, operationID++); }); |
| } |
| |
| template <typename ValueT> |
| SmallVector<uint64_t> getUseIDs(ValueT val) { |
| return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) { |
| return bytecode::getUseID(use, opNumbering.at(use.getOwner())); |
| })); |
| } |
| |
| LogicalResult shuffleUses(Operation *topLevelOp) { |
| uint32_t valueID = 0; |
| /// Permute randomly the use-list of each value. It is guaranteed that at |
| /// least one pair of the use list is permuted. |
| auto doShuffleForRange = [&](ValueRange range) -> LogicalResult { |
| for (auto val : range) { |
| if (val.use_empty() || val.hasOneUse()) |
| continue; |
| |
| /// Get a valid index permutation for the uses of value. |
| SmallVector<unsigned> permutation = getRandomPermutation(val); |
| |
| /// Store original order and verify that the shuffle was applied |
| /// correctly. |
| auto useIDs = getUseIDs(val); |
| |
| /// Apply shuffle to the uselist. |
| val.shuffleUseList(permutation); |
| |
| /// Get the new order and verify the shuffle happened correctly. |
| auto permutedIDs = getUseIDs(val); |
| if (permutedIDs.size() != useIDs.size()) |
| return failure(); |
| for (size_t idx = 0; idx < permutation.size(); idx++) |
| if (useIDs[idx] != permutedIDs[permutation[idx]]) |
| return failure(); |
| |
| referenceUseListOrder.try_emplace( |
| valueID++, llvm::map_range(val.getUses(), [&](auto &use) { |
| return bytecode::getUseID(use, opNumbering.at(use.getOwner())); |
| })); |
| } |
| return success(); |
| }; |
| |
| return walkOverValues(topLevelOp, doShuffleForRange); |
| } |
| |
| LogicalResult verifyUseListOrders(Operation *topLevelOp) { |
| uint32_t valueID = 0; |
| /// Check that the use-list for the value range matches the one stored in |
| /// the reference. |
| auto doValidationForRange = [&](ValueRange range) -> LogicalResult { |
| for (auto val : range) { |
| if (val.use_empty() || val.hasOneUse()) |
| continue; |
| auto referenceOrder = referenceUseListOrder.at(valueID++); |
| for (auto [use, referenceID] : |
| llvm::zip(val.getUses(), referenceOrder)) { |
| uint64_t uniqueID = |
| bytecode::getUseID(use, opNumbering.at(use.getOwner())); |
| if (uniqueID != referenceID) { |
| use.getOwner()->emitError() |
| << "found use-list order mismatch for value: " << val; |
| return failure(); |
| } |
| } |
| } |
| return success(); |
| }; |
| |
| return walkOverValues(topLevelOp, doValidationForRange); |
| } |
| |
| /// Walk over blocks and operations and execute a callable over the ranges of |
| /// operands/results respectively. |
| template <typename FuncT> |
| LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) { |
| auto blockWalk = topLevelOp->walk([&](Block *block) { |
| if (failed(callable(block->getArguments()))) |
| return WalkResult::interrupt(); |
| return WalkResult::advance(); |
| }); |
| |
| if (blockWalk.wasInterrupted()) |
| return failure(); |
| |
| auto resultsWalk = topLevelOp->walk([&](Operation *op) { |
| if (failed(callable(op->getResults()))) |
| return WalkResult::interrupt(); |
| return WalkResult::advance(); |
| }); |
| |
| return failure(resultsWalk.wasInterrupted()); |
| } |
| |
| /// Creates a random permutation of the uselist order chain of the provided |
| /// value. |
| SmallVector<unsigned> getRandomPermutation(Value value) { |
| size_t numUses = std::distance(value.use_begin(), value.use_end()); |
| SmallVector<unsigned> permutation(numUses); |
| unsigned zero = 0; |
| std::iota(permutation.begin(), permutation.end(), zero); |
| std::shuffle(permutation.begin(), permutation.end(), rng); |
| return permutation; |
| } |
| |
| /// Map each value to its use-list order encoded with unique use IDs. |
| DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder; |
| |
| /// Map each operation to its global ID. |
| DenseMap<Operation *, uint32_t> opNumbering; |
| |
| std::default_random_engine rng; |
| }; |
| } // namespace |
| |
| namespace mlir { |
| void registerTestPreserveUseListOrders() { |
| PassRegistration<TestPreserveUseListOrders>(); |
| } |
| } // namespace mlir |