[MLIR][FlatAffineConstraints] Remove duplicate divisions while merging local ids

This patch implements detecting duplicate local identifiers by extracting their
division representation while merging local identifiers.

For example, given the FACs A, B:

```
A: (x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4]: d0 <= s0, d1 <= s0, x + y >= 2)
B: (x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4]: d0 <= s0, d1 <= s0, x + y >= 5)
```

The intersection of A and B without this patch would lead to the following FAC:

```
(x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4], d2 = [x / 4], d3 = [x / 4]: d0 <= s0, d1 <= s0, d2 <= s0, d3 <= s0, x + y >= 2, x + y >= 5)
```

after this patch, merging of local ids will detect that `d0 = d2` and `d1 = d3`,
and the intersection of these two FACs will be (after removing duplicate constraints):

```
(x, y)[s0] : (exists d0 = [x / 4], d1 = [y / 4] : d0 <= s0, d1 <= s0, x + y >= 2, x + y >= 5)
```

This reduces the number of constraints by 2 (constraints) + 4 (2 constraints for each extra division) for this case.

This is used to reduce the output size representation of operations like
PresburgerSet::subtract, PresburgerSet::intersect which require merging local
variables.

Reviewed By: arjunp, bondhugula

Differential Revision: https://reviews.llvm.org/D112867

GitOrigin-RevId: d257f7c1bff3521053d925c616f6f60d0a5c22ec
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 53b3d04..0f763f8 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -441,10 +441,16 @@
   /// variables.
   void convertDimToLocal(unsigned dimStart, unsigned dimLimit);
 
-  /// Merge local ids of `this` and `other`. This is done by appending local ids
-  /// of `other` to `this` and inserting local ids of `this` to `other` at start
-  /// of its local ids. Number of dimension and symbol ids should match in
-  /// `this` and `other`.
+  /// Adds additional local ids to the sets such that they both have the union
+  /// of the local ids in each set, without changing the set of points that
+  /// lie in `this` and `other`. The ordering of the local ids in the
+  /// sets may also be changed. After merging, if the `i^th` local variable in
+  /// one set has a known division representation, then the `i^th` local
+  /// variable in the other set either has the same division representation or
+  /// no known division representation.
+  ///
+  /// The number of dimensions and symbol ids in `this` and `other` should
+  /// match.
   void mergeLocalIds(FlatAffineConstraints &other);
 
   /// Removes all equalities and inequalities.
@@ -819,8 +825,8 @@
   /// constraint systems are updated so that they have the union of all
   /// identifiers, with `this`'s original identifiers appearing first followed
   /// by any of `other`'s identifiers that didn't appear in `this`. Local
-  /// identifiers of each system are by design separate/local and are placed
-  /// one after other (`this`'s followed by `other`'s).
+  /// identifiers in `other` that have the same division representation as local
+  /// identifiers in `this` are merged into one.
   //  E.g.: Input: `this`  has (%i, %j) [%M, %N]
   //               `other` has (%k, %j) [%P, %N, %M]
   //        Output: both `this`, `other` have (%i, %j, %k) [%M, %N, %P]
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index c0465a2..ea600ee 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -493,8 +493,8 @@
 /// dimension-wise and symbol-wise unique; both constraint systems are updated
 /// so that they have the union of all identifiers, with A's original
 /// identifiers appearing first followed by any of B's identifiers that didn't
-/// appear in A. Local identifiers of each system are by design separate/local
-/// and are placed one after other (A's followed by B's).
+/// appear in A. Local identifiers in B that have the same division
+/// representation as local identifiers in A are merged into one.
 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
@@ -1918,18 +1918,108 @@
   equalities.resizeVertically(pos);
 }
 
