//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h"

using namespace mlir;
using namespace presburger;

void MultiAffineFunction::assertIsConsistent() const {
  assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
             output.getNumColumns() &&
         "Inconsistent number of output columns");
  assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
             divs.getNumNonDivs() &&
         "Inconsistent number of non-division variables in divs");
  assert(space.getNumRangeVars() == output.getNumRows() &&
         "Inconsistent number of output rows");
  assert(space.getNumLocalVars() == divs.getNumDivs() &&
         "Inconsistent number of divisions.");
  assert(divs.hasAllReprs() && "All divisions should have a representation");
}

// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
static SmallVector<MPInt, 8> subtractExprs(ArrayRef<MPInt> vecA,
                                           ArrayRef<MPInt> vecB) {
  assert(vecA.size() == vecB.size() &&
         "Cannot subtract vectors of differing lengths!");
  SmallVector<MPInt, 8> result;
  result.reserve(vecA.size());
  for (unsigned i = 0, e = vecA.size(); i < e; ++i)
    result.push_back(vecA[i] - vecB[i]);
  return result;
}

PresburgerSet PWMAFunction::getDomain() const {
  PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
  for (const Piece &piece : pieces)
    domain.unionInPlace(piece.domain);
  return domain;
}

void MultiAffineFunction::print(raw_ostream &os) const {
  space.print(os);
  os << "Division Representation:\n";
  divs.print(os);
  os << "Output:\n";
  output.print(os);
}

SmallVector<MPInt, 8>
MultiAffineFunction::valueAt(ArrayRef<MPInt> point) const {
  assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
         "Point has incorrect dimensionality!");

  SmallVector<MPInt, 8> pointHomogenous{llvm::to_vector(point)};
  // Get the division values at this point.
  SmallVector<Optional<MPInt>, 8> divValues = divs.divValuesAt(point);
  // The given point didn't include the values of the divs which the output is a
  // function of; we have computed one possible set of values and use them here.
  pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
  for (const Optional<MPInt> &divVal : divValues)
    pointHomogenous.push_back(*divVal);
  // The matrix `output` has an affine expression in the ith row, corresponding
  // to the expression for the ith value in the output vector. The last column
  // of the matrix contains the constant term. Let v be the input point with
  // a 1 appended at the end. We can see that output * v gives the desired
  // output vector.
  pointHomogenous.emplace_back(1);
  SmallVector<MPInt, 8> result = output.postMultiplyWithColumn(pointHomogenous);
  assert(result.size() == getNumOutputs());
  return result;
}

bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
  assert(space.isCompatible(other.space) &&
         "Spaces should be compatible for equality check.");
  return getAsRelation().isEqual(other.getAsRelation());
}

bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
                                  const IntegerPolyhedron &domain) const {
  assert(space.isCompatible(other.space) &&
         "Spaces should be compatible for equality check.");
  IntegerRelation restrictedThis = getAsRelation();
  restrictedThis.intersectDomain(domain);

  IntegerRelation restrictedOther = other.getAsRelation();
  restrictedOther.intersectDomain(domain);

  return restrictedThis.isEqual(restrictedOther);
}

bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
                                  const PresburgerSet &domain) const {
  assert(space.isCompatible(other.space) &&
         "Spaces should be compatible for equality check.");
  return llvm::all_of(domain.getAllDisjuncts(),
                      [&](const IntegerRelation &disjunct) {
                        return isEqual(other, IntegerPolyhedron(disjunct));
                      });
}

void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
  assert(end <= getNumOutputs() && "Invalid range");

  if (start >= end)
    return;

  space.removeVarRange(VarKind::Range, start, end);
  output.removeRows(start, end - start);
}

void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
  assert(space.isCompatible(other.space) && "Functions should be compatible");

  unsigned nDivs = getNumDivs();
  unsigned divOffset = divs.getDivOffset();

  other.divs.insertDiv(0, nDivs);

  SmallVector<MPInt, 8> div(other.divs.getNumVars() + 1);
  for (unsigned i = 0; i < nDivs; ++i) {
    // Zero fill.
    std::fill(div.begin(), div.end(), 0);
    // Fill div with dividend from `divs`. Do not fill the constant.
    std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
              div.begin());
    // Fill constant.
    div.back() = divs.getDividend(i).back();
    other.divs.setDiv(i, div, divs.getDenom(i));
  }

  other.space.insertVar(VarKind::Local, 0, nDivs);
  other.output.insertColumns(divOffset, nDivs);

  auto merge = [&](unsigned i, unsigned j) {
    // We only merge from local at pos j to local at pos i, where j > i.
    if (i >= j)
      return false;

    // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
    // do not want to merge duplicates in `this`, we ignore this call.
    if (j < nDivs)
      return false;

    // Merge things in space and output.
    other.space.removeVarRange(VarKind::Local, j, j + 1);
    other.output.addToColumn(divOffset + i, divOffset + j, 1);
    other.output.removeColumn(divOffset + j);
    return true;
  };

  other.divs.removeDuplicateDivs(merge);

  unsigned newDivs = other.divs.getNumDivs() - nDivs;

  space.insertVar(VarKind::Local, nDivs, newDivs);
  output.insertColumns(divOffset + nDivs, newDivs);
  divs = other.divs;

  // Check consistency.
  assertIsConsistent();
  other.assertIsConsistent();
}

