[NFC][mlir][shard] Unify MoveLastSplitAxisPattern/MoveLastSplitAxisPattern (#192295)

Made MoveLastSplitAxisPattern more general to also cover MoveLastSplitAxisPattern.
Less code, same functionality.
Assisted by claude.
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 57502fb..05b864d 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -271,106 +271,15 @@
   return srcShape.cloneWith(tgtShape, srcShape.getElementType());
 }
 
-/// Move a split axis between tensor dimensions:
-/// e.g. [[0], []] -> [[], [0]].
-class MoveSplitAxisPattern : public ReshardingPattern {
-  // Detect if the resharding moves a single split axis from one tensor
-  // dimension to another tensor dimension. If detected, returns the
-  // corresponding (tgt_tensor_dim, grid_axis) pair.
-  static std::optional<std::tuple<int64_t, GridAxis>>
-  detect(const Sharding &srcSharding, const Sharding &tgtSharding,
-         int64_t srcTensorDim) {
-    if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
-      return std::nullopt;
-    auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
-    if (srcAxes.size() != 1)
-      return std::nullopt;
-    for (size_t tgtTensorDim = 0;
-         tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
-      if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
-        continue;
-      auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
-      if (tgtAxes.size() != 1 || srcAxes.front() != tgtAxes.front())
-        continue;
-      return std::make_tuple(static_cast<int64_t>(tgtTensorDim),
-                             srcAxes.front());
-    }
-    return std::nullopt;
-  }
-
-  static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
-                              int64_t srcTensorDim, int64_t tgtTensorDim) {
-    SmallVector<GridAxesAttr> tgtShardingSplitAxes =
-        llvm::to_vector(srcSharding.getSplitAxes());
-    while (static_cast<int64_t>(tgtShardingSplitAxes.size()) <= tgtTensorDim) {
-      tgtShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
-    }
-
-    auto srcSplitAxes =
-        llvm::to_vector(tgtShardingSplitAxes[srcTensorDim].asArrayRef());
-    assert(srcSplitAxes.size() == 1);
-    auto gridAxis = srcSplitAxes.back();
-    srcSplitAxes.pop_back();
-    tgtShardingSplitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
-
-    auto tgtSplitAxes =
-        llvm::to_vector(tgtShardingSplitAxes[tgtTensorDim].asArrayRef());
-    tgtSplitAxes.push_back(gridAxis);
-    tgtShardingSplitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
-
-    return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
-  }
-
-  static std::tuple<TypedValue<ShapedType>, Sharding>
-  apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
-        ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
-        int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis gridAxis) {
-    MLIRContext *ctx = builder.getContext();
-    builder.setInsertionPointAfterValue(srcShard);
-
-    Sharding resultSharding =
-        tgtSharding(ctx, std::move(srcSharding), srcTensorDim, tgtTensorDim);
-    ShapedType a2aResultShape =
-        allToAllResultShape(srcShard.getType(), grid.getShape()[gridAxis],
-                            srcTensorDim, tgtTensorDim);
-    Value allToAllResult = AllToAllOp::create(
-        builder,
-        RankedTensorType::get(a2aResultShape.getShape(),
-                              a2aResultShape.getElementType()),
-        grid.getSymName(), SmallVector<GridAxis>({gridAxis}), srcShard,
-        APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
-    ShapedType tgtShape =
-        shardShapedType(srcUnshardedType, grid, resultSharding);
-    TypedValue<ShapedType> tgtShard =
-        tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
-    return {tgtShard, resultSharding};
-  }
-
-public:
-  std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
-  tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
-           const Sharding &srcSharding, const Sharding &tgtSharding,
-           ShapedType srcUnshardedType,
-           TypedValue<ShapedType> srcShard) override {
-    if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
-      return std::nullopt;
-    if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
-      auto [tgtTensorDim, gridAxis] = detectRes.value();
-      return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
-                   tensorDim, tgtTensorDim, gridAxis);
-    }
-    return std::nullopt;
-  }
-};
-
 /// Move the last split axis of one tensor dimension to the front of another
-/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+/// tensor dimension's split axes, e.g. [[0], []] -> [[], [0]] or
+/// [[0, 1], [2]] -> [[0], [1, 2]].
 class MoveLastSplitAxisPattern : public ReshardingPattern {
   // Detect if the resharding moves the last grid axis of srcTensorDim to the
   // front of another tensor dimension's split axes. If detected, returns
   // (tgtTensorDim, movedGridAxis).
   //
-  // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an]  (n >= 2)
+  // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an]  (n >= 1)
   //          tgt[srcTensorDim] = [a1,...,a(n-1)]
   //          src[tgtTensorDim] = [b1,...,bm]          (m >= 0)
   //          tgt[tgtTensorDim] = [an, b1,...,bm]
@@ -380,8 +289,8 @@
     if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
       return std::nullopt;
     auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
-    // Need at least 2 axes to move the last one.
-    if (srcAxes.size() < 2)
+    // Need at least 1 axis to move.
+    if (srcAxes.empty())
       return std::nullopt;
 
     // After the move the source tensor dim should lose its last axis.
@@ -586,12 +495,11 @@
   // Each pattern's tryApply checks its own applicability preconditions.
   static UpdateHaloPattern updateHaloPattern;
   static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
-  static MoveSplitAxisPattern moveSplitAxisPattern;
   static SplitLastAxisPattern splitLastAxisPattern;
   static UnsplitLastAxesPattern unsplitLastAxesPattern;
   static ReshardingPattern *patterns[] = {
-      &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern,
-      &splitLastAxisPattern, &unsplitLastAxesPattern};
+      &updateHaloPattern, &moveLastSplitAxisPattern, &splitLastAxisPattern,
+      &unsplitLastAxesPattern};
   TypedValue<ShapedType> currentShard = shardedSrc;
   Sharding currentSharding = srcSharding;
   for (int64_t dim = 0;