[mlir] Update dstNode after DenseMap insertion in loop fusion pass.

Reviewed By: vinayaka-polymage

Differential Revision: https://reviews.llvm.org/D101794
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 50adc97..aea8d98 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1645,6 +1645,10 @@
               // Add edge from 'newMemRef' node to dstNode.
               mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
             }
+            // One or more entries for 'newMemRef' alloc op are inserted into
+            // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
+            // reallocate, update dstNode.
+            dstNode = mdg->getNode(dstId);
           }
 
           // Collect dst loop stats after memref privatization transformation.
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index f3fcae2..2cad8613 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3115,3 +3115,186 @@
 // CHECK-NEXT:      affine.load
 // CHECK-NEXT:      mulf
 // CHECK-NEXT:      affine.store
+
+// -----
+
+// CHECK-LABEL: func @fuse_large_number_of_loops
+func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x10xf32, 1>, %arg2: memref<20x10xf32, 1>, %arg3: memref<20x10xf32, 1>, %arg4: memref<20x10xf32, 1>, %arg5: memref<f32, 1>, %arg6: memref<f32, 1>, %arg7: memref<f32, 1>, %arg8: memref<f32, 1>, %arg9: memref<20x10xf32, 1>, %arg10: memref<20x10xf32, 1>, %arg11: memref<20x10xf32, 1>, %arg12: memref<20x10xf32, 1>) {
+  %cst = constant 1.000000e+00 : f32
+  %0 = memref.alloc() : memref<f32, 1>
+  affine.store %cst, %0[] : memref<f32, 1>
+  %1 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg6[] : memref<f32, 1>
+      affine.store %21, %1[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %2 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %arg3[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %2[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %3 = memref.alloc() : memref<f32, 1>
+  %4 = affine.load %arg6[] : memref<f32, 1>
+  %5 = affine.load %0[] : memref<f32, 1>
+  %6 = subf %5, %4 : f32
+  affine.store %6, %3[] : memref<f32, 1>
+  %7 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %3[] : memref<f32, 1>
+      affine.store %21, %7[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %8 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %7[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %8[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %9 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %9[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %9[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %2[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = addf %22, %21 : f32
+      affine.store %23, %arg11[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %10 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %arg2[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %10[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %10[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = addf %22, %21 : f32
+      affine.store %23, %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %11 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %11[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %12 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %11[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %arg11[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = subf %22, %21 : f32
+      affine.store %23, %12[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %13 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg7[] : memref<f32, 1>
+      affine.store %21, %13[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %14 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg4[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %13[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %14[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %15 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg8[] : memref<f32, 1>
+      affine.store %21, %15[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %16 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %15[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %12[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = addf %22, %21 : f32
+      affine.store %23, %16[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %17 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %16[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = math.sqrt %21 : f32
+      affine.store %22, %17[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %18 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg5[] : memref<f32, 1>
+      affine.store %21, %18[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %19 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %18[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = mulf %22, %21 : f32
+      affine.store %23, %19[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  %20 = memref.alloc() : memref<20x10xf32, 1>
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %17[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %19[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = divf %22, %21 : f32
+      affine.store %23, %20[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %20[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %14[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = addf %22, %21 : f32
+      affine.store %23, %arg12[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  affine.for %arg13 = 0 to 20 {
+    affine.for %arg14 = 0 to 10 {
+      %21 = affine.load %arg12[%arg13, %arg14] : memref<20x10xf32, 1>
+      %22 = affine.load %arg0[%arg13, %arg14] : memref<20x10xf32, 1>
+      %23 = subf %22, %21 : f32
+      affine.store %23, %arg9[%arg13, %arg14] : memref<20x10xf32, 1>
+    }
+  }
+  return
+}
+// CHECK:         affine.for
+// CHECK:         affine.for
+// CHECK-NOT:     affine.for