-/// Merge local ids of `this` and `other`. This is done by appending local ids
-/// of `other` to `this` and inserting local ids of `this` to `other` at start
-/// of its local ids. Number of dimension and symbol ids should match in
-/// `this` and `other`.
+/// Eliminate `pos2^th` local identifier, replacing its every instance with
+/// `pos1^th` local identifier. This function is intended to be used to remove
+/// redundancy when local variables at position `pos1` and `pos2` are restricted
+/// to have the same value.
+static void eliminateRedundantLocalId(FlatAffineConstraints &fac, unsigned pos1,
+                                      unsigned pos2) {
+
+  assert(pos1 < fac.getNumLocalIds() && "Invalid local id position");
+  assert(pos2 < fac.getNumLocalIds() && "Invalid local id position");
+
+  unsigned localOffset = fac.getNumDimAndSymbolIds();
+  pos1 += localOffset;
+  pos2 += localOffset;
+  for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i)
+    fac.atIneq(i, pos1) += fac.atIneq(i, pos2);
+  for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i)
+    fac.atEq(i, pos1) += fac.atEq(i, pos2);
+  fac.removeId(pos2);
+}
+
+/// Adds additional local ids to the sets such that they both have the union
+/// of the local ids in each set, without changing the set of points that
+/// lie in `this` and `other`.
+///
+/// To detect local ids that always take the same in both sets, each local id is
+/// represented as a floordiv with constant denominator in terms of other ids.
+/// After extracting these divisions, local ids with the same division
+/// representation are considered duplicate and are merged. It is possible that
+/// division representation for some local id cannot be obtained, and thus these
+/// local ids are not considered for detecting duplicates.
 void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) {
   assert(getNumDimIds() == other.getNumDimIds() &&
          "Number of dimension ids should match");
   assert(getNumSymbolIds() == other.getNumSymbolIds() &&
          "Number of symbol ids should match");
-  unsigned initLocals = getNumLocalIds();
-  insertLocalId(getNumLocalIds(), other.getNumLocalIds());
-  other.insertLocalId(0, initLocals);
+
+  FlatAffineConstraints &facA = *this;
+  FlatAffineConstraints &facB = other;
+
+  // Merge local ids of facA and facB without using division information,
+  // i.e. append local ids of `facB` to `facA` and insert local ids of `facA`
+  // to `facB` at start of its local ids.
+  unsigned initLocals = facA.getNumLocalIds();
+  insertLocalId(facA.getNumLocalIds(), facB.getNumLocalIds());
+  facB.insertLocalId(0, initLocals);
+
+  // Get division representations from each FAC.
+  std::vector<SmallVector<int64_t, 8>> divsA, divsB;
+  SmallVector<unsigned, 4> denomsA, denomsB;
+  facA.getLocalReprs(divsA, denomsA);
+  facB.getLocalReprs(divsB, denomsB);
+
+  // Copy division information for facB into `divsA` and `denomsA`, so that
+  // these have the combined division information of both FACs. Since newly
+  // added local variables in facA and facB have no constraints, they will not
+  // have any division representation.
+  std::copy(divsB.begin() + initLocals, divsB.end(),
+            divsA.begin() + initLocals);
+  std::copy(denomsB.begin() + initLocals, denomsB.end(),
+            denomsA.begin() + initLocals);
+
+  // Find and merge duplicate divisions.
+  // TODO: Add division normalization to support divisions that differ by
+  // a constant.
+  // TODO: Add division ordering such that a division representation for local
+  // identifier at position `i` only depends on local identifiers at position <
+  // `i`. This would make sure that all divisions depending on other local
+  // variables that can be merged, are merged.
+  unsigned localOffset = getIdKindOffset(IdKind::Local);
+  for (unsigned i = 0; i < divsA.size(); ++i) {
+    // Check if a division representation exists for the `i^th` local id.
+    if (denomsA[i] == 0)
+      continue;
+    // Check if a division exists which is a duplicate of the division at `i`.
+    for (unsigned j = i + 1; j < divsA.size(); ++j) {
+      // Check if a division representation exists for the `j^th` local id.
+      if (denomsA[j] == 0)
+        continue;
+      // Check if the denominators match.
+      if (denomsA[i] != denomsA[j])
+        continue;
+      // Check if the representations are equal.
+      if (divsA[i] != divsA[j])
+        continue;
+
+      // Merge divisions at position `j` into division at position `i`.
+      eliminateRedundantLocalId(facA, i, j);
+      eliminateRedundantLocalId(facB, i, j);
+      for (unsigned k = 0, g = divsA.size(); k < g; ++k) {
+        SmallVector<int64_t, 8> &div = divsA[k];
+        if (denomsA[k] != 0) {
+          div[localOffset + i] += div[localOffset + j];
+          div.erase(div.begin() + localOffset + j);
+        }
+      }
+
+      divsA.erase(divsA.begin() + j);
+      denomsA.erase(denomsA.begin() + j);
+      // Since `j` can never be zero, we do not need to worry about overflows.
+      --j;
+    }
+  }
 }
 
 /// Removes local variables using equalities. Each equality is checked if it
diff --git a/unittests/Analysis/AffineStructuresTest.cpp b/unittests/Analysis/AffineStructuresTest.cpp
index 4cb0471..497816a 100644
--- a/unittests/Analysis/AffineStructuresTest.cpp
+++ b/unittests/Analysis/AffineStructuresTest.cpp
@@ -809,4 +809,127 @@
   EXPECT_TRUE(fac3.isEmpty());
 }
 
+TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) {
+  {
+    // (x) : (exists z, y  = [x / 2] : x = 3y and x + z + 1 >= 0).
+    FlatAffineConstraints fac1(1, 0, 1);
+    fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2].
+    fac1.addEquality({1, 0, -3, 0});     // x = 3y.
+    fac1.addInequality({1, 1, 0, 1});    // x + z + 1 >= 0.
+
+    // (x) : (exists y = [x / 2], z : x = 5y).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
+    fac2.addEquality({1, -5, 0});     // x = 5y.
+    fac2.appendLocalId();             // Add local id z.
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 1 division should be matched + 2 unmatched local ids.
+    EXPECT_EQ(fac1.getNumLocalIds(), 3u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 3u);
+  }
+
+  {
+    // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y).
+    FlatAffineConstraints fac1(1);
+    fac1.addLocalFloorDiv({1, 0}, 5);    // z = [x / 5].
+    fac1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2].
+    fac1.addEquality({1, 0, -3, 0});     // x = 3y.
+
+    // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 0}, 2);    // y = [x / 2].
+    fac2.addLocalFloorDiv({1, 0, 0}, 5); // z = [x / 5].
+    fac2.addEquality({1, 0, -5, 0});     // x = 5z.
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 2 divisions should be matched.
+    EXPECT_EQ(fac1.getNumLocalIds(), 2u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 2u);
+  }
+}
+
+TEST(FlatAffineConstraintsTest, mergeDivisionsNestedDivsions) {
+  {
+    // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
+    FlatAffineConstraints fac1(1);
+    fac1.addLocalFloorDiv({1, 0}, 2);    // y = [x / 2].
+    fac1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
+    fac1.addInequality({-1, 1, 1, 0});   // y + z >= x.
+
+    // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 0}, 2);    // y = [x / 2].
+    fac2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
+    fac2.addInequality({1, -1, -1, 0});  // y + z <= x.
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 2 divisions should be matched.
+    EXPECT_EQ(fac1.getNumLocalIds(), 2u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 2u);
+  }
+
+  {
+    // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z >= x).
+    FlatAffineConstraints fac1(1);
+    fac1.addLocalFloorDiv({1, 0}, 2);       // y = [x / 2].
+    fac1.addLocalFloorDiv({1, 1, 0}, 3);    // z = [x + y / 3].
+    fac1.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5].
+    fac1.addInequality({-1, 1, 1, 0, 0});   // y + z >= x.
+
+    // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z <= x).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 0}, 2);       // y = [x / 2].
+    fac2.addLocalFloorDiv({1, 1, 0}, 3);    // z = [x + y / 3].
+    fac2.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5].
+    fac2.addInequality({1, -1, -1, 0, 0});  // y + z <= x.
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 3 divisions should be matched.
+    EXPECT_EQ(fac1.getNumLocalIds(), 3u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 3u);
+  }
+}
+
+TEST(FlatAffineConstraintsTest, mergeDivisionsConstants) {
+  {
+    // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x).
+    FlatAffineConstraints fac1(1);
+    fac1.addLocalFloorDiv({1, 1}, 2);    // y = [x + 1 / 2].
+    fac1.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
+    fac1.addInequality({-1, 1, 1, 0});   // y + z >= x.
+
+    // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x).
+    FlatAffineConstraints fac2(1);
+    fac2.addLocalFloorDiv({1, 1}, 2);    // y = [x + 1 / 2].
+    fac2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
+    fac2.addInequality({1, -1, -1, 0});  // y + z <= x.
+
+    fac1.mergeLocalIds(fac2);
+
+    // Local space should be same.
+    EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+    // 2 divisions should be matched.
+    EXPECT_EQ(fac1.getNumLocalIds(), 2u);
+    EXPECT_EQ(fac2.getNumLocalIds(), 2u);
+  }
+}
+
 } // namespace mlir