[mlir][shard,mpi] Allowing 2d-grids and simplifying lowering shard.all_gather (#180243)

- fixing incorrect assertion and related function name
- MPI_comm_split is not pure
- simplifying/standardizing permutation in all_gather

---------

Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d9e47ea..7e68b15 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -103,7 +103,7 @@
 // CommSplitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
+def MPI_CommSplitOp : MPI_Op<"comm_split"> {
   let summary = "Partition the group associated with the given communicator into "
                 "disjoint subgroups";
   let description = [{
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index c765ad5..1db14e6 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -16,7 +16,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -624,9 +624,8 @@
   // shard.allgather concatenates along a specified gather-axis.
   // mpi.allgather always concatenates along the first dimension and
   // there is no MPI operation that allows gathering along an arbitrary axis.
-  // Hence, if gather-axis!=0, we need to create a temporary buffer
-  // where we gather along the first dimension and then copy from that
-  // buffer to the final output along the specified gather-axis.
+  // Hence, if gather-axis != 0, we need to permute the output buffer
+  // accordingly.
 
   LogicalResult
   matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
@@ -635,104 +634,124 @@
     FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
     if (failed(gridOp))
       return failure();
-    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
-    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+
+    ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), ib);
     MemRefType inType = cast<MemRefType>(input.getType());
-    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
-      return op.emitError(
-          "Expected static shaped memref in contiguous row-major layout.");
     MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
-    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
-      return op.emitError(
-          "Expected static shaped memref in contiguous row-major layout.");
+    auto inputShape = inType.getShape();
+    auto outputShape = outType.getShape();
     int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
-    auto ctx = op->getContext();
+    int64_t inputDimOnAxis = inputShape[gatherAxis];
+    int64_t outputDimOnAxis = outputShape[gatherAxis];
 
-    // Get the right communicator
-    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
-
-    Value nRanks =
-        mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
-            .getSize();
-    nRanks =
-        arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
-
-    Value tmpOutput, gatherDimSz;
-    if (gatherAxis == 0) {
-      tmpOutput = memref::AllocOp::create(iBuilder, outType);
-    } else {
-      // MPI's allgather always concatenates along the first dimension.
-      // Create a memref type for the output buffer with adjusted (expanded)
-      // shape.
-      SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
-      llvm::append_range(gatherShape, outType.getShape());
-      gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
-      MemRefType gatherType =
-          MemRefType::get(gatherShape, outType.getElementType());
-      gatherDimSz = arith::ConstantIndexOp::create(
-          iBuilder, outType.getDimSize(gatherAxis));
-      gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
-                                           gatherDimSz, nRanks);
-      // Allocate output buffer
-      tmpOutput =
-          memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
+    for (size_t i = 0; i < outputShape.size(); ++i)
+      if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis)
+        return op.emitError(
+            "Result and input shapes must match along non-gather axes.");
+    if (inputDimOnAxis == 0)
+      return op.emitError("Input size along the gather axis must be non-zero.");
+    if (inputDimOnAxis == 1) {
+      assert(outputDimOnAxis == inputDimOnAxis);
+      rewriter.replaceOp(op, adaptor.getInput());
+      return success();
     }
+    if (outputDimOnAxis % inputDimOnAxis != 0)
+      return op.emitError("Result size along the gather axis must be an exact "
+                          "multiple of the input size along the gather axis.");
+
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType) ||
+        !memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError("Input/result must be statically shaped memrefs in "
+                          "contiguous row-major layout.");
+
+    // Get the right communicator.
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
+    Value nRanksV =
+        mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
+    nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
+    int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
+    Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks);
+    Value notError =
+        arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
+    cf::AssertOp::create(ib, notError,
+                         "Expected number of ranks in the communicator to "
+                         "match the output size along the gather axis divided "
+                         "by the input size along the gather axis.");
+
+    // mpi.allgather always concatenates along the first dimension, so
+    // get a output buffer of shape {nRanks, dim0, ...}.
+    SmallVector<int64_t> gatherShape;
+    gatherShape.emplace_back(nRanks);
+    gatherShape.append(inputShape.begin(), inputShape.end());
+    auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
+    Value finalOutput = memref::AllocOp::create(ib, gatherType);
     // Create the MPI AllGather operation.
-    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
+    mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm);
 
-    // If gather-axis!=0, copy from gathered buffer to output with the right
-    // layout.
-    Value finalOutput = tmpOutput;
-    if (gatherAxis != 0) {
-      int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
-      assert(nSrcDims == outType.getRank() + 1 &&
-             "Expected gathered type to have rank one more than output type.");
+    if (gatherAxis == 0) {
+      // If gather axis == 0, simply collapse the first 2 dims from {nRanks,
+      // dim0, ...} to {nRanks*dim0, ...}.
+      SmallVector<ReassociationIndices> reassociation;
+      reassociation.push_back({0, 1});
+      int64_t numGatherDims = gatherShape.size();
+      for (int64_t i = 2; i < numGatherDims; ++i)
+        reassociation.push_back({i});
+      finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
+                                                    reassociation);
 
-      // Create affine map for copying from gathered buffer to output.
-      SmallVector<AffineExpr> dims;
-      dims.reserve(nSrcDims);
-      for (unsigned i = 0; i < nSrcDims; ++i)
-        dims.emplace_back(getAffineDimExpr(i, ctx));
-      AffineExpr s = getAffineSymbolExpr(0, ctx);
-      SmallVector<AffineExpr> results;
-      results.reserve(nSrcDims);
-      for (unsigned i = 0; i < nSrcDims - 1; ++i) {
-        if (i == gatherAxis)
-          results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
-        else
-          results.emplace_back(dims[i + 1]);
+      // If the op's result is a tensor, cast it to a tensor.
+      if (isa<RankedTensorType>(op.getType()))
+        finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
+                                                        finalOutput, true);
+    } else {
+      // 1. Enter tensor-land.
+      auto inType =
+          RankedTensorType::get(gatherShape, outType.getElementType());
+      finalOutput =
+          bufferization::ToTensorOp::create(ib, inType, finalOutput, true);
+
+      // 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...}
+      // to {dim0, ..., nRanks, dim1,...}.
+      SmallVector<int64_t> outShapePermuted, permutation;
+      for (int i = 1; i <= gatherAxis; ++i) {
+        outShapePermuted.emplace_back(gatherShape[i]);
+        permutation.emplace_back(i);
       }
-      auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
+      outShapePermuted.emplace_back(gatherShape[0]);
+      permutation.emplace_back(0);
+      for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
+        outShapePermuted.emplace_back(gatherShape[i]);
+        permutation.emplace_back(i);
+      }
+      Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
+                                                 outType.getElementType());
+      finalOutput =
+          linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
+              ->getResult(0);
 
-      finalOutput = memref::AllocOp::create(iBuilder, outType);
+      // 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...}
+      // to {dim0, ..., nRanks*gatherAxis, ...}.
+      SmallVector<ReassociationIndices> reassociation;
+      for (int64_t i = 0; i < gatherAxis; ++i) {
+        reassociation.push_back({i});
+      }
+      reassociation.push_back({gatherAxis, gatherAxis + 1});
+      for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
+           ++i) {
+        reassociation.push_back({i});
+      }
+      auto outTType =
+          RankedTensorType::get(outputShape, outType.getElementType());
+      finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
+                                                    reassociation);
 
-      // Now build a loop nest to copy from gathered buffer to finalOutput
-      // It would be nicer to just use a memref.transpose/collapse_shape op but
-      // these currently only support simpler cases.
-      Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
-      SmallVector<Value> lbs(nSrcDims, zero);
-      SmallVector<Value> ubs;
-      for (int64_t d = 0; d < nSrcDims; ++d)
-        ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
-      SmallVector<int64_t> steps(nSrcDims, 1);
-      auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
-        Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
-        // set symbol value
-        SmallVector<Value> ivss(ivs.begin(), ivs.end());
-        ivss.emplace_back(gatherDimSz);
-        affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
-                                      ivss);
-      };
-      affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
-                                  emitCopy);
-
-      memref::DeallocOp::create(iBuilder, tmpOutput);
+      // 4. Cast back to memref if needed.
+      if (isa<MemRefType>(op.getType()))
+        finalOutput =
+            bufferization::ToBufferOp::create(ib, outType, finalOutput);
     }
 
-    // If the destination is a tensor, cast it to a tensor
-    if (isa<RankedTensorType>(op.getType()))
-      finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
-                                                      finalOutput, true);
     rewriter.replaceOp(op, finalOutput);
     return success();
   }
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 62dc8f5..e619c70 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -436,24 +436,32 @@
                          targetSharding);
 }
 
-// Handles only resharding on a 1D shard.
-// Currently the sharded tensor axes must be exactly divisible by the single
-// grid axis size.
+// In most cases the sharded tensor axes must be exactly divisible by the single
+// grid axis size. Only halo size changes can deal with non-divisible cases.
 static TypedValue<ShapedType>
-reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
-                const Sharding &sourceSharding, const Sharding &targetSharding,
-                TypedValue<ShapedType> sourceUnshardedValue,
-                TypedValue<ShapedType> sourceShard) {
+reshard(ImplicitLocOpBuilder &builder, GridOp grid,
+        const Sharding &sourceSharding, const Sharding &targetSharding,
+        TypedValue<ShapedType> sourceUnshardedValue,
+        TypedValue<ShapedType> sourceShard) {
+  // If source and destination sharding are the same, no need to do anything.
+  if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
+                                           isFullReplication(targetSharding))) {
+    return sourceShard;
+  }
+
+  // Tries to handle the case where the resharding is needed because the halo
+  // sizes are different. Supports arbitrary grid dimensionality.
+  if (auto tryRes = tryUpdateHaloInResharding(
+          builder, grid, sourceSharding, targetSharding,
+          sourceUnshardedValue.getType(), sourceShard)) {
+    return std::get<0>(tryRes.value()); // targetShard
+  }
+
   assert(sourceShard.getType() ==
          shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
   [[maybe_unused]] ShapedType targetShardType =
       shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
   assert(sourceShard.getType().getRank() == targetShardType.getRank());
-  assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
-
-  if (sourceSharding == targetSharding) {
-    return sourceShard;
-  }
 
   TypedValue<ShapedType> targetShard;
   Sharding actualTargetSharding;
@@ -475,38 +483,13 @@
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
     }
   }
+
   assert(targetShard && "Did not find any pattern to apply.");
   assert(actualTargetSharding == targetSharding);
   assert(targetShard.getType() == targetShardType);
   return targetShard;
 }
 
-static TypedValue<ShapedType>
-reshard(ImplicitLocOpBuilder &builder, GridOp grid,
-        const Sharding &sourceSharding, const Sharding &targetSharding,
-        TypedValue<ShapedType> sourceUnshardedValue,
-        TypedValue<ShapedType> sourceShard) {
-  // If source and destination sharding are the same, no need to do anything.
-  if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
-                                           isFullReplication(targetSharding))) {
-    return sourceShard;
-  }
-
-  // Tries to handle the case where the resharding is needed because the halo
-  // sizes are different. Supports arbitrary grid dimensionality.
-  if (auto tryRes = tryUpdateHaloInResharding(
-          builder, grid, sourceSharding, targetSharding,
-          sourceUnshardedValue.getType(), sourceShard)) {
-    return std::get<0>(tryRes.value()); // targetShard
-  }
-
-  // Resort to handling only 1D grids since the general case is complicated if
-  // it needs to be communication efficient in terms of minimizing the data
-  // transfered between devices.
-  return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
-                         sourceUnshardedValue, sourceShard);
-}
-
 TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue) {
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 4ac4a69..6161c13 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -148,6 +148,28 @@
     return %0 : memref<3x4xf64>
   }
 
+  // CHECK-LABEL: func @allgather_tensor_0
+  // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+  func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+    // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+    // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
+    // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
+    // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
+    // CHECK: cf.assert [[v3]]
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x3x4xf32>
+    // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<4x3x4xf32>
+    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1], [2]] : memref<4x3x4xf32> into memref<12x4xf32>
+    // CHECK: [[v4:%.*]] = bufferization.to_tensor [[vcollapse_shape]] restrict : memref<12x4xf32> to tensor<12x4xf32>
+    %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0 : tensor<3x4xf32> -> tensor<12x4xf32>
+    // CHECK: return [[v4]] : tensor<12x4xf32>
+    return %0 : tensor<12x4xf32>
+  }
+
   // CHECK-LABEL: func @allgather_tensor
   func.func @allgather_tensor(
       // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
@@ -155,28 +177,22 @@
       %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
-    // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
+    // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
     // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
     // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
     // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
-    // CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index
-    // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x3x?xf32>
-    // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
-    // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
-    // CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] {
-      // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
-        // CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] {
-          // CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
-          // CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32>
-        // CHECK: }
-      // CHECK: }
-    // CHECK: }
-    // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
-    // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32>
+    // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc5]] : index
+    // CHECK: cf.assert [[v3]]
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
+    // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
+    // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
+    // CHECK: [[v5:%.*]] = tensor.empty() : tensor<3x5x4xf32>
+    // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<5x3x4xf32>) outs([[v5]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
+    // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
     %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
-    // CHECK: return [[v4]] : tensor<3x20xf32>
+    // CHECK: return [[vcollapsed]] : tensor<3x20xf32>
     return %0 : tensor<3x20xf32>
   }
 
@@ -185,28 +201,24 @@
       // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
       // CHECK-SAME: -> memref<3x20xf32>
       %arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
-    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
-    // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+    // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
     // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
     // CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index
-    // CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index
-    // CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref<?x3x?xf32>
-    // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
-    // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
-    // CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] {
-      // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
-        // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
-          // CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
-          // CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32>
-        // CHECK: }
-      // CHECK: }
-    // CHECK: }
-    // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
+    // CHECK: [[v2:%.*]] = arith.cmpi eq, [[v1]], [[vc5]] : index
+    // CHECK: cf.assert [[v2]]
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
+    // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
+    // CHECK: [[v3:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
+    // CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32>
+    // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2] 
+    // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
+    // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32>
     %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
-    // CHECK: return [[valloc_0]] : memref<3x20xf32>
+    // CHECK: return [[v5]] : memref<3x20xf32>
     return %0 : memref<3x20xf32>
   }
 }
@@ -377,9 +389,9 @@
 // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
 func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
   %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
-  // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
-  // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
-  // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
   // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
   // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
   // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@@ -397,10 +409,10 @@
 // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
 func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
   %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
-  // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
-  // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
-  // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
-  // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+  // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
   // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
   // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
   // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@@ -417,12 +429,12 @@
 // CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
 func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
   %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
-  // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
-  // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
-  // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
-  // CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
-  // CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
-  // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+  // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+  // CHECK-DAG: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+  // CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-DAG: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
   // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
   // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
   // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@@ -444,39 +456,38 @@
 // CHECK-LABEL: func.func @mlp_1dgrid(
 // CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
 func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
-  // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
   %cst = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
+  // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
   // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
   // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
-  // CHECK: [[vsize:%.*]] = mpi.comm_size
+  // CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
   // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
-  // CHECK: [[v3:%.*]] = arith.divsi
-  // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x512x?xf32>
-  // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<?x512x?xf32>
-  // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
-  // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
-    // CHECK: affine.for [[varg4:%.*]] = 0 to 512 {
-      // CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] {
-        // CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref<?x512x?xf32>
-        // CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32>
-  // CHECK: memref.dealloc [[valloc]] : memref<?x512x?xf32>
-  // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+  // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
+  // CHECK: cf.assert [[v3]]
+  // CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x512x512xf32>
+  // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<4x512x512xf32>
+  // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<4x512x512xf32> to tensor<4x512x512xf32>
+  // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x4x512xf32>
+  // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<4x512x512xf32>) outs([[v5]] : tensor<512x4x512xf32>) permutation = [1, 0, 2] 
+  // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<512x4x512xf32> into tensor<512x2048xf32>
   %all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
-  // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32>
+  // CHECK: [[v6:%.*]] = tensor.empty() : tensor<512x256xf32>
   %0 = tensor.empty() : tensor<512x256xf32>
-  // CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+  // CHECK: [[v7:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
   %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32>
-  // CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+  // CHECK: [[v8:%.*]] = linalg.matmul ins([[vcollapsed]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v7]] : tensor<512x256xf32>) -> tensor<512x256xf32>
   %2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32>
-  // CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
+  // CHECK: [[v9:%.*]] = tosa.sigmoid [[v8]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
   %3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32>
   %4 = tensor.empty() : tensor<512x2048xf32>
   %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
   %proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index
   %grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index
   %6 = arith.cmpi eq, %proc_linear_idx, %c0 : index
-  // CHECK: [[v14:%.*]] = scf.if
+  // CHECK: [[v15:%.*]] = scf.if
   %7 = scf.if %6 -> (tensor<512x2048xf32>) {
     scf.yield %5 : tensor<512x2048xf32>
   } else {
@@ -484,15 +495,15 @@
     %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
     scf.yield %10 : tensor<512x2048xf32>
   }
-  // CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  // CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
   %8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
-  // CHECK: [[v16:%.*]] = bufferization.to_buffer
-  // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32>
-  // CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>)
-  // CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm
-  // CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32>
-  // CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+  // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32>
+  // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
+  // CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>)
+  // CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm
+  // CHECK: mpi.allreduce([[valloc_0]], [[valloc_0]], MPI_SUM, [[v18]]) : memref<512x2048xf32>, memref<512x2048xf32>
+  // CHECK: [[v19:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
   %all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
-  // CHECK: return [[v18]] : tensor<512x2048xf32>
+  // CHECK: return [[v19]] : tensor<512x2048xf32>
   return %all_reduce : tensor<512x2048xf32>
 }
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index cd9fa22..4c8271a 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -4,6 +4,7 @@
 
 shard.grid @grid_1d(shape = 2)
 shard.grid @grid_1d_4(shape = 4)
+shard.grid @grid_2d_16(shape = 4x4)
 
 // CHECK-LABEL: func @return_sharding
 func.func @return_sharding(
@@ -318,9 +319,9 @@
   return %sharded_ret : tensor<6xi32>
 }
 
-// CHECK-LABEL: func.func @mlp_1dgrid
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
 // CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
-func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
+func.func @mlp_1d_weight_stationary(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
   // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
   %sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
   %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding