| //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "llvm/ADT/SmallPtrSet.h" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // ControlFlowInterfaces |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // BranchOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some |
| /// successor if 'operandIndex' is within the range of 'operands', or None if |
| /// `operandIndex` isn't a successor operand index. |
| Optional<BlockArgument> |
| detail::getBranchSuccessorArgument(Optional<OperandRange> operands, |
| unsigned operandIndex, Block *successor) { |
| // Check that the operands are valid. |
| if (!operands || operands->empty()) |
| return llvm::None; |
| |
| // Check to ensure that this operand is within the range. |
| unsigned operandsStart = operands->getBeginOperandIndex(); |
| if (operandIndex < operandsStart || |
| operandIndex >= (operandsStart + operands->size())) |
| return llvm::None; |
| |
| // Index the successor. |
| unsigned argIndex = operandIndex - operandsStart; |
| return successor->getArgument(argIndex); |
| } |
| |
| /// Verify that the given operands match those of the given successor block. |
| LogicalResult |
| detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, |
| Optional<OperandRange> operands) { |
| if (!operands) |
| return success(); |
| |
| // Check the count. |
| unsigned operandCount = operands->size(); |
| Block *destBB = op->getSuccessor(succNo); |
| if (operandCount != destBB->getNumArguments()) |
| return op->emitError() << "branch has " << operandCount |
| << " operands for successor #" << succNo |
| << ", but target block has " |
| << destBB->getNumArguments(); |
| |
| // Check the types. |
| auto operandIt = operands->begin(); |
| for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { |
| if ((*operandIt).getType() != destBB->getArgument(i).getType()) |
| return op->emitError() << "type mismatch for bb argument #" << i |
| << " of successor #" << succNo; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RegionBranchOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| // A constant value to represent unknown number of region invocations. |
| const int64_t mlir::kUnknownNumRegionInvocations = -1; |
| |
| /// Verify that types match along all region control flow edges originating from |
| /// `sourceNo` (region # if source is a region, llvm::None if source is parent |
| /// op). `getInputsTypesForRegion` is a function that returns the types of the |
| /// inputs that flow from `sourceIndex' to the given region, or llvm::None if |
| /// the exact type match verification is not necessary (e.g., if the Op verifies |
| /// the match itself). |
| static LogicalResult |
| verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo, |
| function_ref<Optional<TypeRange>(Optional<unsigned>)> |
| getInputsTypesForRegion) { |
| auto regionInterface = cast<RegionBranchOpInterface>(op); |
| |
| SmallVector<RegionSuccessor, 2> successors; |
| unsigned numInputs; |
| if (sourceNo) { |
| Region &srcRegion = op->getRegion(sourceNo.getValue()); |
| numInputs = srcRegion.getNumArguments(); |
| } else { |
| numInputs = op->getNumOperands(); |
| } |
| SmallVector<Attribute, 2> operands(numInputs, nullptr); |
| regionInterface.getSuccessorRegions(sourceNo, operands, successors); |
| |
| for (RegionSuccessor &succ : successors) { |
| Optional<unsigned> succRegionNo; |
| if (!succ.isParent()) |
| succRegionNo = succ.getSuccessor()->getRegionNumber(); |
| |
| auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { |
| diag << "from "; |
| if (sourceNo) |
| diag << "Region #" << sourceNo.getValue(); |
| else |
| diag << "parent operands"; |
| |
| diag << " to "; |
| if (succRegionNo) |
| diag << "Region #" << succRegionNo.getValue(); |
| else |
| diag << "parent results"; |
| return diag; |
| }; |
| |
| Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo); |
| if (!sourceTypes.hasValue()) |
| continue; |
| |
| TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); |
| if (sourceTypes->size() != succInputsTypes.size()) { |
| InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); |
| return printEdgeName(diag) << ": source has " << sourceTypes->size() |
| << " operands, but target successor needs " |
| << succInputsTypes.size(); |
| } |
| |
| for (auto typesIdx : |
| llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { |
| Type sourceType = std::get<0>(typesIdx.value()); |
| Type inputType = std::get<1>(typesIdx.value()); |
| if (sourceType != inputType) { |
| InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); |
| return printEdgeName(diag) |
| << ": source type #" << typesIdx.index() << " " << sourceType |
| << " should match input type #" << typesIdx.index() << " " |
| << inputType; |
| } |
| } |
| } |
| return success(); |
| } |
| |
| /// Verify that types match along control flow edges described the given op. |
| LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { |
| auto regionInterface = cast<RegionBranchOpInterface>(op); |
| |
| auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange { |
| if (regionNo.hasValue()) { |
| return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) |
| .getTypes(); |
| } |
| |
| // If the successor of a parent op is the parent itself |
| // RegionBranchOpInterface does not have an API to query what the entry |
| // operands will be in that case. Vend out the result types of the op in |
| // that case so that type checking succeeds for this case. |
| return op->getResultTypes(); |
| }; |
| |
| // Verify types along control flow edges originating from the parent. |
| if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) |
| return failure(); |
| |
| // RegionBranchOpInterface should not be implemented by Ops that do not have |
| // attached regions. |
| assert(op->getNumRegions() != 0); |
| |
| // Verify types along control flow edges originating from each region. |
| for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { |
| Region ®ion = op->getRegion(regionNo); |
| |
| // Since the interface cannot distinguish between different ReturnLike |
| // ops within the region branching to different successors, all ReturnLike |
| // ops in this region should have the same operand types. We will then use |
| // one of them as the representative for type matching. |
| |
| Operation *regionReturn = nullptr; |
| for (Block &block : region) { |
| Operation *terminator = block.getTerminator(); |
| if (!terminator->hasTrait<OpTrait::ReturnLike>()) |
| continue; |
| |
| if (!regionReturn) { |
| regionReturn = terminator; |
| continue; |
| } |
| |
| // Found more than one ReturnLike terminator. Make sure the operand types |
| // match with the first one. |
| if (regionReturn->getOperandTypes() != terminator->getOperandTypes()) |
| return op->emitOpError("Region #") |
| << regionNo |
| << " operands mismatch between return-like terminators"; |
| } |
| |
| auto inputTypesFromRegion = |
| [&](Optional<unsigned> regionNo) -> Optional<TypeRange> { |
| // If there is no return-like terminator, the op itself should verify |
| // type consistency. |
| if (!regionReturn) |
| return llvm::None; |
| |
| // All successors get the same set of operands. |
| return TypeRange(regionReturn->getOperands().getTypes()); |
| }; |
| |
| if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) |
| return failure(); |
| } |
| |
| return success(); |
| } |