[SCF][Transform] Add support for scf.for in LoopFuseSibling op (#81495)

Adds support for fusing two scf.for loops occurring in the same block.
Uses the rudimentary checks already in place for scf.forall (like the
target loop's operands being dominated by the source loop).

- Fixes a bug in the dominance check whereby it was checked that values
in the target loop themselves dominated the source loop rather than the
ops that define these operands.
- Renames the LoopFuseSibling op to LoopFuseSiblingOp.
- Updates LoopFuseSiblingOp's description.
- Adds tests for using LoopFuseSiblingOp on scf.for loops, including one
which fails without the fix for the dominance check.
- Adds tests checking the different failure modes of the dominance
checker.
- Adds test for case whereby scf.yield is automatically generated when
there are no loop-carried variables.

GitOrigin-RevId: eacda36c7dd842cb15c0c954eda74b67d0c73814
diff --git a/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 6f94cee..5eefe26 100644
--- a/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -333,23 +333,24 @@
   }];
 }
 
-def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
+def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
   [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
    DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
 
   let description = [{
     Fuses the `target` loop into the `source` loop assuming they are
-    independent of each other. It is the responsibility of the user to ensure
-    that the given two loops are independent of each other, this operation will
-    not performa any legality checks and will simply fuse the two given loops.
+    independent of each other. In the fused loop, the arguments, body and
+    results of `target` are placed _before_ those of `source`.
 
-    Currently, the only fusion supported is when both `target` and `source`
-    are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
-    mapping must match, otherwise a silencable failure is produced.
+    For fusion of two `scf.for` loops, the bounds and step size must match. For
+    fusion of two `scf.forall` loops, the bounds and the mapping must match.
+    Otherwise a silencable failure is produced.
 
-    The input handles `target` and `source` must map to exactly one operation,
-    a definite failure is produced otherwise.
+    The `target` and `source` handles must refer to exactly one operation,
+    otherwise a definite failure is produced. It is the responsibility of the
+    user to ensure that the `target` and `source` loops are independent of each
+    other -- this op will only perform rudimentary legality checks.
 
     #### Return modes
 
@@ -362,10 +363,6 @@
   let results = (outs TransformHandleTypeInterface:$fused_loop);
   let assemblyFormat = "$target `into` $source attr-dict "
                        " `:` functional-type(operands, results)";
-
-  let builders = [
-    OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
-  ];
 }
 
 #endif // SCF_TRANSFORM_OPS
diff --git a/include/mlir/Dialect/SCF/Utils/Utils.h b/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb..883d11b 100644
--- a/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@
                                                 scf::ForallOp source,
                                                 RewriterBase &rewriter);
 
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+                                          RewriterBase &rewriter);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 4d8d93f..c091841 100644
--- a/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -384,7 +384,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// LoopFuseSibling
+// LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
 /// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@
   // Check if fusion will violate dominance.
   DominanceInfo domInfo(source);
   if (target->isBeforeInBlock(source)) {
-    // Since, `target` is before `source`, all users of results of `target`
+    // Since `target` is before `source`, all users of results of `target`
     // need to be dominated by `source`.
     for (Operation *user : target->getUsers()) {
       if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@
     // Check if operands of `target` are dominated by `source`.
     for (Value operand : target->getOperands()) {
       Operation *operandOp = operand.getDefiningOp();
-      // If operand does not have a defining operation, it is a block arguement,
-      // which will always dominate `source`, since `target` and `source` are in
-      // the same block and the operand dominated `source` before.
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
       if (!operandOp)
         continue;
 
@@ -441,8 +440,11 @@
     bool failed = false;
     OpOperand *failedValue = nullptr;
     visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      if (!domInfo.properlyDominates(operand->getOwner(), source,
-                                     /*enclosingOpOk=*/false)) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not an argument of an enclosing block and the defining
+        // op of `operand` is outside `target` but does not dominate `source`.
         failed = true;
         failedValue = operand;
       }
@@ -457,12 +459,11 @@
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Check if `target` can be fused into `source`.
+/// Check if `target` scf.forall can be fused into `source` scf.forall.
 ///
-/// This is a simple check that just checks if both loops have same
-/// bounds, steps and mapping. This check does not ensure that the side effects
-/// of `target` are independent of `source` or vice-versa. It is the
-/// responsibility of the caller to ensure that.
+/// This simply checks if both loops have the same bounds, steps and mapping.
+/// No attempt is made at checking that the side effects of `target` and
+/// `source` are independent of each other.
 static bool isForallWithIdenticalConfiguration(Operation *target,
                                                Operation *source) {
   auto targetOp = dyn_cast<scf::ForallOp>(target);
@@ -476,21 +477,27 @@
          targetOp.getMapping() == sourceOp.getMapping();
 }
 
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
-static Operation *fuseSiblings(Operation *target, Operation *source,
-                               RewriterBase &rewriter) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
+/// Check if `target` scf.for can be fused into `source` scf.for.
+///
+/// This simply checks if both loops have the same bounds and steps. No attempt
+/// is made at checking that the side effects of `target` and `source` are
+/// independent of each other.
+static bool isForWithIdenticalConfiguration(Operation *target,
+                                            Operation *source) {
+  auto targetOp = dyn_cast<scf::ForOp>(target);
+  auto sourceOp = dyn_cast<scf::ForOp>(source);
   if (!targetOp || !sourceOp)
-    return nullptr;
-  return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+    return false;
+
+  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+         targetOp.getStep() == sourceOp.getStep();
 }
 
 DiagnosedSilenceableFailure
-transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
-                                  transform::TransformResults &results,
-                                  transform::TransformState &state) {
+transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
   auto targetOps = state.getPayloadOps(getTarget());
   auto sourceOps = state.getPayloadOps(getSource());
 
@@ -510,13 +517,18 @@
   if (!diag.succeeded())
     return diag;
 
-  // Check if the target can be fused into source.
-  if (!isForallWithIdenticalConfiguration(target, source)) {
+  Operation *fusedLoop;
+  /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
+  if (isForWithIdenticalConfiguration(target, source)) {
+    fusedLoop = fuseIndependentSiblingForLoops(
+        cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
+  } else if (isForallWithIdenticalConfiguration(target, source)) {
+    fusedLoop = fuseIndependentSiblingForallLoops(
+        cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+  } else
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
-  }
 
-  Operation *fusedLoop = fuseSiblings(target, source, rewriter);
   assert(fusedLoop && "failed to fuse operations");
 
   results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
diff --git a/lib/Dialect/SCF/Utils/Utils.cpp b/lib/Dialect/SCF/Utils/Utils.cpp
index 502d7e1..914aeb4 100644
--- a/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/lib/Dialect/SCF/Utils/Utils.cpp
@@ -910,61 +910,98 @@
   unsigned numTargetOuts = target.getNumResults();
   unsigned numSourceOuts = source.getNumResults();
 
-  OperandRange targetOuts = target.getOutputs();
-  OperandRange sourceOuts = source.getOutputs();
-
   // Create fused shared_outs.
   SmallVector<Value> fusedOuts;
-  fusedOuts.reserve(numTargetOuts + numSourceOuts);
-  fusedOuts.append(targetOuts.begin(), targetOuts.end());
-  fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+  llvm::append_range(fusedOuts, target.getOutputs());
+  llvm::append_range(fusedOuts, source.getOutputs());
 
-  // Create a new scf::forall op after the source loop.
+  // Create a new scf.forall op after the source loop.
   rewriter.setInsertionPointAfter(source);
   scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
       source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
       source.getMixedStep(), fusedOuts, source.getMapping());
 
   // Map control operands.
-  IRMapping fusedMapping;
-  fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+  IRMapping mapping;
+  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getRegionIterArgs(),
-                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
-  fusedMapping.map(
-      source.getRegionIterArgs(),
-      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
   for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
   scf::InParallelOp targetTerm = target.getTerminator();
   scf::InParallelOp sourceTerm = source.getTerminator();
   scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
   rewriter.setInsertionPointToStart(fusedTerm.getBody());
   for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
-  // Replace all uses of the old loops with the fused loop.
-  rewriter.replaceAllUsesWith(target.getResults(),
-                              fusedLoop.getResults().slice(0, numTargetOuts));
-  rewriter.replaceAllUsesWith(
-      source.getResults(),
-      fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
 
-  // Erase the old loops.
-  rewriter.eraseOp(target);
-  rewriter.eraseOp(source);
+  return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  unsigned numTargetOuts = target.getNumResults();
+  unsigned numSourceOuts = source.getNumResults();
+
+  // Create fused init_args, with target's init_args before source's init_args.
+  SmallVector<Value> fusedInitArgs;
+  llvm::append_range(fusedInitArgs, target.getInitArgs());
+  llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+  // Create a new scf.for op after the source loop (with scf.yield terminator
+  // (without arguments) only in case its init_args is empty).
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  // Map original induction variables and operands to those of the fused loop.
+  IRMapping mapping;
+  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+  // Merge target's body into the new (fused) for loop and then source's body.
+  rewriter.setInsertionPointToStart(fusedLoop.getBody());
+  for (Operation &op : target.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+  for (Operation &op : source.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+
+  // Build fused yield results by appropriately mapping original yield operands.
+  SmallVector<Value> yieldResults;
+  for (Value operand : target.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  for (Value operand : source.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  if (!yieldResults.empty())
+    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
 
   return fusedLoop;
 }
diff --git a/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db..0f51b1c 100644
--- a/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,14 +1,113 @@
 // RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s --check-prefix CHECK-NOCLEANUP
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @fuse_1st_for_into_2nd([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_1st_for_into_2nd(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IA:%.*]] = [[A]], [[IB:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IA]][[[IV]]]
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB]][[[IV]]]
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB:%.*]] = [[B]], [[IA:%.*]] = [[A]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB]][[[IV]]]
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IA]][[[IV]]]
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+  // NB: the dominance check used to fail on the following line,
+  // however the defining op for the value of %arg3 occurs above the source loop and hence is safe
+  // and %arg4 is a block argument of the scope of the loops and hence is safe
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @matmul_fuse_1st_forall_into_2nd([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_1st_forall_into_2nd(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   // CHECK:   [[T:%.*]] = affine.apply
+  // CHECK:   tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   [[OUT1:%.*]] = linalg.matmul
+  // CHECK:   tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   [[OUT2:%.*]] = linalg.matmul
   // CHECK:   scf.forall.in_parallel {
@@ -16,68 +115,11 @@
   // CHECK:     tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   }
   // CHECK: }
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
-    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
-
-    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
-  %zero = arith.constant 0.0 : f32
-  %out_alloc = tensor.empty() : tensor<128x128xf32>
-  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
-
-  // expected-error @below {{user of results of target should be properly dominated by source}}
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-
-  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
-    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
-
-    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-
-    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
-  %zero = arith.constant 0.0 : f32
-  %out_alloc = tensor.empty() : tensor<128x128xf32>
-  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
-
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  // expected-error @below {{values used inside regions of target should be properly dominated by source}}
-  %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-
-  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
-}
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -94,18 +136,185 @@
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  // expected-error @below {{operands of target should be properly dominated by source}}
-  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  // CHECK:   [[T:%.*]] = affine.apply
+  // CHECK:   tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   [[OUT1:%.*]] = linalg.matmul
+  // CHECK:   tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   [[OUT2:%.*]] = linalg.matmul
+  // CHECK:   scf.forall.in_parallel {
+  // CHECK:     tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:     tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   }
+  // CHECK: }
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
 
+    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-NOCLEANUP: func.func @fuse_no_iter_args([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_no_iter_args(%A: tensor<128xf32>, %B: tensor<128xf32>) {
+  // CHECK-NOCLEANUP: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-NOCLEANUP: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-NOCLEANUP: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-NOCLEANUP: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK-NOCLEANUP: scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] {{.*}}
+  scf.for %arg0 = %c0 to %c128 step %c16 {
+  // CHECK-NOCLEANUP:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+    %2 = vector.transfer_read %A[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    scf.yield
+  }
+  scf.for %arg0 = %c0 to %c128 step %c16 {
+  // CHECK-NOCLEANUP:   [[BSLICE:%.*]] = vector.transfer_read [[B]][[[IV]]], [[ZERO]]
+    %dup2 = vector.transfer_read %B[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    scf.yield
+  }
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // expected-error @below {{user of results of target should be properly dominated by source}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %1) -> (tensor<128xf32>) {
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @source_forall_uses_result_of_target_forall_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %out_alloc = tensor.empty() : tensor<128x128xf32>
+  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // expected-error @below {{user of results of target should be properly dominated by source}}
+  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+    %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @target_forall_depends_on_value_not_dominated_by_source_forall_err(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %buf1_alloc = tensor.empty() : tensor<128x128xf32>
+  %buf1 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out_alloc2 = tensor.empty() : tensor<128x128xf32>
+  %buf2 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+  // expected-error @below {{operands of target should be properly dominated by source}}
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)