blob: 1aebd90a2e660092e57ce594fc8c08e194dabd41 [file] [log] [blame]
//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements in-dialect lowering of the all-reduce op to a block of
// simpler instructions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;
GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_,
PatternRewriter &rewriter_)
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
indexType(IndexType::get(reduceOp.getContext())),
int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
/// Creates an all_reduce across the workgroup.
///
/// First reduce the elements within a subgroup. The first invocation of each
/// subgroup writes the intermediate result to workgroup memory. After
/// synchronizing the workgroup, the first subgroup reduces the values from
/// workgroup memory. The result is broadcasted to all invocations through
/// workgroup memory.
///
/// %subgroup_reduce = `createSubgroupReduce(%operand)`
/// cond_br %is_first_lane, ^then1, ^continue1
/// ^then1:
/// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
/// br ^continue1
/// ^continue1:
/// gpu.barrier
/// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
/// cond_br %is_valid_subgroup, ^then2, ^continue2
/// ^then2:
/// %partial_reduce = load %workgroup_buffer[%invocation_idx]
/// %all_reduce = `createSubgroupReduce(%partial_reduce)`
/// store %all_reduce, %workgroup_buffer[%zero]
/// llvm.br ^continue2
/// ^continue2:
/// gpu.barrier
/// %result = load %workgroup_buffer[%zero]
/// return %result
///
void rewrite() {
rewriter.setInsertionPoint(reduceOp);
// Compute linear invocation index and workgroup size.
Value dimX = getDimOp<gpu::BlockDimOp>("x");
Value dimY = getDimOp<gpu::BlockDimOp>("y");
Value dimZ = getDimOp<gpu::BlockDimOp>("z");
Value tidX = getDimOp<gpu::ThreadIdOp>("x");
Value tidY = getDimOp<gpu::ThreadIdOp>("y");
Value tidZ = getDimOp<gpu::ThreadIdOp>("z");
Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
// Compute lane id (invocation id withing the subgroup).
Value subgroupMask =
create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
Value isFirstLane =
create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
create<arith::ConstantIntOp>(0, int32Type));
Value numThreadsWithSmallerSubgroupId =
create<arith::SubIOp>(invocationIdx, laneId);
// The number of active invocations starting from the current subgroup.
// The consumers do not require the value to be clamped to the size of the
// subgroup.
Value activeWidth =
create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
// Create factory for op which accumulates to values.
AccumulatorFactory accumFactory = getFactory();
assert(accumFactory && "failed to create accumulator factory");
// Reduce elements within each subgroup to produce the intermediate results.
Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
reduceOp.value(), accumFactory);
// Add workgroup buffer to parent function for intermediate result.
Value buffer = createWorkgroupBuffer();
// Write the intermediate results to workgroup memory, using the first lane
// of each subgroup.
createPredicatedBlock(isFirstLane, [&] {
Value subgroupId = getDivideBySubgroupSize(invocationIdx);
Value index = create<arith::IndexCastOp>(indexType, subgroupId);
create<memref::StoreOp>(subgroupReduce, buffer, index);
});
create<gpu::BarrierOp>();
// Compute number of active subgroups.
Value biasedBlockSize =
create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
invocationIdx, numSubgroups);
// Use the first numSubgroups invocations to reduce the intermediate results
// from workgroup memory. The final result is written to workgroup memory
// again.
Value zero = create<arith::ConstantIndexOp>(0);
createPredicatedBlock(isValidSubgroup, [&] {
Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
Value value = create<memref::LoadOp>(valueType, buffer, index);
Value result =
createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
create<memref::StoreOp>(result, buffer, zero);
});
// Synchronize workgroup and load result from workgroup memory.
create<gpu::BarrierOp>();
Value result = create<memref::LoadOp>(valueType, buffer, zero);
rewriter.replaceOp(reduceOp, result);
}
private:
// Shortcut to create an op from rewriter using loc as the first argument.
template <typename T, typename... Args>
T create(Args... args) {
return rewriter.create<T>(loc, std::forward<Args>(args)...);
}
// Creates dimension op of type T, with the result casted to int32.
template <typename T>
Value getDimOp(StringRef dimension) {
Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
return create<arith::IndexCastOp>(int32Type, dim);
}
/// Adds type to funcOp's workgroup attributions.
Value createWorkgroupBuffer() {
int workgroupMemoryAddressSpace =
gpu::GPUDialect::getWorkgroupAddressSpace();
auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
workgroupMemoryAddressSpace);
return funcOp.addWorkgroupAttribution(bufferType);
}
/// Returns an accumulator factory using either the op attribute or the body
/// region.
AccumulatorFactory getFactory() {
auto &body = reduceOp.body();
if (!body.empty())
return getFactory(body);
auto opAttr = reduceOp.op();
if (opAttr)
return getFactory(*opAttr);
return AccumulatorFactory();
}
/// Returns an accumulator factory that clones the body. The body's entry
/// block is expected to have 2 arguments. The gpu.yield return the
/// accumulated value of the same type.
AccumulatorFactory getFactory(Region &body) {
return AccumulatorFactory([&](Value lhs, Value rhs) {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
// Insert accumulator body between split block.
BlockAndValueMapping mapping;
mapping.map(body.getArgument(0), lhs);
mapping.map(body.getArgument(1), rhs);
rewriter.cloneRegionBefore(body, *split->getParent(),
split->getIterator(), mapping);
// Add branch before inserted body, into body.
block = block->getNextNode();
create<BranchOp>(block, ValueRange());
// Replace all gpu.yield ops with branch out of body.
for (; block != split; block = block->getNextNode()) {
Operation *terminator = block->getTerminator();
if (!isa<gpu::YieldOp>(terminator))
continue;
rewriter.setInsertionPointToEnd(block);
rewriter.replaceOpWithNewOp<BranchOp>(
terminator, split, ValueRange(terminator->getOperand(0)));
}
// Return accumulator result.
rewriter.setInsertionPointToStart(split);
return split->addArgument(lhs.getType());
});
}
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(StringRef opName) {
bool isFloatingPoint = valueType.isa<FloatType>();
if (opName == "add")
return isFloatingPoint ? getFactory<arith::AddFOp>()
: getFactory<arith::AddIOp>();
if (opName == "mul")
return isFloatingPoint ? getFactory<arith::MulFOp>()
: getFactory<arith::MulIOp>();
if (opName == "and") {
return getFactory<arith::AndIOp>();
}
if (opName == "or") {
return getFactory<arith::OrIOp>();
}
if (opName == "xor") {
return getFactory<arith::XOrIOp>();
}
if (opName == "max") {
return isFloatingPoint
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
arith::CmpFPredicate::UGT>()
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
arith::CmpIPredicate::ugt>();
}
if (opName == "min") {
return isFloatingPoint
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
arith::CmpFPredicate::ULT>()
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
arith::CmpIPredicate::ult>();
}
return AccumulatorFactory();
}
/// Returns an accumulator factory that creates an op of type T.
template <typename T>
AccumulatorFactory getFactory() {
return [&](Value lhs, Value rhs) {
return create<T>(lhs.getType(), lhs, rhs);
};
}
/// Returns an accumulator for comparison such as min, max. T is the type
/// of the compare op.
template <typename T, typename PredicateEnum, PredicateEnum predicate>
AccumulatorFactory getCmpFactory() const {
return [&](Value lhs, Value rhs) {
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
};
}
/// Creates an if-block skeleton and calls the two factories to generate the
/// ops in the `then` and `else` block..
///
/// llvm.cond_br %condition, ^then, ^continue
/// ^then:
/// %then_operands = `thenOpsFactory()`
/// llvm.br ^continue(%then_operands)
/// ^else:
/// %else_operands = `elseOpsFactory()`
/// llvm.br ^continue(%else_operands)
/// ^continue(%block_operands):
///
template <typename ThenOpsFactory, typename ElseOpsFactory>
void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
ElseOpsFactory &&elseOpsFactory) {
Block *currentBlock = rewriter.getInsertionBlock();
auto currentPoint = rewriter.getInsertionPoint();
Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
rewriter.setInsertionPointToEnd(currentBlock);
create<CondBranchOp>(condition, thenBlock,
/*trueOperands=*/ArrayRef<Value>(), elseBlock,
/*falseOperands=*/ArrayRef<Value>());
rewriter.setInsertionPointToStart(thenBlock);
auto thenOperands = thenOpsFactory();
create<BranchOp>(continueBlock, thenOperands);
rewriter.setInsertionPointToStart(elseBlock);
auto elseOperands = elseOpsFactory();
create<BranchOp>(continueBlock, elseOperands);
assert(thenOperands.size() == elseOperands.size());
rewriter.setInsertionPointToStart(continueBlock);
for (auto operand : thenOperands)
continueBlock->addArgument(operand.getType());
}
/// Shortcut for createIf with empty else block and no block operands.
template <typename Factory>
void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
"predicatedOpsFactory should not return any value");
createIf(
condition,
[&] {
predicatedOpsFactory();
return ArrayRef<Value>();
},
[&] { return ArrayRef<Value>(); });
}
/// Creates a reduction across the first activeWidth lanes of a subgroup, or
/// the entire subgroup if activeWidth is larger than the subgroup width.
/// The first lane returns the result, all others return values are undefined.
Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
AccumulatorFactory &accumFactory) {
Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
activeWidth, subgroupSize);
std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
auto xorAttr = rewriter.getStringAttr("xor");
createIf(
isPartialSubgroup,
// Generate reduction over a (potentially) partial subgroup.
[&] {
Value value = operand;
// Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
// lane is within the active range. The accumulated value is available
// in the first lane.
for (int i = 1; i < kSubgroupSize; i <<= 1) {
Value offset = create<arith::ConstantIntOp>(i, int32Type);
auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
activeWidth, xorAttr);
// Skip the accumulation if the shuffle op read from a lane outside
// of the active range.
createIf(
shuffleOp.getResult(1),
[&] {
return SmallVector<Value, 1>{
accumFactory(value, shuffleOp.getResult(0))};
},
[&] { return llvm::makeArrayRef(value); });
value = rewriter.getInsertionBlock()->getArgument(0);
}
return SmallVector<Value, 1>{value};
},
// Generate a reduction over the entire subgroup. This is a
// specialization of the above reduction with unconditional
// accumulation.
[&] {
Value value = operand;
for (int i = 1; i < kSubgroupSize; i <<= 1) {
Value offset = create<arith::ConstantIntOp>(i, int32Type);
auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
subgroupSize, xorAttr);
value = accumFactory(value, shuffleOp.getResult(0));
}
return SmallVector<Value, 1>{value};
});
return rewriter.getInsertionBlock()->getArgument(0);
}
/// Returns value divided by the subgroup size (i.e. 32).
Value getDivideBySubgroupSize(Value value) {
Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
return create<arith::DivSIOp>(int32Type, value, subgroupSize);
}
gpu::GPUFuncOp funcOp;
gpu::AllReduceOp reduceOp;
PatternRewriter &rewriter;
Location loc;
Type valueType;
Type indexType;
IntegerType int32Type;
static constexpr int kSubgroupSize = 32;
};
struct GpuAllReduceConversion : public RewritePattern {
explicit GpuAllReduceConversion(MLIRContext *context)
: RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto funcOp = cast<gpu::GPUFuncOp>(op);
auto callback = [&](gpu::AllReduceOp reduceOp) {
GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
// Performing a rewrite invalidates the walk iterator. Report interrupt
// so that we can start a new walk until all all_reduce ops are replaced.
return WalkResult::interrupt();
};
while (funcOp.walk(callback).wasInterrupted()) {
}
return success();
}
};
} // namespace
void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
patterns.add<GpuAllReduceConversion>(patterns.getContext());
}