Implement division merging

GitOrigin-RevId: 3bc5353fc6f291d6ac12c256600b5b05d7de8f74
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 53b3d04..4db0358 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -530,6 +530,12 @@
   /// Normalized each constraints by the GCD of its coefficients.
   void normalizeConstraintsByGCD();
 
+  /// Get division representation for each local identifier. If no local
+  /// representation exists for the `i^th` local identifier, denominator[i] is
+  /// set to 0.
+  void getLocalIdsReprs(std::vector<SmallVector<int64_t, 8>> &reprs,
+                        SmallVector<unsigned, 8> &denominator);
+
   /// Removes identifiers in the column range [idStart, idLimit), and copies any
   /// remaining valid data into place, updates member variables, and resizes
   /// arrays as needed.
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index c0465a2..30c8e09 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -1918,6 +1918,48 @@
   equalities.resizeVertically(pos);
 }
 
+void FlatAffineConstraints::getLocalIdsReprs(
+    std::vector<SmallVector<int64_t, 8>> &reprs,
+    SmallVector<unsigned, 8> &denominators) {
+
+  assert(reprs.size() == getNumLocalIds() &&
+         "Size of reprs must be equal to number of local ids");
+  assert(denominators.size() == getNumLocalIds() &&
+         "Size of denominators must be equal to number of local ids");
+
+  // Get upper-lower bound inequality pairs for division representation.
+  std::vector<Optional<std::pair<unsigned, unsigned>>> divIneqPairs(
+      getNumLocalIds());
+  getLocalReprLbUbPairs(divIneqPairs);
+
+  for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
+    if (!divIneqPairs[i].hasValue()) {
+      denominators[i] = 0;
+      continue;
+    }
+
+    std::pair<unsigned, unsigned> divPair = divIneqPairs[i].getValue();
+    LogicalResult divExtracted =
+        getDivRepr(*this, i + getIdKindOffset(IdKind::Local), divPair.first,
+                   divPair.second, reprs[i], denominators[i]);
+    assert(succeeded(divExtracted) &&
+           "Div should have been found since ub-lb pair exists");
+  }
+}
+
+/// Merge local identifer at `pos2` into local identifer at `pos1` in `fac`.
+static void mergeDivision(FlatAffineConstraints &fac, unsigned pos1,
+                          unsigned pos2) {
+  unsigned localOffset = fac.getNumDimAndSymbolIds();
+  pos1 += localOffset;
+  pos2 += localOffset;
+  for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i)
+    fac.atIneq(i, pos1) += fac.atIneq(i, pos2);
+  for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i)
+    fac.atEq(i, pos1) += fac.atEq(i, pos2);
+  fac.removeId(pos2);
+}
+
 /// Merge local ids of `this` and `other`. This is done by appending local ids
 /// of `other` to `this` and inserting local ids of `this` to `other` at start
 /// of its local ids. Number of dimension and symbol ids should match in
@@ -1927,9 +1969,67 @@
          "Number of dimension ids should match");
   assert(getNumSymbolIds() == other.getNumSymbolIds() &&
          "Number of symbol ids should match");
-  unsigned initLocals = getNumLocalIds();
-  insertLocalId(getNumLocalIds(), other.getNumLocalIds());
-  other.insertLocalId(0, initLocals);
+
+  FlatAffineConstraints &fac1 = *this;
+  FlatAffineConstraints &fac2 = other;
+
+  // Get divisions inequality pairs from each FAC.
+  std::vector<SmallVector<int64_t, 8>> divs1(fac1.getNumLocalIds()),
+      divs2(fac2.getNumLocalIds());
+  SmallVector<unsigned, 8> denoms1(fac1.getNumLocalIds()),
+      denoms2(fac2.getNumLocalIds());
+  fac1.getLocalIdsReprs(divs1, denoms1);
+  fac2.getLocalIdsReprs(divs2, denoms2);
+
+  // Merge local ids of fac1 and fac2 without using division information,
+  // i.e. append local ids of `fac2` to `fac1` and insert local ids of `fac1`
+  // to `fac2` at start of its local ids.
+  unsigned initLocals = fac1.getNumLocalIds();
+  insertLocalId(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+  fac2.insertLocalId(0, initLocals);
+
+  // Merge division representation extracted from fac1 and fac2.
+  divs1.insert(divs1.end(), divs2.begin(), divs2.end());
+  denoms1.insert(denoms1.end(), denoms2.begin(), denoms2.end());
+
+  auto dependsOnExist = [&](unsigned offset, SmallVector<int64_t, 8> &div) {
+    for (unsigned i = offset, e = div.size(); i < e; ++i)
+      if (div[i] != 0)
+        return true;
+    return false;
+  };
+
+  // Find duplicate divisions and merge them.
+  // TODO: Add division normalization to support divisions that differ by
+  // a constant
+  for (unsigned i = 0; i < divs1.size(); ++i) {
+    // Check if a division exists which is duplicate of division at `i`.
+    for (unsigned j = i + 1; j < divs1.size(); ++j) {
+      // Check if division representation exists for both local ids.
+      if (denoms1[i] == 0 || denoms1[j] == 0)
+        continue;
+      // Check if denominators match.
+      if (denoms1[i] != denoms1[j])
+        continue;
+      // Check if representation is equal.
+      if (!std::equal(divs1[i].begin(), divs1[i].end(), divs1[j].begin()))
+        continue;
+      // If division representation contains a local variable, do not match.
+      // TODO: Support divisions that depend on other local ids. This can
+      // be done by ordering divisions such that a division representation
+      // for local identifier at position `i` only depends on local identifiers
+      // at position < `i`.
+      if (dependsOnExist(fac1.getIdKindOffset(IdKind::Local), divs1[j]))
+        continue;
+
+      // Merge divisions at position `j` into division at position `i`.
+      mergeDivision(fac1, i, j);
+      mergeDivision(fac2, i, j);
+      divs1.erase(divs1.begin() + j);
+      denoms1.erase(denoms1.begin() + j);
+      --j;
+    }
+  }
 }
 
 /// Removes local variables using equalities. Each equality is checked if it
diff --git a/unittests/Analysis/AffineStructuresTest.cpp b/unittests/Analysis/AffineStructuresTest.cpp
index 4cb0471..8f30bea 100644
--- a/unittests/Analysis/AffineStructuresTest.cpp
+++ b/unittests/Analysis/AffineStructuresTest.cpp
@@ -809,4 +809,79 @@
   EXPECT_TRUE(fac3.isEmpty());
 }
 
+TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) {
+  {
+    // (x) : (exists z, y  = [x / 2] : x = 3y and x + z + 1 >= 0).
+    FlatAffineConstraints fac1(1, 0, 1);
+    fac1.addLocalFloorDiv({1, 0, 0}, 2);
+    fac1.addEquality({1, 0, -3, 0});
+    fac1.addInequality({1, 1, 0, 1});
+
+    // (x) : (exists y = [x / 2], z : x = 5y).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 0}, 2);
+    fac2.addEquality({1, -5, 0});
+    fac2.appendLocalId();
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 1 division matched + 2 unmatched local variables.
+    EXPECT_EQ(fac1.getNumLocalIds(), 3u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 3u);
+  }
+
+  {
+    // (x) : (exists z = [x / 5], y  = [x / 2] : x = 3y).
+    FlatAffineConstraints fac1(1);
+    fac1.addLocalFloorDiv({1, 0}, 5);
+    fac1.addLocalFloorDiv({1, 0, 0}, 2);
+    fac1.addEquality({1, 0, -3, 0});
+
+    // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 0}, 2);
+    fac2.addLocalFloorDiv({1, 0, 0}, 5);
+    fac2.addEquality({1, 0, -5, 0});
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 2 division matched.
+    EXPECT_EQ(fac1.getNumLocalIds(), 2u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 2u);
+  }
+}
+
+TEST(FlatAffineConstraintsTest, mergeDivisionsUnsupported) {
+  // Division merging for divisions depending on other local variables
+  // not yet supported.
+
+  // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
+  FlatAffineConstraints fac1(1);
+  fac1.addLocalFloorDiv({1, 0}, 2);
+  fac1.addLocalFloorDiv({1, 1, 0}, 3);
+  fac1.addInequality({-1, 1, 1, 0});
+
+  // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
+  FlatAffineConstraints fac2(1);
+  fac2.addLocalFloorDiv({1, 0}, 2);
+  fac2.addLocalFloorDiv({1, 1, 0}, 3);
+  fac2.addInequality({1, -1, -1, 0});
+
+  fac1.mergeLocalIds(fac2);
+
+  // Local space should be same.
+  EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+  // 1 division matched + 2 unmerged division due to division depending on
+  // other local variables.
+  EXPECT_EQ(fac1.getNumLocalIds(), 3u);
+  EXPECT_EQ(fac2.getNumLocalIds(), 3u);
+}
+
 } // namespace mlir