[mlir] Update dstNode after DenseMap insertion in loop fusion pass.
Reviewed By: vinayaka-polymage
Differential Revision: https://reviews.llvm.org/D101794
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 50adc97..aea8d98 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1645,6 +1645,10 @@
// Add edge from 'newMemRef' node to dstNode.
mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
}
+ // One or more entries for 'newMemRef' alloc op are inserted into
+ // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
+ // reallocate, update dstNode.
+ dstNode = mdg->getNode(dstId);
}
// Collect dst loop stats after memref privatization transformation.
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index f3fcae2..2cad8613 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3115,3 +3115,186 @@
// CHECK-NEXT: affine.load
// CHECK-NEXT: mulf
// CHECK-NEXT: affine.store
+
+// -----
+
+// CHECK-LABEL: func @fuse_large_number_of_loops
+func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x10xf32, 1>, %arg2: memref<20x10xf32, 1>, %arg3: memref<20x10xf32, 1>, %arg4: memref<20x10xf32, 1>, %arg5: memref<f32, 1>, %arg6: memref<f32, 1>, %arg7: memref<f32, 1>, %arg8: memref<f32, 1>, %arg9: memref<20x10xf32, 1>, %arg10: memref<20x10xf32, 1>, %arg11: memref<20x10xf32, 1>, %arg12: memref<20x10xf32, 1>) {
+ %cst = constant 1.000000e+00 : f32
+ %0 = memref.alloc() : memref<f32, 1>
+ affine.store %cst, %0[] : memref<f32, 1>
+ %1 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg6[] : memref<f32, 1>
+ affine.store %21, %1[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %2 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg3[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %2[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %3 = memref.alloc() : memref<f32, 1>
+ %4 = affine.load %arg6[] : memref<f32, 1>
+ %5 = affine.load %0[] : memref<f32, 1>
+ %6 = subf %5, %4 : f32
+ affine.store %6, %3[] : memref<f32, 1>
+ %7 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %3[] : memref<f32, 1>
+ affine.store %21, %7[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %8 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %7[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %8[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %9 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %9[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %9[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %2[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %arg11[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %10 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg2[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %10[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %10[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %11 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %11[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %12 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %11[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg11[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = subf %22, %21 : f32
+ affine.store %23, %12[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %13 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg7[] : memref<f32, 1>
+ affine.store %21, %13[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %14 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg4[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %13[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %14[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %15 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg8[] : memref<f32, 1>
+ affine.store %21, %15[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %16 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %15[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %12[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %16[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %17 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %16[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = math.sqrt %21 : f32
+ affine.store %22, %17[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %18 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg5[] : memref<f32, 1>
+ affine.store %21, %18[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %19 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %18[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = mulf %22, %21 : f32
+ affine.store %23, %19[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ %20 = memref.alloc() : memref<20x10xf32, 1>
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %17[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %19[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = divf %22, %21 : f32
+ affine.store %23, %20[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %20[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %14[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = addf %22, %21 : f32
+ affine.store %23, %arg12[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ affine.for %arg13 = 0 to 20 {
+ affine.for %arg14 = 0 to 10 {
+ %21 = affine.load %arg12[%arg13, %arg14] : memref<20x10xf32, 1>
+ %22 = affine.load %arg0[%arg13, %arg14] : memref<20x10xf32, 1>
+ %23 = subf %22, %21 : f32
+ affine.store %23, %arg9[%arg13, %arg14] : memref<20x10xf32, 1>
+ }
+ }
+ return
+}
+// CHECK: affine.for
+// CHECK: affine.for
+// CHECK-NOT: affine.for