| //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines various operation fold utilities. These utilities are |
| // intended to be used by passes to unify and simply their logic. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Transforms/FoldUtils.h" |
| |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Operation.h" |
| |
| using namespace mlir; |
| |
| /// Given an operation, find the parent region that folded constants should be |
| /// inserted into. |
| static Region * |
| getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, |
| Block *insertionBlock) { |
| while (Region *region = insertionBlock->getParent()) { |
| // Insert in this region for any of the following scenarios: |
| // * The parent is unregistered, or is known to be isolated from above. |
| // * The parent is a top-level operation. |
| auto *parentOp = region->getParentOp(); |
| if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || |
| !parentOp->getBlock()) |
| return region; |
| |
| // Otherwise, check if this region is a desired insertion region. |
| auto *interface = interfaces.getInterfaceFor(parentOp); |
| if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) |
| return region; |
| |
| // Traverse up the parent looking for an insertion region. |
| insertionBlock = parentOp->getBlock(); |
| } |
| llvm_unreachable("expected valid insertion region"); |
| } |
| |
| /// A utility function used to materialize a constant for a given attribute and |
| /// type. On success, a valid constant value is returned. Otherwise, null is |
| /// returned |
| static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, |
| Attribute value, Type type, |
| Location loc) { |
| auto insertPt = builder.getInsertionPoint(); |
| (void)insertPt; |
| |
| // Ask the dialect to materialize a constant operation for this value. |
| if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { |
| assert(insertPt == builder.getInsertionPoint()); |
| assert(matchPattern(constOp, m_Constant())); |
| return constOp; |
| } |
| |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OperationFolder |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) { |
| if (inPlaceUpdate) |
| *inPlaceUpdate = false; |
| |
| // If this is a unique'd constant, return failure as we know that it has |
| // already been folded. |
| if (isFolderOwnedConstant(op)) { |
| // Check to see if we should rehoist, i.e. if a non-constant operation was |
| // inserted before this one. |
| Block *opBlock = op->getBlock(); |
| if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) { |
| op->moveBefore(&opBlock->front()); |
| op->setLoc(erasedFoldedLocation); |
| } |
| return failure(); |
| } |
| |
| // Try to fold the operation. |
| SmallVector<Value, 8> results; |
| if (failed(tryToFold(op, results))) |
| return failure(); |
| |
| // Check to see if the operation was just updated in place. |
| if (results.empty()) { |
| if (inPlaceUpdate) |
| *inPlaceUpdate = true; |
| if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>( |
| rewriter.getListener())) { |
| // Folding API does not notify listeners, so we have to notify manually. |
| rewriteListener->notifyOperationModified(op); |
| } |
| return success(); |
| } |
| |
| // Constant folding succeeded. Replace all of the result values and erase the |
| // operation. |
| notifyRemoval(op); |
| rewriter.replaceOp(op, results); |
| return success(); |
| } |
| |
| bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { |
| Block *opBlock = op->getBlock(); |
| |
| // If this is a constant we unique'd, we don't need to insert, but we can |
| // check to see if we should rehoist it. |
| if (isFolderOwnedConstant(op)) { |
| if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) { |
| op->moveBefore(&opBlock->front()); |
| op->setLoc(erasedFoldedLocation); |
| } |
| return true; |
| } |
| |
| // Get the constant value of the op if necessary. |
| if (!constValue) { |
| matchPattern(op, m_Constant(&constValue)); |
| assert(constValue && "expected `op` to be a constant"); |
| } else { |
| // Ensure that the provided constant was actually correct. |
| #ifndef NDEBUG |
| Attribute expectedValue; |
| matchPattern(op, m_Constant(&expectedValue)); |
| assert( |
| expectedValue == constValue && |
| "provided constant value was not the expected value of the constant"); |
| #endif |
| } |
| |
| // Check for an existing constant operation for the attribute value. |
| Region *insertRegion = getInsertionRegion(interfaces, opBlock); |
| auto &uniquedConstants = foldScopes[insertRegion]; |
| Operation *&folderConstOp = uniquedConstants[std::make_tuple( |
| op->getDialect(), constValue, *op->result_type_begin())]; |
| |
| // If there is an existing constant, replace `op`. |
| if (folderConstOp) { |
| notifyRemoval(op); |
| rewriter.replaceOp(op, folderConstOp->getResults()); |
| folderConstOp->setLoc(erasedFoldedLocation); |
| return false; |
| } |
| |
| // Otherwise, we insert `op`. If `op` is in the insertion block and is either |
| // already at the front of the block, or the previous operation is already a |
| // constant we unique'd (i.e. one we inserted), then we don't need to do |
| // anything. Otherwise, we move the constant to the insertion block. |
| // The location info is erased if the constant is moved to a different block. |
| Block *insertBlock = &insertRegion->front(); |
| if (opBlock != insertBlock) { |
| op->moveBefore(&insertBlock->front()); |
| op->setLoc(erasedFoldedLocation); |
| } else if (&insertBlock->front() != op && |
| !isFolderOwnedConstant(op->getPrevNode())) { |
| op->moveBefore(&insertBlock->front()); |
| } |
| |
| folderConstOp = op; |
| referencedDialects[op].push_back(op->getDialect()); |
| return true; |
| } |
| |
| /// Notifies that the given constant `op` should be remove from this |
| /// OperationFolder's internal bookkeeping. |
| void OperationFolder::notifyRemoval(Operation *op) { |
| // Check to see if this operation is uniqued within the folder. |
| auto it = referencedDialects.find(op); |
| if (it == referencedDialects.end()) |
| return; |
| |
| // Get the constant value for this operation, this is the value that was used |
| // to unique the operation internally. |
| Attribute constValue; |
| matchPattern(op, m_Constant(&constValue)); |
| assert(constValue); |
| |
| // Get the constant map that this operation was uniqued in. |
| auto &uniquedConstants = |
| foldScopes[getInsertionRegion(interfaces, op->getBlock())]; |
| |
| // Erase all of the references to this operation. |
| auto type = op->getResult(0).getType(); |
| for (auto *dialect : it->second) |
| uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); |
| referencedDialects.erase(it); |
| } |
| |
| /// Clear out any constants cached inside of the folder. |
| void OperationFolder::clear() { |
| foldScopes.clear(); |
| referencedDialects.clear(); |
| } |
| |
| /// Get or create a constant using the given builder. On success this returns |
| /// the constant operation, nullptr otherwise. |
| Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect, |
| Attribute value, Type type) { |
| // Find an insertion point for the constant. |
| auto *insertRegion = getInsertionRegion(interfaces, block); |
| auto &entry = insertRegion->front(); |
| rewriter.setInsertionPointToStart(&entry); |
| |
| // Get the constant map for the insertion region of this operation. |
| // Use erased location since the op is being built at the front of block. |
| auto &uniquedConstants = foldScopes[insertRegion]; |
| Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value, |
| type, erasedFoldedLocation); |
| return constOp ? constOp->getResult(0) : Value(); |
| } |
| |
| bool OperationFolder::isFolderOwnedConstant(Operation *op) const { |
| return referencedDialects.count(op); |
| } |
| |
| /// Tries to perform folding on the given `op`. If successful, populates |
| /// `results` with the results of the folding. |
| LogicalResult OperationFolder::tryToFold(Operation *op, |
| SmallVectorImpl<Value> &results) { |
| SmallVector<OpFoldResult, 8> foldResults; |
| if (failed(op->fold(foldResults)) || |
| failed(processFoldResults(op, results, foldResults))) |
| return failure(); |
| return success(); |
| } |
| |
| LogicalResult |
| OperationFolder::processFoldResults(Operation *op, |
| SmallVectorImpl<Value> &results, |
| ArrayRef<OpFoldResult> foldResults) { |
| // Check to see if the operation was just updated in place. |
| if (foldResults.empty()) |
| return success(); |
| assert(foldResults.size() == op->getNumResults()); |
| |
| // Create a builder to insert new operations into the entry block of the |
| // insertion region. |
| auto *insertRegion = getInsertionRegion(interfaces, op->getBlock()); |
| auto &entry = insertRegion->front(); |
| rewriter.setInsertionPointToStart(&entry); |
| |
| // Get the constant map for the insertion region of this operation. |
| auto &uniquedConstants = foldScopes[insertRegion]; |
| |
| // Create the result constants and replace the results. |
| auto *dialect = op->getDialect(); |
| for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { |
| assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); |
| |
| // Check if the result was an SSA value. |
| if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) { |
| results.emplace_back(repl); |
| continue; |
| } |
| |
| // Check to see if there is a canonicalized version of this constant. |
| auto res = op->getResult(i); |
| Attribute attrRepl = cast<Attribute>(foldResults[i]); |
| if (auto *constOp = |
| tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl, |
| res.getType(), erasedFoldedLocation)) { |
| // Ensure that this constant dominates the operation we are replacing it |
| // with. This may not automatically happen if the operation being folded |
| // was inserted before the constant within the insertion block. |
| Block *opBlock = op->getBlock(); |
| if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) |
| constOp->moveBefore(&opBlock->front()); |
| |
| results.push_back(constOp->getResult(0)); |
| continue; |
| } |
| // If materialization fails, cleanup any operations generated for the |
| // previous results and return failure. |
| for (Operation &op : llvm::make_early_inc_range( |
| llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) { |
| notifyRemoval(&op); |
| rewriter.eraseOp(&op); |
| } |
| |
| results.clear(); |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Try to get or create a new constant entry. On success this returns the |
| /// constant operation value, nullptr otherwise. |
| Operation * |
| OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants, |
| Dialect *dialect, Attribute value, |
| Type type, Location loc) { |
| // Check if an existing mapping already exists. |
| auto constKey = std::make_tuple(dialect, value, type); |
| Operation *&constOp = uniquedConstants[constKey]; |
| if (constOp) { |
| if (loc != constOp->getLoc()) |
| constOp->setLoc(erasedFoldedLocation); |
| return constOp; |
| } |
| |
| // If one doesn't exist, try to materialize one. |
| if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc))) |
| return nullptr; |
| |
| // Check to see if the generated constant is in the expected dialect. |
| auto *newDialect = constOp->getDialect(); |
| if (newDialect == dialect) { |
| referencedDialects[constOp].push_back(dialect); |
| return constOp; |
| } |
| |
| // If it isn't, then we also need to make sure that the mapping for the new |
| // dialect is valid. |
| auto newKey = std::make_tuple(newDialect, value, type); |
| |
| // If an existing operation in the new dialect already exists, delete the |
| // materialized operation in favor of the existing one. |
| if (auto *existingOp = uniquedConstants.lookup(newKey)) { |
| notifyRemoval(constOp); |
| rewriter.eraseOp(constOp); |
| referencedDialects[existingOp].push_back(dialect); |
| if (loc != existingOp->getLoc()) |
| existingOp->setLoc(erasedFoldedLocation); |
| return constOp = existingOp; |
| } |
| |
| // Otherwise, update the new dialect to the materialized operation. |
| referencedDialects[constOp].assign({dialect, newDialect}); |
| auto newIt = uniquedConstants.insert({newKey, constOp}); |
| return newIt.first->second; |
| } |