blob: f34e1041a6709e0f3bc1b5e5e441c629467f910b [file]
//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
// where:
// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
namespace {
struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
PatternRewriter &rewriter) const override {
// Calculation of the thread's subgroup identifier.
//
// The process involves mapping the thread's 3D identifier within its
// block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
// This linearization assumes a layout where the x-dimension (w_dim.x)
// varies most rapidly (i.e., it is the innermost dimension).
//
// The formula for the linearized thread index is:
// L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
//
// Subsequently, the range of linearized indices [0, N_threads-1] is
// divided into consecutive, non-overlapping segments, each representing
// a subgroup of size 'subgroup_size'.
//
// Example Partitioning (N = subgroup_size):
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
//
// The subgroup identifier is obtained via integer division of the
// linearized thread index by the predefined 'subgroup_size'.
//
// subgroup_id = floor( L / subgroup_size )
// = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
// subgroup_size
Location loc = op->getLoc();
Type indexType = rewriter.getIndexType();
auto asMaybeIndexAttr = [&](std::optional<uint32_t> bound) -> IntegerAttr {
if (!bound)
return IntegerAttr();
return IntegerAttr::get(
indexType, static_cast<int64_t>(static_cast<uint64_t>(*bound)));
};
IntegerAttr maybeKnownDimX =
asMaybeIndexAttr(gpu::getKnownDimensionSizeAround(
op, gpu::DimensionKind::Block, gpu::Dimension::x));
IntegerAttr maybeKnownDimY =
asMaybeIndexAttr(gpu::getKnownDimensionSizeAround(
op, gpu::DimensionKind::Block, gpu::Dimension::y));
IntegerAttr maybeKnownDimZ =
asMaybeIndexAttr(gpu::getKnownDimensionSizeAround(
op, gpu::DimensionKind::Block, gpu::Dimension::z));
Value dimX, dimY;
if (maybeKnownDimX)
dimX = arith::ConstantOp::create(rewriter, loc, maybeKnownDimX);
else
dimX = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
if (maybeKnownDimY)
dimY = arith::ConstantOp::create(rewriter, loc, maybeKnownDimY);
else
dimY = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::y);
Value tidX = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x,
maybeKnownDimX);
Value tidY = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::y,
maybeKnownDimY);
Value tidZ = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::z,
maybeKnownDimZ);
// Block dimensions don't exceed a signed int32_t maximum, and neither does
// their product, on any realistic hardware, nor would any targets compile
// with index < 32 bits, so we can assert no overflow.
auto flags =
arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw;
Value dimYxIdZ =
arith::MulIOp::create(rewriter, loc, indexType, dimY, tidZ, flags);
Value dimYxIdZPlusIdY =
arith::AddIOp::create(rewriter, loc, indexType, dimYxIdZ, tidY, flags);
Value dimYxIdZPlusIdYTimesDimX = arith::MulIOp::create(
rewriter, loc, indexType, dimX, dimYxIdZPlusIdY, flags);
Value idXPlusDimYxIdZPlusIdYTimesDimX = arith::AddIOp::create(
rewriter, loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX, flags);
Value subgroupSize = gpu::SubgroupSizeOp::create(
rewriter, loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
Value subgroupIdOp =
arith::DivUIOp::create(rewriter, loc, indexType,
idXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
rewriter.replaceOp(op, {subgroupIdOp});
return success();
}
};
} // namespace
void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
}