|  | //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include <utility> | 
|  |  | 
|  | #include "mlir/IR/BuiltinTypes.h" | 
|  | #include "mlir/Interfaces/ControlFlowInterfaces.h" | 
|  | #include "llvm/ADT/SmallPtrSet.h" | 
|  |  | 
|  | using namespace mlir; | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // ControlFlowInterfaces | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" | 
|  |  | 
|  | SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) | 
|  | : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { | 
|  | } | 
|  |  | 
|  | SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, | 
|  | MutableOperandRange forwardedOperands) | 
|  | : producedOperandCount(producedOperandCount), | 
|  | forwardedOperands(std::move(forwardedOperands)) {} | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // BranchOpInterface | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some | 
|  | /// successor if 'operandIndex' is within the range of 'operands', or | 
|  | /// std::nullopt if `operandIndex` isn't a successor operand index. | 
|  | std::optional<BlockArgument> | 
|  | detail::getBranchSuccessorArgument(const SuccessorOperands &operands, | 
|  | unsigned operandIndex, Block *successor) { | 
|  | OperandRange forwardedOperands = operands.getForwardedOperands(); | 
|  | // Check that the operands are valid. | 
|  | if (forwardedOperands.empty()) | 
|  | return std::nullopt; | 
|  |  | 
|  | // Check to ensure that this operand is within the range. | 
|  | unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); | 
|  | if (operandIndex < operandsStart || | 
|  | operandIndex >= (operandsStart + forwardedOperands.size())) | 
|  | return std::nullopt; | 
|  |  | 
|  | // Index the successor. | 
|  | unsigned argIndex = | 
|  | operands.getProducedOperandCount() + operandIndex - operandsStart; | 
|  | return successor->getArgument(argIndex); | 
|  | } | 
|  |  | 
|  | /// Verify that the given operands match those of the given successor block. | 
|  | LogicalResult | 
|  | detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, | 
|  | const SuccessorOperands &operands) { | 
|  | // Check the count. | 
|  | unsigned operandCount = operands.size(); | 
|  | Block *destBB = op->getSuccessor(succNo); | 
|  | if (operandCount != destBB->getNumArguments()) | 
|  | return op->emitError() << "branch has " << operandCount | 
|  | << " operands for successor #" << succNo | 
|  | << ", but target block has " | 
|  | << destBB->getNumArguments(); | 
|  |  | 
|  | // Check the types. | 
|  | for (unsigned i = operands.getProducedOperandCount(); i != operandCount; | 
|  | ++i) { | 
|  | if (!cast<BranchOpInterface>(op).areTypesCompatible( | 
|  | operands[i].getType(), destBB->getArgument(i).getType())) | 
|  | return op->emitError() << "type mismatch for bb argument #" << i | 
|  | << " of successor #" << succNo; | 
|  | } | 
|  | return success(); | 
|  | } | 
|  |  | 
|  | //===----------------------------------------------------------------------===// | 
|  | // RegionBranchOpInterface | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, | 
|  | RegionBranchPoint sourceNo, | 
|  | RegionBranchPoint succRegionNo) { | 
|  | diag << "from "; | 
|  | if (Region *region = sourceNo.getRegionOrNull()) | 
|  | diag << "Region #" << region->getRegionNumber(); | 
|  | else | 
|  | diag << "parent operands"; | 
|  |  | 
|  | diag << " to "; | 
|  | if (Region *region = succRegionNo.getRegionOrNull()) | 
|  | diag << "Region #" << region->getRegionNumber(); | 
|  | else | 
|  | diag << "parent results"; | 
|  | return diag; | 
|  | } | 
|  |  | 
|  | /// Verify that types match along all region control flow edges originating from | 
|  | /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the | 
|  | /// types of the inputs that flow to a successor region. | 
|  | static LogicalResult | 
|  | verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, | 
|  | function_ref<FailureOr<TypeRange>(RegionBranchPoint)> | 
|  | getInputsTypesForRegion) { | 
|  | auto regionInterface = cast<RegionBranchOpInterface>(op); | 
|  |  | 
|  | SmallVector<RegionSuccessor, 2> successors; | 
|  | regionInterface.getSuccessorRegions(sourcePoint, successors); | 
|  |  | 
|  | for (RegionSuccessor &succ : successors) { | 
|  | FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ); | 
|  | if (failed(sourceTypes)) | 
|  | return failure(); | 
|  |  | 
|  | TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); | 
|  | if (sourceTypes->size() != succInputsTypes.size()) { | 
|  | InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); | 
|  | return printRegionEdgeName(diag, sourcePoint, succ) | 
|  | << ": source has " << sourceTypes->size() | 
|  | << " operands, but target successor needs " | 
|  | << succInputsTypes.size(); | 
|  | } | 
|  |  | 
|  | for (const auto &typesIdx : | 
|  | llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { | 
|  | Type sourceType = std::get<0>(typesIdx.value()); | 
|  | Type inputType = std::get<1>(typesIdx.value()); | 
|  | if (!regionInterface.areTypesCompatible(sourceType, inputType)) { | 
|  | InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); | 
|  | return printRegionEdgeName(diag, sourcePoint, succ) | 
|  | << ": source type #" << typesIdx.index() << " " << sourceType | 
|  | << " should match input type #" << typesIdx.index() << " " | 
|  | << inputType; | 
|  | } | 
|  | } | 
|  | } | 
|  | return success(); | 
|  | } | 
|  |  | 
|  | /// Verify that types match along control flow edges described the given op. | 
|  | LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { | 
|  | auto regionInterface = cast<RegionBranchOpInterface>(op); | 
|  |  | 
|  | auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange { | 
|  | return regionInterface.getEntrySuccessorOperands(point).getTypes(); | 
|  | }; | 
|  |  | 
|  | // Verify types along control flow edges originating from the parent. | 
|  | if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(), | 
|  | inputTypesFromParent))) | 
|  | return failure(); | 
|  |  | 
|  | auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { | 
|  | if (lhs.size() != rhs.size()) | 
|  | return false; | 
|  | for (auto types : llvm::zip(lhs, rhs)) { | 
|  | if (!regionInterface.areTypesCompatible(std::get<0>(types), | 
|  | std::get<1>(types))) { | 
|  | return false; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | }; | 
|  |  | 
|  | // Verify types along control flow edges originating from each region. | 
|  | for (Region ®ion : op->getRegions()) { | 
|  |  | 
|  | // Since there can be multiple terminators implementing the | 
|  | // `RegionBranchTerminatorOpInterface`, all should have the same operand | 
|  | // types when passing them to the same region. | 
|  |  | 
|  | SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps; | 
|  | for (Block &block : region) | 
|  | if (!block.empty()) | 
|  | if (auto terminator = | 
|  | dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) | 
|  | regionReturnOps.push_back(terminator); | 
|  |  | 
|  | // If there is no return-like terminator, the op itself should verify | 
|  | // type consistency. | 
|  | if (regionReturnOps.empty()) | 
|  | continue; | 
|  |  | 
|  | auto inputTypesForRegion = | 
|  | [&](RegionBranchPoint point) -> FailureOr<TypeRange> { | 
|  | std::optional<OperandRange> regionReturnOperands; | 
|  | for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { | 
|  | auto terminatorOperands = regionReturnOp.getSuccessorOperands(point); | 
|  |  | 
|  | if (!regionReturnOperands) { | 
|  | regionReturnOperands = terminatorOperands; | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // Found more than one ReturnLike terminator. Make sure the operand | 
|  | // types match with the first one. | 
|  | if (!areTypesCompatible(regionReturnOperands->getTypes(), | 
|  | terminatorOperands.getTypes())) { | 
|  | InFlightDiagnostic diag = op->emitOpError(" along control flow edge"); | 
|  | return printRegionEdgeName(diag, region, point) | 
|  | << " operands mismatch between return-like terminators"; | 
|  | } | 
|  | } | 
|  |  | 
|  | // All successors get the same set of operand types. | 
|  | return TypeRange(regionReturnOperands->getTypes()); | 
|  | }; | 
|  |  | 
|  | if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion))) | 
|  | return failure(); | 
|  | } | 
|  |  | 
|  | return success(); | 
|  | } | 
|  |  | 
|  | /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if | 
|  | /// this function returns "true" for a successor region. The first parameter is | 
|  | /// the successor region. The second parameter indicates all already visited | 
|  | /// regions. | 
|  | using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>; | 
|  |  | 
|  | /// Traverse the region graph starting at `begin`. The traversal is interrupted | 
|  | /// if `stopCondition` evaluates to "true" for a successor region. In that case, | 
|  | /// this function returns "true". Otherwise, if the traversal was not | 
|  | /// interrupted, this function returns "false". | 
|  | static bool traverseRegionGraph(Region *begin, | 
|  | StopConditionFn stopConditionFn) { | 
|  | auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); | 
|  | SmallVector<bool> visited(op->getNumRegions(), false); | 
|  | visited[begin->getRegionNumber()] = true; | 
|  |  | 
|  | // Retrieve all successors of the region and enqueue them in the worklist. | 
|  | SmallVector<Region *> worklist; | 
|  | auto enqueueAllSuccessors = [&](Region *region) { | 
|  | SmallVector<RegionSuccessor> successors; | 
|  | op.getSuccessorRegions(region, successors); | 
|  | for (RegionSuccessor successor : successors) | 
|  | if (!successor.isParent()) | 
|  | worklist.push_back(successor.getSuccessor()); | 
|  | }; | 
|  | enqueueAllSuccessors(begin); | 
|  |  | 
|  | // Process all regions in the worklist via DFS. | 
|  | while (!worklist.empty()) { | 
|  | Region *nextRegion = worklist.pop_back_val(); | 
|  | if (stopConditionFn(nextRegion, visited)) | 
|  | return true; | 
|  | if (visited[nextRegion->getRegionNumber()]) | 
|  | continue; | 
|  | visited[nextRegion->getRegionNumber()] = true; | 
|  | enqueueAllSuccessors(nextRegion); | 
|  | } | 
|  |  | 
|  | return false; | 
|  | } | 
|  |  | 
|  | /// Return `true` if region `r` is reachable from region `begin` according to | 
|  | /// the RegionBranchOpInterface (by taking a branch). | 
|  | static bool isRegionReachable(Region *begin, Region *r) { | 
|  | assert(begin->getParentOp() == r->getParentOp() && | 
|  | "expected that both regions belong to the same op"); | 
|  | return traverseRegionGraph(begin, | 
|  | [&](Region *nextRegion, ArrayRef<bool> visited) { | 
|  | // Interrupt traversal if `r` was reached. | 
|  | return nextRegion == r; | 
|  | }); | 
|  | } | 
|  |  | 
|  | /// Return `true` if `a` and `b` are in mutually exclusive regions. | 
|  | /// | 
|  | /// 1. Find the first common of `a` and `b` (ancestor) that implements | 
|  | ///    RegionBranchOpInterface. | 
|  | /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are | 
|  | ///    contained. | 
|  | /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are | 
|  | ///    mutually exclusive if they are not reachable from each other as per | 
|  | ///    RegionBranchOpInterface::getSuccessorRegions. | 
|  | bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { | 
|  | assert(a && "expected non-empty operation"); | 
|  | assert(b && "expected non-empty operation"); | 
|  |  | 
|  | auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); | 
|  | while (branchOp) { | 
|  | // Check if b is inside branchOp. (We already know that a is.) | 
|  | if (!branchOp->isProperAncestor(b)) { | 
|  | // Check next enclosing RegionBranchOpInterface. | 
|  | branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // b is contained in branchOp. Retrieve the regions in which `a` and `b` | 
|  | // are contained. | 
|  | Region *regionA = nullptr, *regionB = nullptr; | 
|  | for (Region &r : branchOp->getRegions()) { | 
|  | if (r.findAncestorOpInRegion(*a)) { | 
|  | assert(!regionA && "already found a region for a"); | 
|  | regionA = &r; | 
|  | } | 
|  | if (r.findAncestorOpInRegion(*b)) { | 
|  | assert(!regionB && "already found a region for b"); | 
|  | regionB = &r; | 
|  | } | 
|  | } | 
|  | assert(regionA && regionB && "could not find region of op"); | 
|  |  | 
|  | // `a` and `b` are in mutually exclusive regions if both regions are | 
|  | // distinct and neither region is reachable from the other region. | 
|  | return regionA != regionB && !isRegionReachable(regionA, regionB) && | 
|  | !isRegionReachable(regionB, regionA); | 
|  | } | 
|  |  | 
|  | // Could not find a common RegionBranchOpInterface among a's and b's | 
|  | // ancestors. | 
|  | return false; | 
|  | } | 
|  |  | 
|  | bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { | 
|  | Region *region = &getOperation()->getRegion(index); | 
|  | return isRegionReachable(region, region); | 
|  | } | 
|  |  | 
|  | bool RegionBranchOpInterface::hasLoop() { | 
|  | SmallVector<RegionSuccessor> entryRegions; | 
|  | getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); | 
|  | for (RegionSuccessor successor : entryRegions) | 
|  | if (!successor.isParent() && | 
|  | traverseRegionGraph(successor.getSuccessor(), | 
|  | [](Region *nextRegion, ArrayRef<bool> visited) { | 
|  | // Interrupt traversal if the region was already | 
|  | // visited. | 
|  | return visited[nextRegion->getRegionNumber()]; | 
|  | })) | 
|  | return true; | 
|  | return false; | 
|  | } | 
|  |  | 
|  | Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { | 
|  | while (Region *region = op->getParentRegion()) { | 
|  | op = region->getParentOp(); | 
|  | if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) | 
|  | if (branchOp.isRepetitiveRegion(region->getRegionNumber())) | 
|  | return region; | 
|  | } | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | Region *mlir::getEnclosingRepetitiveRegion(Value value) { | 
|  | Region *region = value.getParentRegion(); | 
|  | while (region) { | 
|  | Operation *op = region->getParentOp(); | 
|  | if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) | 
|  | if (branchOp.isRepetitiveRegion(region->getRegionNumber())) | 
|  | return region; | 
|  | region = op->getParentRegion(); | 
|  | } | 
|  | return nullptr; | 
|  | } |