| //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the Reduction Tree Pass class. It provides a framework for |
| // the implementation of different reduction passes in the MLIR Reduce tool. It |
| // allows for custom specification of the variant generation behavior. It |
| // implements methods that define the different possible traversals of the |
| // reduction tree. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/IR/DialectInterface.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/Reducer/PassDetail.h" |
| #include "mlir/Reducer/Passes.h" |
| #include "mlir/Reducer/ReductionNode.h" |
| #include "mlir/Reducer/ReductionPatternInterface.h" |
| #include "mlir/Reducer/Tester.h" |
| #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/Allocator.h" |
| #include "llvm/Support/ManagedStatic.h" |
| |
| using namespace mlir; |
| |
| /// We implicitly number each operation in the region and if an operation's |
| /// number falls into rangeToKeep, we need to keep it and apply the given |
| /// rewrite patterns on it. |
| static void applyPatterns(Region ®ion, |
| const FrozenRewritePatternSet &patterns, |
| ArrayRef<ReductionNode::Range> rangeToKeep, |
| bool eraseOpNotInRange) { |
| std::vector<Operation *> opsNotInRange; |
| std::vector<Operation *> opsInRange; |
| size_t keepIndex = 0; |
| for (auto op : enumerate(region.getOps())) { |
| int index = op.index(); |
| if (keepIndex < rangeToKeep.size() && |
| index == rangeToKeep[keepIndex].second) |
| ++keepIndex; |
| if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) |
| opsNotInRange.push_back(&op.value()); |
| else |
| opsInRange.push_back(&op.value()); |
| } |
| |
| // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern |
| // matching in above iteration. Besides, erase op not-in-range may end up in |
| // invalid module, so `applyOpPatternsAndFold` should come before that |
| // transform. |
| for (Operation *op : opsInRange) |
| // `applyOpPatternsAndFold` returns whether the op is convered. Omit it |
| // because we don't have expectation this reduction will be success or not. |
| (void)applyOpPatternsAndFold(op, patterns); |
| |
| if (eraseOpNotInRange) |
| for (Operation *op : opsNotInRange) { |
| op->dropAllUses(); |
| op->erase(); |
| } |
| } |
| |
| /// We will apply the reducer patterns to the operations in the ranges specified |
| /// by ReductionNode. Note that we are not able to remove an operation without |
| /// replacing it with another valid operation. However, The validity of module |
| /// reduction is based on the Tester provided by the user and that means certain |
| /// invalid module is still interested by the use. Thus we provide an |
| /// alternative way to remove operations, which is using `eraseOpNotInRange` to |
| /// erase the operations not in the range specified by ReductionNode. |
| template <typename IteratorType> |
| static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
| const FrozenRewritePatternSet &patterns, |
| const Tester &test, bool eraseOpNotInRange) { |
| std::pair<Tester::Interestingness, size_t> initStatus = |
| test.isInteresting(module); |
| // While exploring the reduction tree, we always branch from an interesting |
| // node. Thus the root node must be interesting. |
| if (initStatus.first != Tester::Interestingness::True) |
| return module.emitWarning() << "uninterested module will not be reduced"; |
| |
| llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; |
| |
| std::vector<ReductionNode::Range> ranges{ |
| {0, std::distance(region.op_begin(), region.op_end())}}; |
| |
| ReductionNode *root = allocator.Allocate(); |
| new (root) ReductionNode(nullptr, std::move(ranges), allocator); |
| // Duplicate the module for root node and locate the region in the copy. |
| if (failed(root->initialize(module, region))) |
| llvm_unreachable("unexpected initialization failure"); |
| root->update(initStatus); |
| |
| ReductionNode *smallestNode = root; |
| IteratorType iter(root); |
| |
| while (iter != IteratorType::end()) { |
| ReductionNode ¤tNode = *iter; |
| Region &curRegion = currentNode.getRegion(); |
| |
| applyPatterns(curRegion, patterns, currentNode.getRanges(), |
| eraseOpNotInRange); |
| currentNode.update(test.isInteresting(currentNode.getModule())); |
| |
| if (currentNode.isInteresting() == Tester::Interestingness::True && |
| currentNode.getSize() < smallestNode->getSize()) |
| smallestNode = ¤tNode; |
| |
| ++iter; |
| } |
| |
| // At here, we have found an optimal path to reduce the given region. Retrieve |
| // the path and apply the reducer to it. |
| SmallVector<ReductionNode *> trace; |
| ReductionNode *curNode = smallestNode; |
| trace.push_back(curNode); |
| while (curNode != root) { |
| curNode = curNode->getParent(); |
| trace.push_back(curNode); |
| } |
| |
| // Reduce the region through the optimal path. |
| while (!trace.empty()) { |
| ReductionNode *top = trace.pop_back_val(); |
| applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); |
| } |
| |
| if (test.isInteresting(module).first != Tester::Interestingness::True) |
| llvm::report_fatal_error("Reduced module is not interesting"); |
| if (test.isInteresting(module).second != smallestNode->getSize()) |
| llvm::report_fatal_error( |
| "Reduced module doesn't have consistent size with smallestNode"); |
| return success(); |
| } |
| |
| template <typename IteratorType> |
| static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
| const FrozenRewritePatternSet &patterns, |
| const Tester &test) { |
| // We separate the reduction process into 2 steps, the first one is to erase |
| // redundant operations and the second one is to apply the reducer patterns. |
| |
| // In the first phase, we don't apply any patterns so that we only select the |
| // range of operations to keep to the module stay interesting. |
| if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, |
| /*eraseOpNotInRange=*/true))) |
| return failure(); |
| // In the second phase, we suppose that no operation is redundant, so we try |
| // to rewrite the operation into simpler form. |
| return findOptimal<IteratorType>(module, region, patterns, test, |
| /*eraseOpNotInRange=*/false); |
| } |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Reduction Pattern Interface Collection |
| //===----------------------------------------------------------------------===// |
| |
| class ReductionPatternInterfaceCollection |
| : public DialectInterfaceCollection<DialectReductionPatternInterface> { |
| public: |
| using Base::Base; |
| |
| // Collect the reduce patterns defined by each dialect. |
| void populateReductionPatterns(RewritePatternSet &pattern) const { |
| for (const DialectReductionPatternInterface &interface : *this) |
| interface.populateReductionPatterns(pattern); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ReductionTreePass |
| //===----------------------------------------------------------------------===// |
| |
| /// This class defines the Reduction Tree Pass. It provides a framework to |
| /// to implement a reduction pass using a tree structure to keep track of the |
| /// generated reduced variants. |
| class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> { |
| public: |
| ReductionTreePass() = default; |
| ReductionTreePass(const ReductionTreePass &pass) = default; |
| |
| LogicalResult initialize(MLIRContext *context) override; |
| |
| /// Runs the pass instance in the pass pipeline. |
| void runOnOperation() override; |
| |
| private: |
| LogicalResult reduceOp(ModuleOp module, Region ®ion); |
| |
| FrozenRewritePatternSet reducerPatterns; |
| }; |
| |
| } // end anonymous namespace |
| |
| LogicalResult ReductionTreePass::initialize(MLIRContext *context) { |
| RewritePatternSet patterns(context); |
| ReductionPatternInterfaceCollection reducePatternCollection(context); |
| reducePatternCollection.populateReductionPatterns(patterns); |
| reducerPatterns = std::move(patterns); |
| return success(); |
| } |
| |
| void ReductionTreePass::runOnOperation() { |
| Operation *topOperation = getOperation(); |
| while (topOperation->getParentOp() != nullptr) |
| topOperation = topOperation->getParentOp(); |
| ModuleOp module = cast<ModuleOp>(topOperation); |
| |
| SmallVector<Operation *, 8> workList; |
| workList.push_back(getOperation()); |
| |
| do { |
| Operation *op = workList.pop_back_val(); |
| |
| for (Region ®ion : op->getRegions()) |
| if (!region.empty()) |
| if (failed(reduceOp(module, region))) |
| return signalPassFailure(); |
| |
| for (Region ®ion : op->getRegions()) |
| for (Operation &op : region.getOps()) |
| if (op.getNumRegions() != 0) |
| workList.push_back(&op); |
| } while (!workList.empty()); |
| } |
| |
| LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { |
| Tester test(testerName, testerArgs); |
| switch (traversalModeId) { |
| case TraversalMode::SinglePath: |
| return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( |
| module, region, reducerPatterns, test); |
| default: |
| return module.emitError() << "unsupported traversal mode detected"; |
| } |
| } |
| |
| std::unique_ptr<Pass> mlir::createReductionTreePass() { |
| return std::make_unique<ReductionTreePass>(); |
| } |