[MLIR][Vector] Add warp distribution for `scf.if` (#157119)

This PR adds `scf.if` op distribution to the existing `VectorDistribute`
patterns. The logic mostly follows that of `scf.for`: move op outside, wrap each
branch with `gpu.warp_execute_on_lane_0`. A notable difference to `scf.for` is
that each branch has its own set of escaping values, and `scf.if` itself does not
have block arguments.

GitOrigin-RevId: 7f007b572d63fe802a1f5587c74aae437f9f50e6
diff --git a/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c84eb2c..995a259 100644
--- a/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -371,6 +371,38 @@
   return targetType;
 }
 
+/// Given a warpOp that contains ops with regions, the corresponding op's
+/// "inner" region and the distributionMapFn, get all values used by the op's
+/// region that are defined within the warpOp, but outside the inner region.
+/// Return the set of values, their types and their distributed types.
+std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
+           SmallVector<Type>>
+getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
+                             DistributionMapFn distributionMapFn) {
+  llvm::SmallSetVector<Value, 32> escapingValues;
+  SmallVector<Type> escapingValueTypes;
+  SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
+  if (innerRegion.empty())
+    return {std::move(escapingValues), std::move(escapingValueTypes),
+            std::move(escapingValueDistTypes)};
+  mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
+    Operation *parent = operand->get().getParentRegion()->getParentOp();
+    if (warpOp->isAncestor(parent)) {
+      if (!escapingValues.insert(operand->get()))
+        return;
+      Type distType = operand->get().getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(operand->get());
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      escapingValueTypes.push_back(operand->get().getType());
+      escapingValueDistTypes.push_back(distType);
+    }
+  });
+  return {std::move(escapingValues), std::move(escapingValueTypes),
+          std::move(escapingValueDistTypes)};
+}
+
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
 /// will not be distributed (it should be less than the warp size).
@@ -1713,6 +1745,215 @@
   }
 };
 
+/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
+/// the scf.if is the last operation in the region so that it doesn't
+/// change the order of execution. This creates a new scf.if after the
+/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
+/// the "inner" WarpExecuteOnLane0Op. Example:
+/// ```
+/// gpu.warp_execute_on_lane_0(%laneid)[32] {
+///   %payload = ... : vector<32xindex>
+///   scf.if %pred {
+///     vector.store %payload, %buffer[%idx] : memref<128xindex>,
+///     vector<32xindex>
+///   }
+///   gpu.yield
+/// }
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
+///   %payload = ... : vector<32xindex>
+///   gpu.yield %payload : vector<32xindex>
+/// }
+/// scf.if %pred {
+///   gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
+///     ^bb0(%arg1: vector<32xindex>):
+///     vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
+///   }
+/// }
+/// ```
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+  WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
+    // Only pick up `IfOp` if it is the last op in the region.
+    Operation *lastNode = warpOpYield->getPrevNode();
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+    if (!ifOp)
+      return failure();
+
+    // The current `WarpOp` can yield two types of values:
+    // 1. Not results of `IfOp`:
+    //     Preserve them in the new `WarpOp`.
+    //     Collect their yield index to remap the usages.
+    // 2. Results of `IfOp`:
+    //     They are not part of the new `WarpOp` results.
+    //     Map current warp's yield operand index to `IfOp` result idx.
+    SmallVector<Value> nonIfYieldValues;
+    SmallVector<unsigned> nonIfYieldIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+    llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+      const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+      if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+        nonIfYieldValues.push_back(yieldOperand.get());
+        nonIfYieldIndices.push_back(yieldOperandIdx);
+        continue;
+      }
+      OpResult ifResult = cast<OpResult>(yieldOperand.get());
+      const unsigned ifResultIdx = ifResult.getResultNumber();
+      ifResultMapping[yieldOperandIdx] = ifResultIdx;
+      // If this `ifOp` result is vector type and it is yielded by the
+      // `WarpOp`, we keep track the distributed type for this result.
+      if (!isa<VectorType>(ifResult.getType()))
+        continue;
+      VectorType distType =
+          cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+      ifResultDistTypes[ifResultIdx] = distType;
+    }
+
+    // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+    // them
+    auto [escapingValuesThen, escapingValueInputTypesThen,
+          escapingValueDistTypesThen] =
+        getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
+                                     distributionMapFn);
+    auto [escapingValuesElse, escapingValueInputTypesElse,
+          escapingValueDistTypesElse] =
+        getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
+                                     distributionMapFn);
+    if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+        llvm::is_contained(escapingValueDistTypesElse, Type{}))
+      return failure();
+
+    // The new `WarpOp` groups yields values in following order:
+    // 1. Branch condition
+    // 2. Escaping values then branch
+    // 3. Escaping values else branch
+    // 4. All non-`ifOp` yielded values.
+    SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
+    newWarpOpYieldValues.append(escapingValuesThen.begin(),
+                                escapingValuesThen.end());
+    newWarpOpYieldValues.append(escapingValuesElse.begin(),
+                                escapingValuesElse.end());
+    SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
+    newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
+                              escapingValueDistTypesThen.end());
+    newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+                              escapingValueDistTypesElse.end());
+
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+    for (auto [idx, val] :
+         llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+      origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(val);
+      newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+    }
+    // Create the new `WarpOp` with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    // `ifOp` returns the result of the inner warp op.
+    SmallVector<Type> newIfOpDistResTypes;
+    for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+      Type distType = cast<Value>(res).getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(cast<Value>(res));
+        // Fallback to affine map if the dist result was not previously recorded
+        distType = ifResultDistTypes.count(i)
+                       ? ifResultDistTypes[i]
+                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newIfOpDistResTypes.push_back(distType);
+    }
+    // Create a new `IfOp` outside the new `WarpOp` region.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newIfOp = scf::IfOp::create(
+        rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
+        static_cast<bool>(ifOp.thenBlock()),
+        static_cast<bool>(ifOp.elseBlock()));
+    auto encloseRegionInWarpOp =
+        [&](Block *oldIfBranch, Block *newIfBranch,
+            llvm::SmallSetVector<Value, 32> &escapingValues,
+            SmallVector<Type> &escapingValueInputTypes,
+            size_t warpResRangeStart) {
+          OpBuilder::InsertionGuard g(rewriter);
+          if (!newIfBranch)
+            return;
+          rewriter.setInsertionPointToStart(newIfBranch);
+          llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
+          SmallVector<Value> innerWarpInputVals;
+          SmallVector<Type> innerWarpInputTypes;
+          for (size_t i = 0; i < escapingValues.size();
+               ++i, ++warpResRangeStart) {
+            innerWarpInputVals.push_back(
+                newWarpOp.getResult(warpResRangeStart));
+            escapeValToBlockArgIndex[escapingValues[i]] =
+                innerWarpInputTypes.size();
+            innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
+          }
+          auto innerWarp = WarpExecuteOnLane0Op::create(
+              rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
+              newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
+              innerWarpInputVals, innerWarpInputTypes);
+
+          innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
+          innerWarp.getWarpRegion().addArguments(
+              innerWarpInputTypes,
+              SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
+
+          SmallVector<Value> yieldOperands;
+          for (Value operand : oldIfBranch->getTerminator()->getOperands())
+            yieldOperands.push_back(operand);
+          rewriter.eraseOp(oldIfBranch->getTerminator());
+
+          rewriter.setInsertionPointToEnd(innerWarp.getBody());
+          gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
+          rewriter.setInsertionPointAfter(innerWarp);
+          scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
+
+          // Update any users of escaping values that were forwarded to the
+          // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
+          innerWarp.walk([&](Operation *op) {
+            for (OpOperand &operand : op->getOpOperands()) {
+              auto it = escapeValToBlockArgIndex.find(operand.get());
+              if (it == escapeValToBlockArgIndex.end())
+                continue;
+              operand.set(innerWarp.getBodyRegion().getArgument(it->second));
+            }
+          });
+          mlir::vector::moveScalarUniformCode(innerWarp);
+        };
+    encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
+                          &newIfOp.getThenRegion().front(), escapingValuesThen,
+                          escapingValueInputTypesThen, 1);
+    if (!ifOp.getElseRegion().empty())
+      encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
+                            &newIfOp.getElseRegion().front(),
+                            escapingValuesElse, escapingValueInputTypesElse,
+                            1 + escapingValuesThen.size());
+    // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
+    // result.
+    for (auto [origIdx, newIdx] : ifResultMapping)
+      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+                                    newIfOp.getResult(newIdx), newIfOp);
+    // Similarly, update any users of the `WarpOp` results that were not
+    // results of the `IfOp`.
+    for (auto [origIdx, newIdx] : origToNewYieldIdx)
+      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+                                  newWarpOp.getResult(newIdx));
+    // Remove the original `WarpOp` and `IfOp`, they should not have any uses
+    // at this point.
+    rewriter.eraseOp(ifOp);
+    rewriter.eraseOp(warpOp);
+    return success();
+  }
+
+private:
+  DistributionMapFn distributionMapFn;
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't
 /// change the order of execution. This creates a new scf.for region after the
@@ -1759,25 +2000,9 @@
       return failure();
     // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
     // Those Values need to be returned by the new warp op.
-    llvm::SmallSetVector<Value, 32> escapingValues;
-    SmallVector<Type> escapingValueInputTypes;
-    SmallVector<Type> escapingValueDistTypes;
-    mlir::visitUsedValuesDefinedAbove(
-        forOp.getBodyRegion(), [&](OpOperand *operand) {
-          Operation *parent = operand->get().getParentRegion()->getParentOp();
-          if (warpOp->isAncestor(parent)) {
-            if (!escapingValues.insert(operand->get()))
-              return;
-            Type distType = operand->get().getType();
-            if (auto vecType = dyn_cast<VectorType>(distType)) {
-              AffineMap map = distributionMapFn(operand->get());
-              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-            }
-            escapingValueInputTypes.push_back(operand->get().getType());
-            escapingValueDistTypes.push_back(distType);
-          }
-        });
-
+    auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
+        getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
+                                     distributionMapFn);
     if (llvm::is_contained(escapingValueDistTypes, Type{}))
       return failure();
     // `WarpOp` can yield two types of values:
@@ -2068,6 +2293,8 @@
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
+  patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
+                              benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
diff --git a/test/Dialect/Vector/vector-warp-distribute.mlir b/test/Dialect/Vector/vector-warp-distribute.mlir
index 8750582..bb76392 100644
--- a/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1856,3 +1856,72 @@
 // CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
 // CHECK-PROP-NOT: vector.broadcast
 // CHECK-PROP: vector.step : vector<64xindex>
+
+// -----
+
+func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1)  {
+  %laneid = gpu.lane_id
+  %c0 = arith.constant 0 : index
+
+  gpu.warp_execute_on_lane_0(%laneid)[32] {
+    %seq = vector.step : vector<32xindex>
+    scf.if %pred {
+      vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
+    }
+    gpu.yield
+  }
+  return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
+//       CHECK-PROP:   scf.if %[[ARG1]] {
+//       CHECK-PROP:   gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
+//       CHECK-PROP:   ^bb0(%[[ARG2:.+]]: vector<32xindex>):
+//       CHECK-PROP:   vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>
+
+// -----
+
+func.func @warp_scf_if_distribute(%pred : i1)  {
+  %laneid = gpu.lane_id
+  %c0 = arith.constant 0 : index
+
+  %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
+    %seq1 = vector.step : vector<32xindex>
+    %seq2 = arith.constant dense<2> : vector<32xindex>
+    %0 = scf.if %pred -> (vector<32xf32>) {
+      %1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
+      scf.yield %1 : vector<32xf32>
+    } else {
+      %2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
+      scf.yield %2 : vector<32xf32>
+    }
+    gpu.yield %0 : vector<32xf32>
+  }
+  "some_use"(%0) : (vector<1xf32>) -> ()
+
+  return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
+//  CHECK-PROP-SAME:    %[[ARG0:.+]]: i1
+//       CHECK-PROP:    %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
+//       CHECK-PROP:    %[[LANE_ID:.+]] = gpu.lane_id
+//       CHECK-PROP:    %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
+//       CHECK-PROP:    %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
+//       CHECK-PROP:    %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
+//       CHECK-PROP:        ^bb0(%[[ARG1:.+]]: vector<32xindex>):
+//       CHECK-PROP:        %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
+//       CHECK-PROP:        gpu.yield %{{.*}} : vector<32xf32>
+//       CHECK-PROP:      }
+//       CHECK-PROP:      scf.yield %[[THEN_DIST]] : vector<1xf32>
+//       CHECK-PROP:    } else {
+//       CHECK-PROP:      %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
+//       CHECK-PROP:        %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
+//       CHECK-PROP:        gpu.yield %{{.*}} : vector<32xf32>
+//       CHECK-PROP:      }
+//       CHECK-PROP:      scf.yield %[[ELSE_DIST]] : vector<1xf32>
+//       CHECK-PROP:    }
+//       CHECK-PROP:    "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
+//       CHECK-PROP:    return
+//       CHECK-PROP:  }
diff --git a/test/Dialect/XeGPU/subgroup-distribute.mlir b/test/Dialect/XeGPU/subgroup-distribute.mlir
index a39aa90..60acea0 100644
--- a/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -339,6 +339,63 @@
 }
 
 // -----
+// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}},
+// CHECK-SAME: %[[PREDICATE:.*]]: i1) {
+// CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) {
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16>
+// CHECK-NEXT: } else {
+// CHECK-NEXT:   scf.yield %[[DEFAULT]] : vector<8xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>, %pred : i1) {
+    %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+    %loaded = scf.if %pred -> (vector<16x8xf16>) {
+      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+      scf.yield %3 : vector<16x8xf16>
+    } else {
+      %3 = arith.constant {
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      } dense<12.> : vector<16x8xf16>
+      scf.yield %3 : vector<16x8xf16>
+    } { layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> }
+    xegpu.store %loaded, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) {
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1
+// CHECK: scf.if %[[PREDICATE]] {
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+// CHECK-NEXT: }
+gpu.module @test {
+  gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) {
+    %pred = llvm.mlir.poison : i1
+    %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+    scf.if %pred  {
+      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
+        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+      xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    }
+    gpu.return
+  }
+}
+
+// -----
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
 // CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>