/// Two PWMAFunctions are equal if they have the same dimensionalities,
/// the same domain, and take the same value at every point in the domain.
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
  if (!space.isCompatible(other.space))
    return false;

  if (!this->getDomain().isEqual(other.getDomain()))
    return false;

  // Check if, whenever the domains of a piece of `this` and a piece of `other`
  // overlap, they take the same output value. If `this` and `other` have the
  // same domain (checked above), then this check passes iff the two functions
  // have the same output at every point in the domain.
  return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
    return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
      PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
      return pieceA.output.isEqual(pieceB.output, commonDomain);
    });
  });
}

void PWMAFunction::addPiece(const Piece &piece) {
  assert(piece.isConsistent() && "Piece should be consistent");
  pieces.push_back(piece);
}

void PWMAFunction::print(raw_ostream &os) const {
  space.print(os);
  os << getNumPieces() << " pieces:\n";
  for (const Piece &piece : pieces) {
    os << "Domain of piece:\n";
    piece.domain.print(os);
    os << "Output of piece\n";
    piece.output.print(os);
  }
}

void PWMAFunction::dump() const { print(llvm::errs()); }

PWMAFunction PWMAFunction::unionFunction(
    const PWMAFunction &func,
    llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
  assert(getNumOutputs() == func.getNumOutputs() &&
         "Ranges of functions should be same.");
  assert(getSpace().isCompatible(func.getSpace()) &&
         "Space is not compatible.");

  // The algorithm used here is as follows:
  // - Add the output of pieceB for the part of the domain where both pieceA and
  //   pieceB are defined, and `tiebreak` chooses the output of pieceB.
  // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
  // chooses
  //   pieceA over pieceB.
  // - Add the output of pieceB, where pieceA is not defined.

  // Add parts of the common domain where pieceB's output is used. Also
  // add all the parts where pieceA's output is used, both common and
  // non-common.
  PWMAFunction result(getSpace());
  for (const Piece &pieceA : pieces) {
    PresburgerSet dom(pieceA.domain);
    for (const Piece &pieceB : func.pieces) {
      PresburgerSet better = tiebreak(pieceB, pieceA);
      // Add the output of pieceB, where it is better than output of pieceA.
      // The disjuncts in "better" will be disjoint as tiebreak should gurantee
      // that.
      result.addPiece({better, pieceB.output});
      dom = dom.subtract(better);
    }
    // Add output of pieceA, where it is better than pieceB, or pieceB is not
    // defined.
    //
    // `dom` here is guranteed to be disjoint from already added pieces
    // because because the pieces added before are either:
    // - Subsets of the domain of other MAFs in `this`, which are guranteed
    //   to be disjoint from `dom`, or
    // - They are one of the pieces added for `pieceB`, and we have been
    //   subtracting all such pieces from `dom`, so `dom` is disjoint from those
    //   pieces as well.
    result.addPiece({dom, pieceA.output});
  }

  // Add parts of pieceB which are not shared with pieceA.
  PresburgerSet dom = getDomain();
  for (const Piece &pieceB : func.pieces)
    result.addPiece({pieceB.domain.subtract(dom), pieceB.output});

  return result;
}

/// A tiebreak function which breaks ties by comparing the outputs
/// lexicographically. If `lexMin` is true, then the ties are broken by
/// taking the lexicographically smaller output and otherwise, by taking the
/// lexicographically larger output.
template <bool lexMin>
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
                                 const PWMAFunction::Piece &pieceB) {
  // TODO: Support local variables here.
  assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) &&
         "Pieces should be compatible");
  assert(pieceA.domain.getSpace().getNumLocalVars() == 0 &&
         "Local variables are not supported yet.");

  PresburgerSpace compatibleSpace = pieceA.domain.getSpace();
  const PresburgerSpace &space = pieceA.domain.getSpace();

  // We first create the set `result`, corresponding to the set where output
  // of pieceA is lexicographically larger/smaller than pieceB. This is done by
  // creating a PresburgerSet with the following constraints:
  //
  //    (outA[0] > outB[0]) U
  //    (outA[0] = outB[0], outA[1] > outA[1]) U
  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
  //    ...
  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
  //
  // where `n` is the number of outputs.
  // If `lexMin` is set, the complement inequality is used:
  //
  //    (outA[0] < outB[0]) U
  //    (outA[0] = outB[0], outA[1] < outA[1]) U
  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
  //    ...
  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
  PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
  IntegerPolyhedron levelSet(
      /*numReservedInequalities=*/1,
      /*numReservedEqualities=*/pieceA.output.getNumOutputs(),
      /*numReservedCols=*/space.getNumVars() + 1, space);
  for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {

    // Create the expression `outA - outB` for this level.
    SmallVector<MPInt, 8> subExpr = subtractExprs(
        pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));

    if (lexMin) {
      // For lexMin, we add an upper bound of -1:
      //        outA - outB <= -1
      //        outA <= outB - 1
      //        outA < outB
      levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1));
    } else {
      // For lexMax, we add a lower bound of 1:
      //        outA - outB >= 1
      //        outA > outB + 1
      //        outA > outB
      levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1));
    }

    // Union the set with the result.
    result.unionInPlace(levelSet);
    // There is only 1 inequality in `levelSet`, so the index is always 0.
    levelSet.removeInequality(0);
    // Add equality `outA - outB == 0` for this level for next iteration.
    levelSet.addEquality(subExpr);
  }

  // We then intersect `result` with the domain of pieceA and pieceB, to only
  // tiebreak on the domain where both are defined.
  result = result.intersect(pieceA.domain).intersect(pieceB.domain);

  return result;
}

PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
  return unionFunction(func, tiebreakLex</*lexMin=*/true>);
}

PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
  return unionFunction(func, tiebreakLex</*lexMin=*/false>);
}

void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
  assert(space.isCompatible(other.space) &&
         "Spaces should be compatible for subtraction.");

  MultiAffineFunction copyOther = other;
  mergeDivs(copyOther);
  for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
    output.addToRow(i, copyOther.getOutputExpr(i), MPInt(-1));

  // Check consistency.
  assertIsConsistent();
}

/// Adds division constraints corresponding to local variables, given a
/// relation and division representations of the local variables in the
/// relation.
static void addDivisionConstraints(IntegerRelation &rel,
                                   const DivisionRepr &divs) {
  assert(divs.hasAllReprs() &&
         "All divisions in divs should have a representation");
  assert(rel.getNumVars() == divs.getNumVars() &&
         "Relation and divs should have the same number of vars");
  assert(rel.getNumLocalVars() == divs.getNumDivs() &&
         "Relation and divs should have the same number of local vars");

  for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
    rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
                                       divs.getDivOffset() + i));
    rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
                                       divs.getDivOffset() + i));
  }
}

IntegerRelation MultiAffineFunction::getAsRelation() const {
  // Create a relation corressponding to the input space plus the divisions
  // used in outputs.
  IntegerRelation result(PresburgerSpace::getRelationSpace(
      space.getNumDomainVars(), 0, space.getNumSymbolVars(),
      space.getNumLocalVars()));
  // Add division constraints corresponding to divisions used in outputs.
  addDivisionConstraints(result, divs);
  // The outputs are represented as range variables in the relation. We add
  // range variables for the outputs.
  result.insertVar(VarKind::Range, 0, getNumOutputs());

  // Add equalities such that the i^th range variable is equal to the i^th
  // output expression.
  SmallVector<MPInt, 8> eq(result.getNumCols());
  for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
    // TODO: Add functions to get VarKind offsets in output in MAF and use them
    // here.
    // The output expression does not contain range variables, while the
    // equality does. So, we need to copy all variables and mark all range
    // variables as 0 in the equality.
    ArrayRef<MPInt> expr = getOutputExpr(i);
    // Copy domain variables in `expr` to domain variables in `eq`.
    std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
    // Fill the range variables in `eq` as zero.
    std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
              eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
    // Copy remaining variables in `expr` to the remaining variables in `eq`.
    std::copy(expr.begin() + getNumDomainVars(), expr.end(),
              eq.begin() + result.getVarKindEnd(VarKind::Range));

    // Set the i^th range var to -1 in `eq` to equate the output expression to
    // this range var.
    eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
    // Add the equality `rangeVar_i = output[i]`.
    result.addEquality(eq);
  }

  return result;
}

void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
  space.removeVarRange(VarKind::Range, start, end);
  for (Piece &piece : pieces)
    piece.output.removeOutputs(start, end);
}

Optional<SmallVector<MPInt, 8>>
PWMAFunction::valueAt(ArrayRef<MPInt> point) const {
  assert(point.size() == getNumDomainVars() + getNumSymbolVars());

  for (const Piece &piece : pieces)
    if (piece.domain.containsPoint(point))
      return piece.output.valueAt(point);
  return None;
}
