[NFC][mlir][shard] Unify MoveLastSplitAxisPattern/MoveLastSplitAxisPattern (#192295)
Made MoveLastSplitAxisPattern more general to also cover MoveLastSplitAxisPattern.
Less code, same functionality.
Assisted by claude.
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 57502fb..05b864d 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -271,106 +271,15 @@
return srcShape.cloneWith(tgtShape, srcShape.getElementType());
}
-/// Move a split axis between tensor dimensions:
-/// e.g. [[0], []] -> [[], [0]].
-class MoveSplitAxisPattern : public ReshardingPattern {
- // Detect if the resharding moves a single split axis from one tensor
- // dimension to another tensor dimension. If detected, returns the
- // corresponding (tgt_tensor_dim, grid_axis) pair.
- static std::optional<std::tuple<int64_t, GridAxis>>
- detect(const Sharding &srcSharding, const Sharding &tgtSharding,
- int64_t srcTensorDim) {
- if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
- return std::nullopt;
- auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
- if (srcAxes.size() != 1)
- return std::nullopt;
- for (size_t tgtTensorDim = 0;
- tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
- if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
- continue;
- auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
- if (tgtAxes.size() != 1 || srcAxes.front() != tgtAxes.front())
- continue;
- return std::make_tuple(static_cast<int64_t>(tgtTensorDim),
- srcAxes.front());
- }
- return std::nullopt;
- }
-
- static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
- int64_t srcTensorDim, int64_t tgtTensorDim) {
- SmallVector<GridAxesAttr> tgtShardingSplitAxes =
- llvm::to_vector(srcSharding.getSplitAxes());
- while (static_cast<int64_t>(tgtShardingSplitAxes.size()) <= tgtTensorDim) {
- tgtShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
- }
-
- auto srcSplitAxes =
- llvm::to_vector(tgtShardingSplitAxes[srcTensorDim].asArrayRef());
- assert(srcSplitAxes.size() == 1);
- auto gridAxis = srcSplitAxes.back();
- srcSplitAxes.pop_back();
- tgtShardingSplitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
-
- auto tgtSplitAxes =
- llvm::to_vector(tgtShardingSplitAxes[tgtTensorDim].asArrayRef());
- tgtSplitAxes.push_back(gridAxis);
- tgtShardingSplitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
-
- return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
- }
-
- static std::tuple<TypedValue<ShapedType>, Sharding>
- apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
- ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
- int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis gridAxis) {
- MLIRContext *ctx = builder.getContext();
- builder.setInsertionPointAfterValue(srcShard);
-
- Sharding resultSharding =
- tgtSharding(ctx, std::move(srcSharding), srcTensorDim, tgtTensorDim);
- ShapedType a2aResultShape =
- allToAllResultShape(srcShard.getType(), grid.getShape()[gridAxis],
- srcTensorDim, tgtTensorDim);
- Value allToAllResult = AllToAllOp::create(
- builder,
- RankedTensorType::get(a2aResultShape.getShape(),
- a2aResultShape.getElementType()),
- grid.getSymName(), SmallVector<GridAxis>({gridAxis}), srcShard,
- APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
- ShapedType tgtShape =
- shardShapedType(srcUnshardedType, grid, resultSharding);
- TypedValue<ShapedType> tgtShard =
- tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
- return {tgtShard, resultSharding};
- }
-
-public:
- std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
- tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
- const Sharding &srcSharding, const Sharding &tgtSharding,
- ShapedType srcUnshardedType,
- TypedValue<ShapedType> srcShard) override {
- if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
- return std::nullopt;
- if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
- auto [tgtTensorDim, gridAxis] = detectRes.value();
- return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
- tensorDim, tgtTensorDim, gridAxis);
- }
- return std::nullopt;
- }
-};
-
/// Move the last split axis of one tensor dimension to the front of another
-/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+/// tensor dimension's split axes, e.g. [[0], []] -> [[], [0]] or
+/// [[0, 1], [2]] -> [[0], [1, 2]].
class MoveLastSplitAxisPattern : public ReshardingPattern {
// Detect if the resharding moves the last grid axis of srcTensorDim to the
// front of another tensor dimension's split axes. If detected, returns
// (tgtTensorDim, movedGridAxis).
//
- // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an] (n >= 2)
+ // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an] (n >= 1)
// tgt[srcTensorDim] = [a1,...,a(n-1)]
// src[tgtTensorDim] = [b1,...,bm] (m >= 0)
// tgt[tgtTensorDim] = [an, b1,...,bm]
@@ -380,8 +289,8 @@
if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
return std::nullopt;
auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
- // Need at least 2 axes to move the last one.
- if (srcAxes.size() < 2)
+ // Need at least 1 axis to move.
+ if (srcAxes.empty())
return std::nullopt;
// After the move the source tensor dim should lose its last axis.
@@ -586,12 +495,11 @@
// Each pattern's tryApply checks its own applicability preconditions.
static UpdateHaloPattern updateHaloPattern;
static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
- static MoveSplitAxisPattern moveSplitAxisPattern;
static SplitLastAxisPattern splitLastAxisPattern;
static UnsplitLastAxesPattern unsplitLastAxesPattern;
static ReshardingPattern *patterns[] = {
- &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern,
- &splitLastAxisPattern, &unsplitLastAxesPattern};
+ &updateHaloPattern, &moveLastSplitAxisPattern, &splitLastAxisPattern,
+ &unsplitLastAxesPattern};
TypedValue<ShapedType> currentShard = shardedSrc;
Sharding currentSharding = srcSharding;
for (int64_t dim = 0;