Reapply [mlir][memref]: Allow collapse of strided unit dim even if strides are dynamic (#171039)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c2e71669..fe93b3e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2636,6 +2636,11 @@
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
       stride = stride * SaturatedInteger::wrap(srcShape[idx]);
 
+      // Dimensions of size 1 should be skipped, because their strides are
+      // meaningless and could have any arbitrary value.
+      if (srcShape[idx - 1] == 1)
+        continue;
+
       // Both source and result stride must have the same static value. In that
       // case, we can be sure, that the dimensions are collapsible (because they
       // are contiguous).
@@ -2648,11 +2653,6 @@
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
-      // Dimensions of size 1 should be skipped, because their strides are
-      // meaningless and could have any arbitrary value.
-      if (srcShape[idx - 1] == 1)
-        continue;
-
       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
         return failure();
     }
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index a90c950..cddc79f 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -440,7 +440,10 @@
          %arg4: index,
          %arg5: index,
          %arg6: index,
-         %arg7: memref<4x?x4xf32>) {
+         %arg7: memref<4x?x4xf32>,
+         %arg8: memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>>,
+         %arg9: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>) {
+
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -489,6 +492,16 @@
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
   %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
         : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
+//  CHECK-SAME:     memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>>
+  %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
+//  CHECK-SAME:     memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into memref<3x288xf32, strided<[288, 1], offset: 864>>
+  %6 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
+    memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into
+    memref<3x288xf32, strided<[288, 1], offset: 864>>
   return
 }