| //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Transform/IR/Utils.h" |
| #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| #include "mlir/IR/Verifier.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "transform-dialect-utils" |
| #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| |
| /// Return whether `func1` can be merged into `func2`. For that to work |
| /// `func1` has to be a declaration (aka has to be external) and `func2` |
| /// either has to be a declaration as well, or it has to be public (otherwise, |
| /// it wouldn't be visible by `func1`). |
| static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { |
| return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
| } |
| |
| /// Merge `func1` into `func2`. The two ops must be inside the same parent op |
| /// and mergable according to `canMergeInto`. The function erases `func1` such |
| /// that only `func2` exists when the function returns. |
| static InFlightDiagnostic mergeInto(FunctionOpInterface func1, |
| FunctionOpInterface func2) { |
| assert(canMergeInto(func1, func2)); |
| assert(func1->getParentOp() == func2->getParentOp() && |
| "expected func1 and func2 to be in the same parent op"); |
| |
| // Check that function signatures match. |
| if (func1.getFunctionType() != func2.getFunctionType()) { |
| return func1.emitError() |
| << "external definition has a mismatching signature (" |
| << func2.getFunctionType() << ")"; |
| } |
| |
| // Check and merge argument attributes. |
| MLIRContext *context = func1->getContext(); |
| auto *td = context->getLoadedDialect<transform::TransformDialect>(); |
| StringAttr consumedName = td->getConsumedAttrName(); |
| StringAttr readOnlyName = td->getReadOnlyAttrName(); |
| for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { |
| bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; |
| bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; |
| bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; |
| bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; |
| if (!isExternalConsumed && !isExternalReadonly) { |
| if (isConsumed) |
| func2.setArgAttr(i, consumedName, UnitAttr::get(context)); |
| else if (isReadonly) |
| func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); |
| continue; |
| } |
| |
| if ((isExternalConsumed && !isConsumed) || |
| (isExternalReadonly && !isReadonly)) { |
| return func1.emitError() |
| << "external definition has mismatching consumption " |
| "annotations for argument #" |
| << i; |
| } |
| } |
| |
| // `func1` is the external one, so we can remove it. |
| assert(func1.isExternal()); |
| func1->erase(); |
| |
| return InFlightDiagnostic(); |
| } |
| |
| InFlightDiagnostic |
| transform::detail::mergeSymbolsInto(Operation *target, |
| OwningOpRef<Operation *> other) { |
| assert(target->hasTrait<OpTrait::SymbolTable>() && |
| "requires target to implement the 'SymbolTable' trait"); |
| assert(other->hasTrait<OpTrait::SymbolTable>() && |
| "requires target to implement the 'SymbolTable' trait"); |
| |
| SymbolTable targetSymbolTable(target); |
| SymbolTable otherSymbolTable(*other); |
| |
| // Step 1: |
| // |
| // Rename private symbols in both ops in order to resolve conflicts that can |
| // be resolved that way. |
| LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); |
| // TODO: Do we *actually* need to test in both directions? |
| for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( |
| SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable}, |
| SmallVector<SymbolTable *, 2>{&otherSymbolTable, |
| &targetSymbolTable})) { |
| Operation *symbolTableOp = symbolTable->getOp(); |
| for (Operation &op : symbolTableOp->getRegion(0).front()) { |
| auto symbolOp = dyn_cast<SymbolOpInterface>(op); |
| if (!symbolOp) |
| continue; |
| StringAttr name = symbolOp.getNameAttr(); |
| LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); |
| |
| // Check if there is a colliding op in the other module. |
| auto collidingOp = |
| cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name)); |
| if (!collidingOp) |
| continue; |
| |
| LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); |
| |
| // Collisions are fine if both opt are functions and can be merged. |
| if (auto funcOp = dyn_cast<FunctionOpInterface>(op), |
| collidingFuncOp = |
| dyn_cast<FunctionOpInterface>(collidingOp.getOperation()); |
| funcOp && collidingFuncOp) { |
| if (canMergeInto(funcOp, collidingFuncOp) || |
| canMergeInto(collidingFuncOp, funcOp)) { |
| LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " |
| "will be merged\n"); |
| continue; |
| } |
| |
| // If they can't be merged, proceed like any other collision. |
| LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); |
| } |
| |
| // Collision can be resolved by renaming if one of the ops is private. |
| auto renameToUnique = |
| [&](SymbolOpInterface op, SymbolOpInterface otherOp, |
| SymbolTable &symbolTable, |
| SymbolTable &otherSymbolTable) -> InFlightDiagnostic { |
| LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); |
| FailureOr<StringAttr> maybeNewName = |
| symbolTable.renameToUnique(op, {&otherSymbolTable}); |
| if (failed(maybeNewName)) { |
| InFlightDiagnostic diag = op->emitError("failed to rename symbol"); |
| diag.attachNote(otherOp->getLoc()) |
| << "attempted renaming due to collision with this op"; |
| return diag; |
| } |
| LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() |
| << "\n"); |
| return InFlightDiagnostic(); |
| }; |
| |
| if (symbolOp.isPrivate()) { |
| InFlightDiagnostic diag = renameToUnique( |
| symbolOp, collidingOp, *symbolTable, *otherSymbolTable); |
| if (failed(diag)) |
| return diag; |
| continue; |
| } |
| if (collidingOp.isPrivate()) { |
| InFlightDiagnostic diag = renameToUnique( |
| collidingOp, symbolOp, *otherSymbolTable, *symbolTable); |
| if (failed(diag)) |
| return diag; |
| continue; |
| } |
| LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); |
| InFlightDiagnostic diag = symbolOp.emitError() |
| << "doubly defined symbol @" << name.getValue(); |
| diag.attachNote(collidingOp->getLoc()) << "previously defined here"; |
| return diag; |
| } |
| } |
| |
| // TODO: This duplicates pass infrastructure. We should split this pass into |
| // several and let the pass infrastructure do the verification. |
| for (auto *op : SmallVector<Operation *>{target, *other}) { |
| if (failed(mlir::verify(op))) |
| return op->emitError() << "failed to verify input op after renaming"; |
| } |
| |
| // Step 2: |
| // |
| // Move all ops from `other` into target and merge public symbols. |
| LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); |
| { |
| SmallVector<SymbolOpInterface> opsToMove; |
| for (Operation &op : other->getRegion(0).front()) { |
| if (auto symbol = dyn_cast<SymbolOpInterface>(op)) |
| opsToMove.push_back(symbol); |
| } |
| |
| for (SymbolOpInterface op : opsToMove) { |
| // Remember potentially colliding op in the target module. |
| auto collidingOp = cast_or_null<SymbolOpInterface>( |
| targetSymbolTable.lookup(op.getNameAttr())); |
| |
| // Move op even if we get a collision. |
| LLVM_DEBUG(DBGS() << " moving @" << op.getName()); |
| op->moveBefore(&target->getRegion(0).front(), |
| target->getRegion(0).front().end()); |
| |
| // If there is no collision, we are done. |
| if (!collidingOp) { |
| LLVM_DEBUG(llvm::dbgs() << " without collision\n"); |
| continue; |
| } |
| |
| // The two colliding ops must both be functions because we have already |
| // emitted errors otherwise earlier. |
| auto funcOp = cast<FunctionOpInterface>(op.getOperation()); |
| auto collidingFuncOp = |
| cast<FunctionOpInterface>(collidingOp.getOperation()); |
| |
| // Both ops are in the target module now and can be treated |
| // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into |
| // `collidingFuncOp`. |
| if (!canMergeInto(funcOp, collidingFuncOp)) { |
| std::swap(funcOp, collidingFuncOp); |
| } |
| assert(canMergeInto(funcOp, collidingFuncOp)); |
| |
| LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " |
| << collidingFuncOp.getLoc() << ":\n" |
| << collidingFuncOp << "\n"); |
| |
| // Update symbol table. This works with or without the previous `swap`. |
| targetSymbolTable.remove(funcOp); |
| targetSymbolTable.insert(collidingFuncOp); |
| assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); |
| |
| // Do the actual merging. |
| { |
| InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp); |
| if (failed(diag)) |
| return diag; |
| } |
| } |
| } |
| |
| if (failed(mlir::verify(target))) |
| return target->emitError() |
| << "failed to verify target op after merging symbols"; |
| |
| LLVM_DEBUG(DBGS() << "done merging ops\n"); |
| return InFlightDiagnostic(); |
| } |