blob: 6868c393f99e6b5429e3799286b15cd3a42e3d44 [file] [log] [blame]
//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
namespace xegpu {
void XeGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
>();
addOperations<
#define GET_OP_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
>();
}
//===----------------------------------------------------------------------===//
// XeGPU_BlockTensorDescAttr
//===----------------------------------------------------------------------===//
BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
xegpu::MemorySpace memory_space,
int array_length,
bool boundary_check) {
auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
auto lengthAttr =
IntegerAttr::get(IntegerType::get(context, 64), array_length);
auto boundaryAttr = BoolAttr::get(context, boundary_check);
return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
}
//===----------------------------------------------------------------------===//
// XeGPU_ScatterTensorDescAttr
//===----------------------------------------------------------------------===//
ScatterTensorDescAttr
ScatterTensorDescAttr::get(mlir::MLIRContext *context,
xegpu::MemorySpace memory_space, int chunk_size) {
auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
auto chunkSizeAttr =
IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
return Base::get(context, scopeAttr, chunkSizeAttr);
}
LogicalResult ScatterTensorDescAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
int64_t chunkSize = chunk_size.getInt();
SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
16, 32, 64, 128, 256};
if (!llvm::is_contained(supportedChunkSizes, chunkSize))
return emitError() << "invalid chunk size";
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_SGMapAttr
//===----------------------------------------------------------------------===//
namespace {
template <typename T, unsigned N>
LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
llvm::SmallVector<T, N> &result,
llvm::StringRef fieldName) {
if (failed(parser.parseKeyword(fieldName))) {
parser.emitError(parser.getCurrentLocation(),
"unexpected field name. Expected " + fieldName + ".");
return failure();
}
if (failed(parser.parseEqual())) {
parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
return failure();
}
auto elemParser = [&]() -> llvm::ParseResult {
uint32_t elem = 0;
auto res = parser.parseInteger(elem);
result.push_back(elem);
return res;
};
return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
elemParser, fieldName);
}
} // namespace
mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
::mlir::Type attrType) {
if (failed(parser.parseLess()))
return {};
llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
return {};
if (failed(parser.parseComma()))
return {};
if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
return {};
return SGMapAttr::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); },
parser.getContext(), wi_layout, wi_data);
}
void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer.printKeywordOrString("wi_layout");
printer << " = [" << getWiLayout() << "], ";
printer.printKeywordOrString("wi_data");
printer << " = [" << getWiData() << "]";
printer << ">";
}
LogicalResult
SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::ArrayRef<uint32_t> wi_layout,
llvm::ArrayRef<uint32_t> wi_data) {
if (wi_layout.size() != 2)
return emitError() << "expected wi_layout of size 2";
if (wi_data.size() != 2)
return emitError() << "expected wi_data of size 2";
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
mlir::FailureOr<mlir::Attribute> sg_map;
// Parse literal '<'
if (parser.parseLess())
return {};
auto shapeLoc = parser.getCurrentLocation();
if (mlir::failed(parser.parseDimensionList(shape))) {
parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
return {};
}
auto elemTypeLoc = parser.getCurrentLocation();
if (mlir::failed(parser.parseType(elementType))) {
parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
return {};
}
// parse optional attributes
while (mlir::succeeded(parser.parseOptionalComma())) {
mlir::Attribute attr;
ParseResult res = parser.parseAttribute(attr);
if (mlir::succeeded(res)) {
if (mlir::isa<SGMapAttr>(attr)) {
sg_map = attr;
continue;
}
if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
encoding = attr;
continue;
}
}
return {};
}
// Parse literal '>'
if (parser.parseGreater())
return {};
return TensorDescType::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); },
parser.getContext(), shape, elementType,
encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
for (int64_t dim : shape) {
if (mlir::ShapedType::isDynamic(dim))
printer << '?';
else
printer << dim;
printer << 'x';
}
printer << getElementType();
if (auto encoding = getEncoding())
printer << ", " << encoding;
if (auto sg_map = getSgMap())
printer << ", " << sg_map;
printer << ">";
}
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType, int array_length,
bool boundary_check,
MemorySpace memory_space,
mlir::Attribute sg_map) {
auto context = elementType.getContext();
auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
boundary_check);
return Base::get(context, shape, elementType, attr, sg_map);
}
TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType, int chunk_size,
MemorySpace memory_space,
mlir::Attribute sg_map) {
auto context = elementType.getContext();
auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
return Base::get(context, shape, elementType, attr, sg_map);
}
LogicalResult TensorDescType::verify(
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute sg_map) {
size_t rank = shape.size();
// Low-pressure types are packed in 32-bit units.
unsigned packingFactor = 32 / elementType.getIntOrFloatBitWidth();
if (rank != 1 && rank != 2)
return emitError() << "expected 1D or 2D tensor";
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
// Expected tensor ranks for scattered data:
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
// - 2D tensor for scattered blocks (chunk size > 1)
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
if (rank == 1 && chunkSize != 1)
return emitError() << "expected non-contiguous elements for 1D tensor";
if (rank == 2 && chunkSize < 2)
return emitError() << "expected chunk blocks for 2D tensor";
// If chunk size > 1, the second dimension of the tensor shape must be
// equal to chunk size and it must be a multiple of the packing factor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
return emitError() << "expected tensor shape[1] to match chunk size";
if (shape.back() % packingFactor != 0)
return emitError()
<< "expected tensor shape[1] to be a multiple of packing factor "
<< packingFactor;
}
}
if (auto blockAttr =
mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
if (rank == 2 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
return emitError() << "SLM is not supported for 2D block tensor";
}
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
if (rank == 1) {
if (wiLayout[0] != 1 || wiData[0] != 1)
return emitError()
<< "outer layout distribution and data mapping must be 1 "
"for 1D tensor";
}
if (scatterAttr) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
// respectively, the mapping should reflect that. This is because each
// work item access data in 32 bit granularity.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";
if (wiData[1] != packingFactor)
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
}
// For 1D tensor, pad the shape with an outer unit dimension to allow common
// validation logic.
SmallVector<int64_t> tensorShape(shape);
if (rank == 1)
tensorShape = {1, tensorShape.back()};
size_t dims = tensorShape.size();
for (size_t i = 0; i < dims; ++i) {
uint32_t numElemPerWi = wiLayout[i] * wiData[i];
if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
return emitError() << "cannot distribute " << tensorShape[i] << " over "
<< wiLayout[i] << " work items with " << wiData[i]
<< " elements each";
}
}
return success();
}
// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
// In this mode, the distributed vector shape is determined as follows:
// Definitions:
// wi_data_size = wi_data[0] × wi_data[1]
// subgroup_size = wi_layout[0] × wi_layout[1]
// distribution_unit_size = subgroup_size × wi_data_size
// ---------------------------------------------------------------------
// Case 1: Regular loads/stores.
// ---------------------------------------------------------------------
// Distributed vector shape must be:
// [chunk_size / wi_data_size, wi_data_size]
// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
// [wi_data_size]
// ---------------------------------------------------------------------
// Case 2: Block loads/stores
// ---------------------------------------------------------------------
// Additional definitions:
// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
// n_distribution_units = tensor_size / distribution_unit_size
// Given above definitions, the following conditions must be met:
// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
// Distributed vector shape must be:
// [n_distribution_units, wi_data_size]
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
// If no sg_map is provided, tensor desc is not used in SIMT mode.
if (!sgMap)
return failure();
SmallVector<int64_t> wiData(sgMap.getWiData());
SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
auto tdescShape = getShape();
auto wiDataSize = 1, sgSize = 1;
for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
wiDataSize *= wiDataDim;
sgSize *= wiDim;
}
// Case 1: regular loads/stores
auto scatterAttr = getEncodingAsScatterTensorDescAttr();
if (scatterAttr) {
auto chunkSize = scatterAttr.getChunkSize().getInt();
// Verify if the first dimension of the tensor descriptor shape is
// distributable.
assert(tdescShape[0] % (wiLayout[0]) == 0 &&
"tensor descriptor shape is not distributable");
if (chunkSize > 1)
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
getElementType());
return VectorType::get({wiDataSize}, getElementType());
}
// Case 2: block loads/stores
// Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
// and wiLayout must be 1.
if (tdescShape.size() == 1) {
assert((wiData[0] == 1 && wiLayout[0] == 1) &&
"wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
wiData = {wiData[1]};
wiLayout = {wiLayout[1]};
}
// Check if the tensor descriptor shape is distributable.
int64_t tensorSize = 1;
for (auto [tdescDim, wiDim, wiDataDim] :
llvm::zip_equal(tdescShape, wiLayout, wiData)) {
assert((tdescDim % (wiDim * wiDataDim) == 0) &&
"tensor descriptor shape is not distributable");
tensorSize *= tdescDim;
}
// tensorSize must be adjusted for array_length.
tensorSize *= getArrayLength();
return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
getElementType());
}
} // namespace xegpu
} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>