| //===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a translation of Shard communication ops to MPI ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/ShardToMPI/ShardToMPI.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/MPI/IR/MPI.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Shard/IR/ShardDialect.h" |
| #include "mlir/Dialect/Shard/IR/ShardOps.h" |
| #include "mlir/Dialect/Shard/Transforms/Simplifications.h" |
| #include "mlir/Dialect/Shard/Transforms/Transforms.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "shard-to-mpi" |
| #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace shard; |
| |
| namespace { |
| /// Converts a vector of OpFoldResults (ints) into vector of Values of the |
| /// provided type. |
| static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc, |
| llvm::ArrayRef<int64_t> statics, |
| ValueRange dynamics, |
| Type type = Type()) { |
| SmallVector<Value> values; |
| auto dyn = dynamics.begin(); |
| Type i64 = b.getI64Type(); |
| if (!type) |
| type = i64; |
| assert((i64 == type || b.getIndexType() == type) && |
| "expected an i64 or an intex type"); |
| for (auto s : statics) { |
| if (s == ShapedType::kDynamic) { |
| values.emplace_back(*(dyn++)); |
| } else { |
| TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); |
| values.emplace_back(arith::ConstantOp::create(b, loc, type, val)); |
| } |
| } |
| return values; |
| } |
| |
| /// Create operations converting a linear index to a multi-dimensional index. |
| static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, |
| Value linearIndex, |
| ValueRange dimensions) { |
| int n = dimensions.size(); |
| SmallVector<Value> multiIndex(n); |
| |
| for (int i = n - 1; i >= 0; --i) { |
| multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]); |
| if (i > 0) |
| linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]); |
| } |
| |
| return multiIndex; |
| } |
| |
| /// Create operations converting a multi-dimensional index to a linear index. |
| Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, |
| ValueRange dimensions) { |
| |
| Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0); |
| Value stride = arith::ConstantIndexOp::create(b, loc, 1); |
| |
| for (int i = multiIndex.size() - 1; i >= 0; --i) { |
| Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride); |
| linearIndex = arith::AddIOp::create(b, loc, linearIndex, off); |
| stride = arith::MulIOp::create(b, loc, stride, dimensions[i]); |
| } |
| |
| return linearIndex; |
| } |
| |
| /// Replace GetShardingOp with related/dependent ShardingOp. |
| struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(GetShardingOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>(); |
| if (!shardOp) |
| return failure(); |
| auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>(); |
| if (!shardingOp) |
| return failure(); |
| |
| rewriter.replaceOp(op, shardingOp.getResult()); |
| return success(); |
| } |
| }; |
| |
| /// Convert a sharding op to a tuple of tensors of its components |
| /// (SplitAxes, HaloSizes, ShardedDimsOffsets) |
| /// as defined by type converter. |
| struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(ShardingOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto splitAxes = op.getSplitAxes().getAxes(); |
| int64_t maxNAxes = 0; |
| for (auto axes : splitAxes) |
| maxNAxes = std::max<int64_t>(maxNAxes, axes.size()); |
| |
| // To hold the split axes, create empty 2d tensor with shape |
| // {splitAxes.size(), max-size-of-split-groups}. |
| // Set trailing elements for smaller split-groups to -1. |
| Location loc = op.getLoc(); |
| auto i16 = rewriter.getI16Type(); |
| auto i64 = rewriter.getI64Type(); |
| std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()), |
| maxNAxes}; |
| Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16); |
| auto attr = IntegerAttr::get(i16, -1); |
| Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr); |
| resSplitAxes = |
| linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes) |
| .getResult(0); |
| |
| // explicitly write values into tensor row by row |
| std::array<int64_t, 2> strides = {1, 1}; |
| int64_t nSplits = 0; |
| ValueRange empty = {}; |
| for (auto [i, axes] : llvm::enumerate(splitAxes)) { |
| int64_t size = axes.size(); |
| if (size > 0) |
| ++nSplits; |
| std::array<int64_t, 2> offs = {(int64_t)i, 0}; |
| std::array<int64_t, 2> sizes = {1, size}; |
| auto tensorType = RankedTensorType::get({size}, i16); |
| auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); |
| auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs); |
| resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals, |
| resSplitAxes, empty, empty, |
| empty, offs, sizes, strides); |
| } |
| |
| // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. |
| // Store the halo sizes in the tensor. |
| SmallVector<Value> haloSizes = |
| getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(), |
| adaptor.getDynamicHaloSizes()); |
| auto type = RankedTensorType::get({nSplits, 2}, i64); |
| Value resHaloSizes = |
| haloSizes.empty() |
| ? rewriter |
| .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, |
| i64) |
| .getResult() |
| : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) |
| .getResult(); |
| |
| // To hold sharded dims offsets, create Tensor with shape {nSplits, |
| // maxSplitSize+1}. Store the offsets in the tensor but set trailing |
| // elements for smaller split-groups to -1. Computing the max size of the |
| // split groups needs using collectiveProcessGroupSize (which needs the |
| // GridOp) |
| Value resOffsets; |
| if (adaptor.getStaticShardedDimsOffsets().empty()) { |
| resOffsets = tensor::EmptyOp::create(rewriter, loc, |
| std::array<int64_t, 2>{0, 0}, i64); |
| } else { |
| SymbolTableCollection symbolTableCollection; |
| auto gridOp = getGrid(op, symbolTableCollection); |
| int64_t maxSplitSize = 0; |
| for (auto axes : splitAxes) { |
| int64_t splitSize = |
| collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); |
| assert(splitSize != ShapedType::kDynamic); |
| maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); |
| } |
| assert(maxSplitSize); |
| ++maxSplitSize; // add one for the total size |
| |
| resOffsets = tensor::EmptyOp::create( |
| rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64); |
| Value zero = arith::ConstantOp::create( |
| rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); |
| resOffsets = |
| linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0); |
| SmallVector<Value> offsets = |
| getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), |
| adaptor.getDynamicShardedDimsOffsets()); |
| int64_t curr = 0; |
| for (auto [i, axes] : llvm::enumerate(splitAxes)) { |
| int64_t splitSize = |
| collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); |
| assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); |
| ++splitSize; // add one for the total size |
| ArrayRef<Value> values(&offsets[curr], splitSize); |
| Value vals = tensor::FromElementsOp::create(rewriter, loc, values); |
| std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0}; |
| std::array<int64_t, 2> sizes = {1, splitSize}; |
| resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals, |
| resOffsets, empty, empty, |
| empty, offs, sizes, strides); |
| curr += splitSize; |
| } |
| } |
| |
| // return a tuple of tensors as defined by type converter |
| SmallVector<Type> resTypes; |
| if (failed(getTypeConverter()->convertType(op.getResult().getType(), |
| resTypes))) |
| return failure(); |
| |
| resSplitAxes = |
| tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes); |
| resHaloSizes = |
| tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes); |
| resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets); |
| |
| rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( |
| op, TupleType::get(op.getContext(), resTypes), |
| ValueRange{resSplitAxes, resHaloSizes, resOffsets}); |
| |
| return success(); |
| } |
| }; |
| |
| struct ConvertProcessMultiIndexOp |
| : public OpConversionPattern<ProcessMultiIndexOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| // Currently converts its linear index to a multi-dimensional index. |
| |
| SymbolTableCollection symbolTableCollection; |
| Location loc = op.getLoc(); |
| auto gridOp = getGrid(op, symbolTableCollection); |
| // For now we only support static grid shapes |
| if (ShapedType::isDynamicShape(gridOp.getShape())) |
| return failure(); |
| |
| SmallVector<Value> dims; |
| llvm::transform( |
| gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { |
| return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); |
| }); |
| Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp); |
| auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); |
| |
| // optionally extract subset of grid axes |
| auto axes = adaptor.getAxes(); |
| if (!axes.empty()) { |
| SmallVector<Value> subIndex; |
| for (auto axis : axes) { |
| subIndex.emplace_back(mIdx[axis]); |
| } |
| mIdx = std::move(subIndex); |
| } |
| |
| rewriter.replaceOp(op, mIdx); |
| return success(); |
| } |
| }; |
| |
| class ConvertProcessLinearIndexOp |
| : public OpConversionPattern<ProcessLinearIndexOp> { |
| |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Create mpi::CommRankOp |
| Location loc = op.getLoc(); |
| auto ctx = op.getContext(); |
| Value commWorld = |
| mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); |
| auto rank = |
| rewriter |
| .create<mpi::CommRankOp>( |
| loc, |
| TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, |
| commWorld) |
| .getRank(); |
| rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), |
| rank); |
| return success(); |
| } |
| }; |
| |
| struct ConvertNeighborsLinearIndicesOp |
| : public OpConversionPattern<NeighborsLinearIndicesOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| // Computes the neighbors indices along a split axis by simply |
| // adding/subtracting 1 to the current index in that dimension. |
| // Assigns -1 if neighbor is out of bounds. |
| |
| auto axes = adaptor.getSplitAxes(); |
| // For now only single axis sharding is supported |
| if (axes.size() != 1) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| SymbolTableCollection symbolTableCollection; |
| auto gridOp = getGrid(op, symbolTableCollection); |
| auto mIdx = adaptor.getDevice(); |
| auto orgIdx = mIdx[axes[0]]; |
| SmallVector<Value> dims; |
| llvm::transform( |
| gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { |
| return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); |
| }); |
| Value dimSz = dims[axes[0]]; |
| Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); |
| Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1); |
| Value atBorder = |
| arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx, |
| arith::ConstantIndexOp::create(rewriter, loc, 0)); |
| auto down = scf::IfOp::create( |
| rewriter, loc, atBorder, |
| [&](OpBuilder &builder, Location loc) { |
| scf::YieldOp::create(builder, loc, minus1); |
| }, |
| [&](OpBuilder &builder, Location loc) { |
| SmallVector<Value> tmp = mIdx; |
| tmp[axes[0]] = |
| arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one) |
| .getResult(); |
| scf::YieldOp::create(builder, loc, |
| multiToLinearIndex(loc, rewriter, tmp, dims)); |
| }); |
| atBorder = arith::CmpIOp::create( |
| rewriter, loc, arith::CmpIPredicate::sge, orgIdx, |
| arith::SubIOp::create(rewriter, loc, dimSz, one).getResult()); |
| auto up = scf::IfOp::create( |
| rewriter, loc, atBorder, |
| [&](OpBuilder &builder, Location loc) { |
| scf::YieldOp::create(builder, loc, minus1); |
| }, |
| [&](OpBuilder &builder, Location loc) { |
| SmallVector<Value> tmp = mIdx; |
| tmp[axes[0]] = |
| arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one); |
| scf::YieldOp::create(builder, loc, |
| multiToLinearIndex(loc, rewriter, tmp, dims)); |
| }); |
| rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); |
| return success(); |
| } |
| }; |
| |
| struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); |
| if (!sharding) { |
| return op->emitError() |
| << "Expected ShardingOp as defining op for sharding" |
| << " but found " << adaptor.getSharding()[0].getDefiningOp(); |
| } |
| |
| // Compute the sharded shape by applying the sharding to the input shape. |
| // If shardedDimsOffsets is not defined in the sharding, the shard shape is |
| // computed by dividing the dimension size by the number of shards in that |
| // dimension (which is given by the size of the grid axes provided in |
| // split-axes). Odd elements get distributed to trailing shards. If a |
| // shardedDimsOffsets is provided, the shard shape is computed by |
| // subtracting the offset of the current shard from the offset of the next |
| // shard. |
| |
| Location loc = op.getLoc(); |
| Type index = rewriter.getIndexType(); |
| |
| // This is a 1:N conversion because the sharding op is a 1:3 conversion. |
| // The operands in the adaptor are a vector<ValeRange>. For dims and device |
| // we have a 1:1 conversion. |
| // For simpler access fill a vector with the dynamic dims. |
| SmallVector<Value> dynDims, dynDevice; |
| for (auto dim : adaptor.getDimsDynamic()) { |
| // type conversion should be 1:1 for ints |
| dynDims.emplace_back(llvm::getSingleElement(dim)); |
| } |
| // same for device |
| for (auto device : adaptor.getDeviceDynamic()) { |
| dynDevice.emplace_back(llvm::getSingleElement(device)); |
| } |
| |
| // To keep the code simple, convert dims/device to values when they are |
| // attributes. Count on canonicalization to fold static values. |
| SmallVector<Value> shape = |
| getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index); |
| SmallVector<Value> multiIdx = |
| getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); |
| |
| // Get the GridOp, the grid shape is needed to compute the sharded shape. |
| SymbolTableCollection symbolTableCollection; |
| auto gridOp = getGrid(sharding, symbolTableCollection); |
| // For now we only support static grid shapes |
| if (ShapedType::isDynamicShape(gridOp.getShape())) |
| return failure(); |
| |
| auto splitAxes = sharding.getSplitAxes().getAxes(); |
| // shardedDimsOffsets are optional and might be Values (not attributes). |
| // Also, the shardId might be dynamic which means the position in the |
| // shardedDimsOffsets is not statically known. Create a tensor of the |
| // shardedDimsOffsets and later extract the offsets for computing the |
| // local shard-size. |
| Value shardedDimsOffs; |
| { |
| SmallVector<Value> tmp = getMixedAsValues( |
| rewriter, loc, sharding.getStaticShardedDimsOffsets(), |
| sharding.getDynamicShardedDimsOffsets(), index); |
| if (!tmp.empty()) |
| shardedDimsOffs = tensor::FromElementsOp::create( |
| rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index), |
| tmp); |
| } |
| |
| // With static grid shape the sizes of the split axes are known. |
| // Hence the start/pos for each split axes in shardDimsOffsets can be |
| // computed statically. |
| int64_t pos = 0; |
| SmallVector<Value> shardShape; |
| Value zero = |
| arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index)); |
| Value one = |
| arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index)); |
| |
| // Iterate over the dimensions of the tensor shape, get their split Axes, |
| // and compute the sharded shape. |
| for (auto [i, dim] : llvm::enumerate(shape)) { |
| // Trailing dimensions might not be annotated. |
| if (i < splitAxes.size() && !splitAxes[i].empty()) { |
| auto axes = splitAxes[i]; |
| // The current dimension might not be sharded. |
| // Create a value from the static position in shardDimsOffsets. |
| Value posVal = arith::ConstantOp::create(rewriter, loc, |
| rewriter.getIndexAttr(pos)); |
| // Get the index of the local shard in the grid axis. |
| Value idx = multiIdx[axes[0]]; |
| auto numShards = |
| collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); |
| if (shardedDimsOffs) { |
| // If sharded dims offsets are provided, use them to compute the |
| // sharded shape. |
| if (axes.size() > 1) { |
| return op->emitError() << "Only single axis sharding is " |
| << "supported for each dimension."; |
| } |
| idx = arith::AddIOp::create(rewriter, loc, posVal, idx); |
| // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. |
| Value off = |
| tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); |
| idx = arith::AddIOp::create(rewriter, loc, idx, one); |
| Value nextOff = |
| tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); |
| Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off); |
| shardShape.emplace_back(sz); |
| } else { |
| Value numShardsVal = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getIndexAttr(numShards)); |
| // Compute shard dim size by distributing odd elements to trailing |
| // shards: |
| // sz = dim / numShards |
| // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) |
| Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal); |
| Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal); |
| sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1); |
| auto cond = arith::CmpIOp::create( |
| rewriter, loc, arith::CmpIPredicate::sge, idx, sz1); |
| Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero); |
| sz = arith::AddIOp::create(rewriter, loc, sz, odd); |
| shardShape.emplace_back(sz); |
| } |
| pos += numShards + 1; // add one for the total size. |
| } // else no sharding if split axis is empty or no split axis |
| // If no size was added -> no sharding in this dimension. |
| if (shardShape.size() <= i) |
| shardShape.emplace_back(dim); |
| } |
| assert(shardShape.size() == shape.size()); |
| rewriter.replaceOp(op, shardShape); |
| return success(); |
| } |
| }; |
| |
| static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) { |
| auto ctx = kind.getContext(); |
| auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) { |
| return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp); |
| }; |
| |
| switch (kind.getValue()) { |
| case ReductionKind::Sum: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM); |
| case ReductionKind::Product: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD); |
| case ReductionKind::Min: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN); |
| case ReductionKind::Max: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX); |
| case ReductionKind::BitwiseAnd: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND); |
| case ReductionKind::BitwiseOr: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR); |
| case ReductionKind::BitwiseXor: |
| return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR); |
| default: |
| llvm_unreachable("Unknown/unsupported reduction kind"); |
| } |
| } |
| |
| struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| SymbolTableCollection symbolTableCollection; |
| auto grid = adaptor.getGrid(); |
| mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); |
| if (!gridOp) |
| return op->emitError() << "No grid found for AllReduceOp"; |
| if (ShapedType::isDynamicShape(gridOp.getShape())) |
| return op->emitError() |
| << "Dynamic grid shape not supported in AllReduceOp"; |
| |
| ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); |
| Value input = adaptor.getInput(); |
| auto inputShape = cast<ShapedType>(input.getType()).getShape(); |
| |
| // If the source is a memref, cast it to a tensor. |
| if (isa<RankedTensorType>(input.getType())) { |
| auto memrefType = MemRefType::get( |
| inputShape, cast<ShapedType>(input.getType()).getElementType()); |
| input = bufferization::ToBufferOp::create(iBuilder, memrefType, input); |
| } |
| MemRefType inType = cast<MemRefType>(input.getType()); |
| |
| // Get the actual shape to allocate the buffer. |
| SmallVector<OpFoldResult> shape(inType.getRank()); |
| for (auto i = 0; i < inType.getRank(); ++i) { |
| auto s = inputShape[i]; |
| if (ShapedType::isDynamic(s)) |
| shape[i] = memref::DimOp::create(iBuilder, input, s).getResult(); |
| else |
| shape[i] = iBuilder.getIndexAttr(s); |
| } |
| |
| // Allocate buffer and copy input to buffer. |
| Value buffer = memref::AllocOp::create( |
| iBuilder, shape, cast<ShapedType>(op.getType()).getElementType()); |
| linalg::CopyOp::create(iBuilder, input, buffer); |
| |
| // Get an MPI_Comm_split for the AllReduce operation. |
| // The color is the linear index of the process in the grid along the |
| // non-reduced axes. The key is the linear index of the process in the grid |
| // along the reduced axes. |
| SmallVector<Type> indexResultTypes(gridOp.getShape().size(), |
| iBuilder.getIndexType()); |
| SmallVector<Value> myMultiIndex = |
| ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) |
| .getResult(); |
| Value zero = arith::ConstantIndexOp::create(iBuilder, 0); |
| SmallVector<Value> multiKey(myMultiIndex.size(), zero); |
| |
| auto redAxes = adaptor.getGridAxes(); |
| for (auto axis : redAxes) { |
| multiKey[axis] = myMultiIndex[axis]; |
| myMultiIndex[axis] = zero; |
| } |
| |
| Value color = |
| createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); |
| color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); |
| Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); |
| key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); |
| |
| // Finally split the communicator |
| auto commType = mpi::CommType::get(op->getContext()); |
| Value commWorld = mpi::CommWorldOp::create(iBuilder, commType); |
| auto comm = |
| mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key) |
| .getNewcomm(); |
| |
| Value buffer1d = buffer; |
| // Collapse shape to 1d if needed |
| if (inType.getRank() > 1) { |
| ReassociationIndices reassociation(inType.getRank()); |
| std::iota(reassociation.begin(), reassociation.end(), 0); |
| buffer1d = memref::CollapseShapeOp::create( |
| iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation)); |
| } |
| |
| // Create the MPI AllReduce operation. |
| mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d, |
| getMPIReductionOp(adaptor.getReductionAttr()), |
| comm); |
| |
| // If the destination is a memref, cast it to a tensor |
| if (isa<RankedTensorType>(op.getType())) |
| buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer, |
| true); |
| |
| rewriter.replaceOp(op, buffer); |
| return success(); |
| } |
| }; |
| |
| struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| // The input/output memref is assumed to be in C memory order. |
| // Halos are exchanged as 2 blocks per dimension (one for each side: down |
| // and up). For each haloed dimension `d`, the exchanged blocks are |
| // expressed as multi-dimensional subviews. The subviews include potential |
| // halos of higher dimensions `dh > d`, no halos for the lower dimensions |
| // `dl < d` and for dimension `d` the currently exchanged halo only. |
| // By iterating form higher to lower dimensions this also updates the halos |
| // in the 'corners'. |
| // memref.subview is used to read and write the halo data from and to the |
| // local data. Because subviews and halos can have mixed dynamic and static |
| // shapes, OpFoldResults are used whenever possible. |
| |
| auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(), |
| adaptor.getHaloSizes(), rewriter); |
| if (haloSizes.empty()) { |
| // no halos -> nothing to do |
| rewriter.replaceOp(op, adaptor.getDestination()); |
| return success(); |
| } |
| |
| SymbolTableCollection symbolTableCollection; |
| Location loc = op.getLoc(); |
| |
| // convert a OpFoldResult into a Value |
| auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { |
| if (auto value = dyn_cast<Value>(v)) |
| return value; |
| return arith::ConstantOp::create( |
| rewriter, loc, |
| rewriter.getIndexAttr( |
| cast<IntegerAttr>(cast<Attribute>(v)).getInt())); |
| }; |
| |
| auto dest = adaptor.getDestination(); |
| auto dstShape = cast<ShapedType>(dest.getType()).getShape(); |
| Value array = dest; |
| if (isa<RankedTensorType>(array.getType())) { |
| // If the destination is a memref, we need to cast it to a tensor |
| auto mmemrefType = MemRefType::get( |
| dstShape, cast<ShapedType>(array.getType()).getElementType()); |
| array = |
| bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array); |
| } |
| auto rank = cast<ShapedType>(array.getType()).getRank(); |
| auto opSplitAxes = adaptor.getSplitAxes().getAxes(); |
| auto grid = adaptor.getGrid(); |
| auto gridOp = getGrid(op, symbolTableCollection); |
| // subviews need Index values |
| for (auto &sz : haloSizes) { |
| if (auto value = dyn_cast<Value>(sz)) |
| sz = |
| rewriter |
| .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) |
| .getResult(); |
| } |
| |
| // most of the offset/size/stride data is the same for all dims |
| SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); |
| SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
| SmallVector<OpFoldResult> shape(rank), dimSizes(rank); |
| auto currHaloDim = -1; // halo sizes are provided for split dimensions only |
| // we need the actual shape to compute offsets and sizes |
| for (auto i = 0; i < rank; ++i) { |
| auto s = dstShape[i]; |
| if (ShapedType::isDynamic(s)) |
| shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult(); |
| else |
| shape[i] = rewriter.getIndexAttr(s); |
| |
| if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { |
| ++currHaloDim; |
| // the offsets for lower dim sstarts after their down halo |
| offsets[i] = haloSizes[currHaloDim * 2]; |
| |
| // prepare shape and offsets of highest dim's halo exchange |
| Value _haloSz = arith::AddIOp::create( |
| rewriter, loc, toValue(haloSizes[currHaloDim * 2]), |
| toValue(haloSizes[currHaloDim * 2 + 1])); |
| // the halo shape of lower dims exlude the halos |
| dimSizes[i] = |
| arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz) |
| .getResult(); |
| } else { |
| dimSizes[i] = shape[i]; |
| } |
| } |
| |
| auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something |
| auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr); |
| auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 |
| auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); |
| |
| SmallVector<Type> indexResultTypes(gridOp.getShape().size(), |
| rewriter.getIndexType()); |
| auto myMultiIndex = |
| ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid) |
| .getResult(); |
| // traverse all split axes from high to low dim |
| for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { |
| auto splitAxes = opSplitAxes[dim]; |
| if (splitAxes.empty()) |
| continue; |
| assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); |
| // Get the linearized ids of the neighbors (down and up) for the |
| // given split |
| auto tmp = rewriter |
| .create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex, |
| splitAxes) |
| .getResults(); |
| // MPI operates on i32... |
| Value neighbourIDs[2] = { |
| arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), |
| tmp[0]), |
| arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), |
| tmp[1])}; |
| |
| auto lowerRecvOffset = rewriter.getIndexAttr(0); |
| auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); |
| auto upperRecvOffset = |
| arith::SubIOp::create(rewriter, loc, toValue(shape[dim]), |
| toValue(haloSizes[currHaloDim * 2 + 1])); |
| auto upperSendOffset = arith::SubIOp::create( |
| rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); |
| |
| Value commWorld = mpi::CommWorldOp::create( |
| rewriter, loc, mpi::CommType::get(op->getContext())); |
| |
| // Make sure we send/recv in a way that does not lead to a dead-lock. |
| // The current approach is by far not optimal, this should be at least |
| // be a red-black pattern or using MPI_sendrecv. |
| // Also, buffers should be re-used. |
| // Still using temporary contiguous buffers for MPI communication... |
| // Still yielding a "serialized" communication pattern... |
| auto genSendRecv = [&](bool upperHalo) { |
| auto orgOffset = offsets[dim]; |
| dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] |
| : haloSizes[currHaloDim * 2]; |
| // Check if we need to send and/or receive |
| // Processes on the grid borders have only one neighbor |
| auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; |
| auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; |
| auto hasFrom = arith::CmpIOp::create( |
| rewriter, loc, arith::CmpIPredicate::sge, from, zero); |
| auto hasTo = arith::CmpIOp::create(rewriter, loc, |
| arith::CmpIPredicate::sge, to, zero); |
| auto buffer = memref::AllocOp::create( |
| rewriter, loc, dimSizes, |
| cast<ShapedType>(array.getType()).getElementType()); |
| // if has neighbor: copy halo data from array to buffer and send |
| scf::IfOp::create( |
| rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) { |
| offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) |
| : OpFoldResult(upperSendOffset); |
| auto subview = memref::SubViewOp::create( |
| builder, loc, array, offsets, dimSizes, strides); |
| memref::CopyOp::create(builder, loc, subview, buffer); |
| mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to, |
| commWorld); |
| scf::YieldOp::create(builder, loc); |
| }); |
| // if has neighbor: receive halo data into buffer and copy to array |
| scf::IfOp::create( |
| rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) { |
| offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) |
| : OpFoldResult(lowerRecvOffset); |
| mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from, |
| commWorld); |
| auto subview = memref::SubViewOp::create( |
| builder, loc, array, offsets, dimSizes, strides); |
| memref::CopyOp::create(builder, loc, buffer, subview); |
| scf::YieldOp::create(builder, loc); |
| }); |
| memref::DeallocOp::create(rewriter, loc, buffer); |
| offsets[dim] = orgOffset; |
| }; |
| |
| auto doSendRecv = [&](int upOrDown) { |
| OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; |
| Value haloSz = dyn_cast<Value>(v); |
| if (!haloSz) |
| haloSz = arith::ConstantOp::create( |
| rewriter, loc, |
| rewriter.getI32IntegerAttr( |
| cast<IntegerAttr>(cast<Attribute>(v)).getInt())); |
| auto hasSize = arith::CmpIOp::create( |
| rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero); |
| scf::IfOp::create(rewriter, loc, hasSize, |
| [&](OpBuilder &builder, Location loc) { |
| genSendRecv(upOrDown > 0); |
| scf::YieldOp::create(builder, loc); |
| }); |
| }; |
| |
| doSendRecv(0); |
| doSendRecv(1); |
| |
| // the shape for lower dims include higher dims' halos |
| dimSizes[dim] = shape[dim]; |
| // -> the offset for higher dims is always 0 |
| offsets[dim] = rewriter.getIndexAttr(0); |
| // on to next halo |
| --currHaloDim; |
| } |
| |
| if (isa<MemRefType>(op.getResult().getType())) { |
| rewriter.replaceOp(op, array); |
| } else { |
| assert(isa<RankedTensorType>(op.getResult().getType())); |
| rewriter.replaceOp(op, bufferization::ToTensorOp::create( |
| rewriter, loc, op.getResult().getType(), array, |
| /*restrict=*/true, /*writable=*/true)); |
| } |
| return success(); |
| } |
| }; |
| |
| struct ConvertShardToMPIPass |
| : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> { |
| using Base::Base; |
| |
| /// Run the dialect converter on the module. |
| void runOnOperation() override { |
| auto *ctxt = &getContext(); |
| RewritePatternSet patterns(ctxt); |
| ConversionTarget target(getContext()); |
| |
| // Define a type converter to convert shard::ShardingType, |
| // mostly for use in return operations. |
| TypeConverter typeConverter; |
| typeConverter.addConversion([](Type type) { return type; }); |
| |
| // convert shard::ShardingType to a tuple of RankedTensorTypes |
| typeConverter.addConversion( |
| [](ShardingType type, |
| SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
| auto i16 = IntegerType::get(type.getContext(), 16); |
| auto i64 = IntegerType::get(type.getContext(), 64); |
| std::array<int64_t, 2> shp = {ShapedType::kDynamic, |
| ShapedType::kDynamic}; |
| results.emplace_back(RankedTensorType::get(shp, i16)); |
| results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2 |
| results.emplace_back(RankedTensorType::get(shp, i64)); |
| return success(); |
| }); |
| |
| // To 'extract' components, a UnrealizedConversionCastOp is expected |
| // to define the input |
| typeConverter.addTargetMaterialization( |
| [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, |
| Location loc) { |
| // Expecting a single input. |
| if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) |
| return SmallVector<Value>(); |
| auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>(); |
| // Expecting an UnrealizedConversionCastOp. |
| if (!castOp) |
| return SmallVector<Value>(); |
| // Fill a vector with elements of the tuple/castOp. |
| SmallVector<Value> results; |
| for (auto oprnd : castOp.getInputs()) { |
| if (!isa<RankedTensorType>(oprnd.getType())) |
| return SmallVector<Value>(); |
| results.emplace_back(oprnd); |
| } |
| return results; |
| }); |
| |
| // No shard dialect should left after conversion... |
| target.addIllegalDialect<shard::ShardDialect>(); |
| // ...except the global GridOp. GridShapeOp which will get folded later. |
| target.addLegalOp<shard::GridOp, shard::GridShapeOp>(); |
| // Allow all the stuff that our patterns will convert to |
| target.addLegalDialect< |
| BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, |
| tensor::TensorDialect, bufferization::BufferizationDialect, |
| linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>(); |
| // Make sure the function signature, calls etc. are legal |
| target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| return typeConverter.isSignatureLegal(op.getFunctionType()); |
| }); |
| target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>( |
| [&](Operation *op) { return typeConverter.isLegal(op); }); |
| |
| patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, |
| ConvertProcessMultiIndexOp, ConvertGetShardingOp, |
| ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp, |
| ConvertProcessLinearIndexOp>(typeConverter, ctxt); |
| |
| populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
| patterns, typeConverter); |
| populateCallOpTypeConversionPattern(patterns, typeConverter); |
| populateReturnOpTypeConversionPattern(patterns, typeConverter); |
| |
| (void)applyPartialConversion(getOperation(), target, std::move(patterns)); |
| |
| // Folding patterns cannot be mixed with conversion patterns -> extra pass. |
| patterns.clear(); |
| SymbolTableCollection symbolTableCollection; |
| mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection); |
| (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
| } |
| }; |
| |
| } // namespace |