[mlir][shard,mpi] Allowing 2d-grids and simplifying lowering shard.all_gather (#180243)
- fixing incorrect assertion and related function name
- MPI_comm_split is not pure
- simplifying/standardizing permutation in all_gather
---------
Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d9e47ea..7e68b15 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -103,7 +103,7 @@
// CommSplitOp
//===----------------------------------------------------------------------===//
-def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
+def MPI_CommSplitOp : MPI_Op<"comm_split"> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index c765ad5..1db14e6 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -16,7 +16,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -624,9 +624,8 @@
// shard.allgather concatenates along a specified gather-axis.
// mpi.allgather always concatenates along the first dimension and
// there is no MPI operation that allows gathering along an arbitrary axis.
- // Hence, if gather-axis!=0, we need to create a temporary buffer
- // where we gather along the first dimension and then copy from that
- // buffer to the final output along the specified gather-axis.
+ // Hence, if gather-axis != 0, we need to permute the output buffer
+ // accordingly.
LogicalResult
matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
@@ -635,104 +634,124 @@
FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
if (failed(gridOp))
return failure();
- ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
- Value input = getAsMemref(adaptor.getInput(), iBuilder);
+
+ ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
+ Value input = getAsMemref(adaptor.getInput(), ib);
MemRefType inType = cast<MemRefType>(input.getType());
- if (!memref::isStaticShapeAndContiguousRowMajor(inType))
- return op.emitError(
- "Expected static shaped memref in contiguous row-major layout.");
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
- if (!memref::isStaticShapeAndContiguousRowMajor(outType))
- return op.emitError(
- "Expected static shaped memref in contiguous row-major layout.");
+ auto inputShape = inType.getShape();
+ auto outputShape = outType.getShape();
int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
- auto ctx = op->getContext();
+ int64_t inputDimOnAxis = inputShape[gatherAxis];
+ int64_t outputDimOnAxis = outputShape[gatherAxis];
- // Get the right communicator
- Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
-
- Value nRanks =
- mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
- .getSize();
- nRanks =
- arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
-
- Value tmpOutput, gatherDimSz;
- if (gatherAxis == 0) {
- tmpOutput = memref::AllocOp::create(iBuilder, outType);
- } else {
- // MPI's allgather always concatenates along the first dimension.
- // Create a memref type for the output buffer with adjusted (expanded)
- // shape.
- SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
- llvm::append_range(gatherShape, outType.getShape());
- gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
- MemRefType gatherType =
- MemRefType::get(gatherShape, outType.getElementType());
- gatherDimSz = arith::ConstantIndexOp::create(
- iBuilder, outType.getDimSize(gatherAxis));
- gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
- gatherDimSz, nRanks);
- // Allocate output buffer
- tmpOutput =
- memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
+ for (size_t i = 0; i < outputShape.size(); ++i)
+ if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis)
+ return op.emitError(
+ "Result and input shapes must match along non-gather axes.");
+ if (inputDimOnAxis == 0)
+ return op.emitError("Input size along the gather axis must be non-zero.");
+ if (inputDimOnAxis == 1) {
+ assert(outputDimOnAxis == inputDimOnAxis);
+ rewriter.replaceOp(op, adaptor.getInput());
+ return success();
}
+ if (outputDimOnAxis % inputDimOnAxis != 0)
+ return op.emitError("Result size along the gather axis must be an exact "
+ "multiple of the input size along the gather axis.");
+
+ if (!memref::isStaticShapeAndContiguousRowMajor(inType) ||
+ !memref::isStaticShapeAndContiguousRowMajor(outType))
+ return op.emitError("Input/result must be statically shaped memrefs in "
+ "contiguous row-major layout.");
+
+ // Get the right communicator.
+ Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
+ Value nRanksV =
+ mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
+ nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
+ int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
+ Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks);
+ Value notError =
+ arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
+ cf::AssertOp::create(ib, notError,
+ "Expected number of ranks in the communicator to "
+ "match the output size along the gather axis divided "
+ "by the input size along the gather axis.");
+
+ // mpi.allgather always concatenates along the first dimension, so
+ // get a output buffer of shape {nRanks, dim0, ...}.
+ SmallVector<int64_t> gatherShape;
+ gatherShape.emplace_back(nRanks);
+ gatherShape.append(inputShape.begin(), inputShape.end());
+ auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
+ Value finalOutput = memref::AllocOp::create(ib, gatherType);
// Create the MPI AllGather operation.
- mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
+ mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm);
- // If gather-axis!=0, copy from gathered buffer to output with the right
- // layout.
- Value finalOutput = tmpOutput;
- if (gatherAxis != 0) {
- int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
- assert(nSrcDims == outType.getRank() + 1 &&
- "Expected gathered type to have rank one more than output type.");
+ if (gatherAxis == 0) {
+ // If gather axis == 0, simply collapse the first 2 dims from {nRanks,
+ // dim0, ...} to {nRanks*dim0, ...}.
+ SmallVector<ReassociationIndices> reassociation;
+ reassociation.push_back({0, 1});
+ int64_t numGatherDims = gatherShape.size();
+ for (int64_t i = 2; i < numGatherDims; ++i)
+ reassociation.push_back({i});
+ finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
+ reassociation);
- // Create affine map for copying from gathered buffer to output.
- SmallVector<AffineExpr> dims;
- dims.reserve(nSrcDims);
- for (unsigned i = 0; i < nSrcDims; ++i)
- dims.emplace_back(getAffineDimExpr(i, ctx));
- AffineExpr s = getAffineSymbolExpr(0, ctx);
- SmallVector<AffineExpr> results;
- results.reserve(nSrcDims);
- for (unsigned i = 0; i < nSrcDims - 1; ++i) {
- if (i == gatherAxis)
- results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
- else
- results.emplace_back(dims[i + 1]);
+ // If the op's result is a tensor, cast it to a tensor.
+ if (isa<RankedTensorType>(op.getType()))
+ finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
+ finalOutput, true);
+ } else {
+ // 1. Enter tensor-land.
+ auto inType =
+ RankedTensorType::get(gatherShape, outType.getElementType());
+ finalOutput =
+ bufferization::ToTensorOp::create(ib, inType, finalOutput, true);
+
+ // 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...}
+ // to {dim0, ..., nRanks, dim1,...}.
+ SmallVector<int64_t> outShapePermuted, permutation;
+ for (int i = 1; i <= gatherAxis; ++i) {
+ outShapePermuted.emplace_back(gatherShape[i]);
+ permutation.emplace_back(i);
}
- auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
+ outShapePermuted.emplace_back(gatherShape[0]);
+ permutation.emplace_back(0);
+ for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
+ outShapePermuted.emplace_back(gatherShape[i]);
+ permutation.emplace_back(i);
+ }
+ Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
+ outType.getElementType());
+ finalOutput =
+ linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
+ ->getResult(0);
- finalOutput = memref::AllocOp::create(iBuilder, outType);
+ // 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...}
+ // to {dim0, ..., nRanks*gatherAxis, ...}.
+ SmallVector<ReassociationIndices> reassociation;
+ for (int64_t i = 0; i < gatherAxis; ++i) {
+ reassociation.push_back({i});
+ }
+ reassociation.push_back({gatherAxis, gatherAxis + 1});
+ for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
+ ++i) {
+ reassociation.push_back({i});
+ }
+ auto outTType =
+ RankedTensorType::get(outputShape, outType.getElementType());
+ finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
+ reassociation);
- // Now build a loop nest to copy from gathered buffer to finalOutput
- // It would be nicer to just use a memref.transpose/collapse_shape op but
- // these currently only support simpler cases.
- Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
- SmallVector<Value> lbs(nSrcDims, zero);
- SmallVector<Value> ubs;
- for (int64_t d = 0; d < nSrcDims; ++d)
- ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
- SmallVector<int64_t> steps(nSrcDims, 1);
- auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
- Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
- // set symbol value
- SmallVector<Value> ivss(ivs.begin(), ivs.end());
- ivss.emplace_back(gatherDimSz);
- affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
- ivss);
- };
- affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
- emitCopy);
-
- memref::DeallocOp::create(iBuilder, tmpOutput);
+ // 4. Cast back to memref if needed.
+ if (isa<MemRefType>(op.getType()))
+ finalOutput =
+ bufferization::ToBufferOp::create(ib, outType, finalOutput);
}
- // If the destination is a tensor, cast it to a tensor
- if (isa<RankedTensorType>(op.getType()))
- finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
- finalOutput, true);
rewriter.replaceOp(op, finalOutput);
return success();
}
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 62dc8f5..e619c70 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -436,24 +436,32 @@
targetSharding);
}
-// Handles only resharding on a 1D shard.
-// Currently the sharded tensor axes must be exactly divisible by the single
-// grid axis size.
+// In most cases the sharded tensor axes must be exactly divisible by the single
+// grid axis size. Only halo size changes can deal with non-divisible cases.
static TypedValue<ShapedType>
-reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
- const Sharding &sourceSharding, const Sharding &targetSharding,
- TypedValue<ShapedType> sourceUnshardedValue,
- TypedValue<ShapedType> sourceShard) {
+reshard(ImplicitLocOpBuilder &builder, GridOp grid,
+ const Sharding &sourceSharding, const Sharding &targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ // If source and destination sharding are the same, no need to do anything.
+ if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
+ isFullReplication(targetSharding))) {
+ return sourceShard;
+ }
+
+ // Tries to handle the case where the resharding is needed because the halo
+ // sizes are different. Supports arbitrary grid dimensionality.
+ if (auto tryRes = tryUpdateHaloInResharding(
+ builder, grid, sourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), sourceShard)) {
+ return std::get<0>(tryRes.value()); // targetShard
+ }
+
assert(sourceShard.getType() ==
shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
[[maybe_unused]] ShapedType targetShardType =
shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
assert(sourceShard.getType().getRank() == targetShardType.getRank());
- assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
-
- if (sourceSharding == targetSharding) {
- return sourceShard;
- }
TypedValue<ShapedType> targetShard;
Sharding actualTargetSharding;
@@ -475,38 +483,13 @@
std::tie(targetShard, actualTargetSharding) = tryRes.value();
}
}
+
assert(targetShard && "Did not find any pattern to apply.");
assert(actualTargetSharding == targetSharding);
assert(targetShard.getType() == targetShardType);
return targetShard;
}
-static TypedValue<ShapedType>
-reshard(ImplicitLocOpBuilder &builder, GridOp grid,
- const Sharding &sourceSharding, const Sharding &targetSharding,
- TypedValue<ShapedType> sourceUnshardedValue,
- TypedValue<ShapedType> sourceShard) {
- // If source and destination sharding are the same, no need to do anything.
- if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) &&
- isFullReplication(targetSharding))) {
- return sourceShard;
- }
-
- // Tries to handle the case where the resharding is needed because the halo
- // sizes are different. Supports arbitrary grid dimensionality.
- if (auto tryRes = tryUpdateHaloInResharding(
- builder, grid, sourceSharding, targetSharding,
- sourceUnshardedValue.getType(), sourceShard)) {
- return std::get<0>(tryRes.value()); // targetShard
- }
-
- // Resort to handling only 1D grids since the general case is complicated if
- // it needs to be communication efficient in terms of minimizing the data
- // transfered between devices.
- return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
- sourceUnshardedValue, sourceShard);
-}
-
TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 4ac4a69..6161c13 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -148,6 +148,28 @@
return %0 : memref<3x4xf64>
}
+ // CHECK-LABEL: func @allgather_tensor_0
+ // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+ func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+ // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
+ // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
+ // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
+ // CHECK: cf.assert [[v3]]
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x3x4xf32>
+ // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<4x3x4xf32>
+ // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1], [2]] : memref<4x3x4xf32> into memref<12x4xf32>
+ // CHECK: [[v4:%.*]] = bufferization.to_tensor [[vcollapse_shape]] restrict : memref<12x4xf32> to tensor<12x4xf32>
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0 : tensor<3x4xf32> -> tensor<12x4xf32>
+ // CHECK: return [[v4]] : tensor<12x4xf32>
+ return %0 : tensor<12x4xf32>
+ }
+
// CHECK-LABEL: func @allgather_tensor
func.func @allgather_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
@@ -155,28 +177,22 @@
%arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
- // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
+ // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
- // CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index
- // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x3x?xf32>
- // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
- // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
- // CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] {
- // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
- // CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] {
- // CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
- // CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32>
- // CHECK: }
- // CHECK: }
- // CHECK: }
- // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
- // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32>
+ // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc5]] : index
+ // CHECK: cf.assert [[v3]]
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
+ // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
+ // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
+ // CHECK: [[v5:%.*]] = tensor.empty() : tensor<3x5x4xf32>
+ // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<5x3x4xf32>) outs([[v5]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
+ // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
- // CHECK: return [[v4]] : tensor<3x20xf32>
+ // CHECK: return [[vcollapsed]] : tensor<3x20xf32>
return %0 : tensor<3x20xf32>
}
@@ -185,28 +201,24 @@
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
// CHECK-SAME: -> memref<3x20xf32>
%arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
- // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
- // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
// CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index
- // CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index
- // CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref<?x3x?xf32>
- // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
- // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
- // CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] {
- // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
- // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
- // CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
- // CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32>
- // CHECK: }
- // CHECK: }
- // CHECK: }
- // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
+ // CHECK: [[v2:%.*]] = arith.cmpi eq, [[v1]], [[vc5]] : index
+ // CHECK: cf.assert [[v2]]
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<5x3x4xf32>
+ // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<5x3x4xf32>
+ // CHECK: [[v3:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<5x3x4xf32> to tensor<5x3x4xf32>
+ // CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32>
+ // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
+ // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
+ // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
- // CHECK: return [[valloc_0]] : memref<3x20xf32>
+ // CHECK: return [[v5]] : memref<3x20xf32>
return %0 : memref<3x20xf32>
}
}
@@ -377,9 +389,9 @@
// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
- // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
- // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
- // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@@ -397,10 +409,10 @@
// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
- // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
- // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
- // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
- // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+ // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@@ -417,12 +429,12 @@
// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
%sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
- // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
- // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
- // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
- // CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
- // CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
- // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-DAG: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+ // CHECK-DAG: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+ // CHECK-DAG: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+ // CHECK-DAG: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-DAG: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-DAG: [[vcm1_i16:%.*]] = arith.constant -1 : i16
// CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
// CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
// CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
@@ -444,39 +456,38 @@
// CHECK-LABEL: func.func @mlp_1dgrid(
// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
- // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
%cst = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vsize:%.*]] = mpi.comm_size
+ // CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
- // CHECK: [[v3:%.*]] = arith.divsi
- // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x512x?xf32>
- // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<?x512x?xf32>
- // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
- // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
- // CHECK: affine.for [[varg4:%.*]] = 0 to 512 {
- // CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] {
- // CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref<?x512x?xf32>
- // CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32>
- // CHECK: memref.dealloc [[valloc]] : memref<?x512x?xf32>
- // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+ // CHECK: [[v3:%.*]] = arith.cmpi eq, [[v2]], [[vc4]] : index
+ // CHECK: cf.assert [[v3]]
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<4x512x512xf32>
+ // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<4x512x512xf32>
+ // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<4x512x512xf32> to tensor<4x512x512xf32>
+ // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x4x512xf32>
+ // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v4]] : tensor<4x512x512xf32>) outs([[v5]] : tensor<512x4x512xf32>) permutation = [1, 0, 2]
+ // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<512x4x512xf32> into tensor<512x2048xf32>
%all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
- // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32>
+ // CHECK: [[v6:%.*]] = tensor.empty() : tensor<512x256xf32>
%0 = tensor.empty() : tensor<512x256xf32>
- // CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+ // CHECK: [[v7:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32>
- // CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+ // CHECK: [[v8:%.*]] = linalg.matmul ins([[vcollapsed]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v7]] : tensor<512x256xf32>) -> tensor<512x256xf32>
%2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32>
- // CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
+ // CHECK: [[v9:%.*]] = tosa.sigmoid [[v8]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
%3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32>
%4 = tensor.empty() : tensor<512x2048xf32>
%5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
%proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index
%grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index
%6 = arith.cmpi eq, %proc_linear_idx, %c0 : index
- // CHECK: [[v14:%.*]] = scf.if
+ // CHECK: [[v15:%.*]] = scf.if
%7 = scf.if %6 -> (tensor<512x2048xf32>) {
scf.yield %5 : tensor<512x2048xf32>
} else {
@@ -484,15 +495,15 @@
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
scf.yield %10 : tensor<512x2048xf32>
}
- // CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ // CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
%8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
- // CHECK: [[v16:%.*]] = bufferization.to_buffer
- // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32>
- // CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>)
- // CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32>
- // CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+ // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32>
+ // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
+ // CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>)
+ // CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc_0]], [[valloc_0]], MPI_SUM, [[v18]]) : memref<512x2048xf32>, memref<512x2048xf32>
+ // CHECK: [[v19:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
%all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
- // CHECK: return [[v18]] : tensor<512x2048xf32>
+ // CHECK: return [[v19]] : tensor<512x2048xf32>
return %all_reduce : tensor<512x2048xf32>
}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index cd9fa22..4c8271a 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -4,6 +4,7 @@
shard.grid @grid_1d(shape = 2)
shard.grid @grid_1d_4(shape = 4)
+shard.grid @grid_2d_16(shape = 4x4)
// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
@@ -318,9 +319,9 @@
return %sharded_ret : tensor<6xi32>
}
-// CHECK-LABEL: func.func @mlp_1dgrid
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
-func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
+func.func @mlp_1d_weight_stationary(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
// CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
%sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
%sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding