//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <utility>

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// ControlFlowInterfaces
//===----------------------------------------------------------------------===//

#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"

SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
    : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
}

SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
                                     MutableOperandRange forwardedOperands)
    : producedOperandCount(producedOperandCount),
      forwardedOperands(std::move(forwardedOperands)) {}

//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//

/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if 'operandIndex' is within the range of 'operands', or
/// std::nullopt if `operandIndex` isn't a successor operand index.
std::optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
                                   unsigned operandIndex, Block *successor) {
  OperandRange forwardedOperands = operands.getForwardedOperands();
  // Check that the operands are valid.
  if (forwardedOperands.empty())
    return std::nullopt;

  // Check to ensure that this operand is within the range.
  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
  if (operandIndex < operandsStart ||
      operandIndex >= (operandsStart + forwardedOperands.size()))
    return std::nullopt;

  // Index the successor.
  unsigned argIndex =
      operands.getProducedOperandCount() + operandIndex - operandsStart;
  return successor->getArgument(argIndex);
}

/// Verify that the given operands match those of the given successor block.
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
                                      const SuccessorOperands &operands) {
  // Check the count.
  unsigned operandCount = operands.size();
  Block *destBB = op->getSuccessor(succNo);
  if (operandCount != destBB->getNumArguments())
    return op->emitError() << "branch has " << operandCount
                           << " operands for successor #" << succNo
                           << ", but target block has "
                           << destBB->getNumArguments();

  // Check the types.
  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
       ++i) {
    if (!cast<BranchOpInterface>(op).areTypesCompatible(
            operands[i].getType(), destBB->getArgument(i).getType()))
      return op->emitError() << "type mismatch for bb argument #" << i
                             << " of successor #" << succNo;
  }
  return success();
}

//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//

static LogicalResult verifyWeights(Operation *op,
                                   llvm::ArrayRef<int32_t> weights,
                                   std::size_t expectedWeightsNum,
                                   llvm::StringRef weightAnchorName,
                                   llvm::StringRef weightRefName) {
  if (weights.empty())
    return success();

  if (weights.size() != expectedWeightsNum)
    return op->emitError() << "expects number of " << weightAnchorName
                           << " weights to match number of " << weightRefName
                           << ": " << weights.size() << " vs "
                           << expectedWeightsNum;

  for (auto [index, weight] : llvm::enumerate(weights))
    if (weight < 0)
      return op->emitError() << "weight #" << index << " must be non-negative";

  if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
    return op->emitError() << "branch weights cannot all be zero";

  return success();
}

LogicalResult detail::verifyBranchWeights(Operation *op) {
  llvm::ArrayRef<int32_t> weights =
      cast<WeightedBranchOpInterface>(op).getWeights();
  return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
                       "successors");
}

//===----------------------------------------------------------------------===//
// WeightedRegionBranchOpInterface
//===----------------------------------------------------------------------===//

LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
  llvm::ArrayRef<int32_t> weights =
      cast<WeightedRegionBranchOpInterface>(op).getWeights();
  return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
}

//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//

static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
                                               RegionBranchPoint sourceNo,
                                               RegionBranchPoint succRegionNo) {
  diag << "from ";
  if (Region *region = sourceNo.getRegionOrNull())
    diag << "Region #" << region->getRegionNumber();
  else
    diag << "parent operands";

  diag << " to ";
  if (Region *region = succRegionNo.getRegionOrNull())
    diag << "Region #" << region->getRegionNumber();
  else
    diag << "parent results";
  return diag;
}

/// Verify that types match along all region control flow edges originating from
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
                         function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
                             getInputsTypesForRegion) {
  auto regionInterface = cast<RegionBranchOpInterface>(op);

  SmallVector<RegionSuccessor, 2> successors;
  regionInterface.getSuccessorRegions(sourcePoint, successors);

  for (RegionSuccessor &succ : successors) {
    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
    if (failed(sourceTypes))
      return failure();

    TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
    if (sourceTypes->size() != succInputsTypes.size()) {
      InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
      return printRegionEdgeName(diag, sourcePoint, succ)
             << ": source has " << sourceTypes->size()
             << " operands, but target successor needs "
             << succInputsTypes.size();
    }

    for (const auto &typesIdx :
         llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
      Type sourceType = std::get<0>(typesIdx.value());
      Type inputType = std::get<1>(typesIdx.value());
      if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
        InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
        return printRegionEdgeName(diag, sourcePoint, succ)
               << ": source type #" << typesIdx.index() << " " << sourceType
               << " should match input type #" << typesIdx.index() << " "
               << inputType;
      }
    }
  }
  return success();
}

/// Verify that types match along control flow edges described the given op.
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
  auto regionInterface = cast<RegionBranchOpInterface>(op);

  auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
    return regionInterface.getEntrySuccessorOperands(point).getTypes();
  };

  // Verify types along control flow edges originating from the parent.
  if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
                                      inputTypesFromParent)))
    return failure();

  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
    if (lhs.size() != rhs.size())
      return false;
    for (auto types : llvm::zip(lhs, rhs)) {
      if (!regionInterface.areTypesCompatible(std::get<0>(types),
                                              std::get<1>(types))) {
        return false;
      }
    }
    return true;
  };

  // Verify types along control flow edges originating from each region.
  for (Region &region : op->getRegions()) {

    // Since there can be multiple terminators implementing the
    // `RegionBranchTerminatorOpInterface`, all should have the same operand
    // types when passing them to the same region.

    SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
    for (Block &block : region)
      if (!block.empty())
        if (auto terminator =
                dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
          regionReturnOps.push_back(terminator);

    // If there is no return-like terminator, the op itself should verify
    // type consistency.
    if (regionReturnOps.empty())
      continue;

    auto inputTypesForRegion =
        [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
      std::optional<OperandRange> regionReturnOperands;
      for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
        auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);

        if (!regionReturnOperands) {
          regionReturnOperands = terminatorOperands;
          continue;
        }

        // Found more than one ReturnLike terminator. Make sure the operand
        // types match with the first one.
        if (!areTypesCompatible(regionReturnOperands->getTypes(),
                                terminatorOperands.getTypes())) {
          InFlightDiagnostic diag = op->emitOpError("along control flow edge");
          return printRegionEdgeName(diag, region, point)
                 << " operands mismatch between return-like terminators";
        }
      }

      // All successors get the same set of operand types.
      return TypeRange(regionReturnOperands->getTypes());
    };

    if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
      return failure();
  }

  return success();
}

/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
/// this function returns "true" for a successor region. The first parameter is
/// the successor region. The second parameter indicates all already visited
/// regions.
using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;

/// Traverse the region graph starting at `begin`. The traversal is interrupted
/// if `stopCondition` evaluates to "true" for a successor region. In that case,
/// this function returns "true". Otherwise, if the traversal was not
/// interrupted, this function returns "false".
static bool traverseRegionGraph(Region *begin,
                                StopConditionFn stopConditionFn) {
  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
  SmallVector<bool> visited(op->getNumRegions(), false);
  visited[begin->getRegionNumber()] = true;

  // Retrieve all successors of the region and enqueue them in the worklist.
  SmallVector<Region *> worklist;
  auto enqueueAllSuccessors = [&](Region *region) {
    SmallVector<RegionSuccessor> successors;
    op.getSuccessorRegions(region, successors);
    for (RegionSuccessor successor : successors)
      if (!successor.isParent())
        worklist.push_back(successor.getSuccessor());
  };
  enqueueAllSuccessors(begin);

  // Process all regions in the worklist via DFS.
  while (!worklist.empty()) {
    Region *nextRegion = worklist.pop_back_val();
    if (stopConditionFn(nextRegion, visited))
      return true;
    if (visited[nextRegion->getRegionNumber()])
      continue;
    visited[nextRegion->getRegionNumber()] = true;
    enqueueAllSuccessors(nextRegion);
  }

  return false;
}

/// Return `true` if region `r` is reachable from region `begin` according to
/// the RegionBranchOpInterface (by taking a branch).
static bool isRegionReachable(Region *begin, Region *r) {
  assert(begin->getParentOp() == r->getParentOp() &&
         "expected that both regions belong to the same op");
  return traverseRegionGraph(begin,
                             [&](Region *nextRegion, ArrayRef<bool> visited) {
                               // Interrupt traversal if `r` was reached.
                               return nextRegion == r;
                             });
}

/// Return `true` if `a` and `b` are in mutually exclusive regions.
///
/// 1. Find the first common of `a` and `b` (ancestor) that implements
///    RegionBranchOpInterface.
/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
///    contained.
/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
///    mutually exclusive if they are not reachable from each other as per
///    RegionBranchOpInterface::getSuccessorRegions.
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
  assert(a && "expected non-empty operation");
  assert(b && "expected non-empty operation");

  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
  while (branchOp) {
    // Check if b is inside branchOp. (We already know that a is.)
    if (!branchOp->isProperAncestor(b)) {
      // Check next enclosing RegionBranchOpInterface.
      branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
      continue;
    }

    // b is contained in branchOp. Retrieve the regions in which `a` and `b`
    // are contained.
    Region *regionA = nullptr, *regionB = nullptr;
    for (Region &r : branchOp->getRegions()) {
      if (r.findAncestorOpInRegion(*a)) {
        assert(!regionA && "already found a region for a");
        regionA = &r;
      }
      if (r.findAncestorOpInRegion(*b)) {
        assert(!regionB && "already found a region for b");
        regionB = &r;
      }
    }
    assert(regionA && regionB && "could not find region of op");

    // `a` and `b` are in mutually exclusive regions if both regions are
    // distinct and neither region is reachable from the other region.
    return regionA != regionB && !isRegionReachable(regionA, regionB) &&
           !isRegionReachable(regionB, regionA);
  }

  // Could not find a common RegionBranchOpInterface among a's and b's
  // ancestors.
  return false;
}

bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
  Region *region = &getOperation()->getRegion(index);
  return isRegionReachable(region, region);
}

bool RegionBranchOpInterface::hasLoop() {
  SmallVector<RegionSuccessor> entryRegions;
  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
  for (RegionSuccessor successor : entryRegions)
    if (!successor.isParent() &&
        traverseRegionGraph(successor.getSuccessor(),
                            [](Region *nextRegion, ArrayRef<bool> visited) {
                              // Interrupt traversal if the region was already
                              // visited.
                              return visited[nextRegion->getRegionNumber()];
                            }))
      return true;
  return false;
}

Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
  while (Region *region = op->getParentRegion()) {
    op = region->getParentOp();
    if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
      if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
        return region;
  }
  return nullptr;
}

Region *mlir::getEnclosingRepetitiveRegion(Value value) {
  Region *region = value.getParentRegion();
  while (region) {
    Operation *op = region->getParentOp();
    if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
      if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
        return region;
    region = op->getParentRegion();
  }
  return nullptr;
}
