blob: 64be7fca5d40fb989b9cc3c67425d7d8cd22e1c4 [file] [edit]
//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "xegpu"
using namespace mlir;
using namespace mlir::xegpu;
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
buf.clear();
llvm::raw_string_ostream os(buf);
os << "[";
for (size_t i = 1; i < array.size(); i++) {
os << array[i - 1] << ", ";
if (breakline)
os << "\n\t\t";
}
os << array.back() << "]";
return buf;
}
static SmallVector<int64_t> getShapeOf(Type type) {
SmallVector<int64_t> shape;
if (auto ty = llvm::dyn_cast<ShapedType>(type))
shape = SmallVector<int64_t>(ty.getShape());
else
shape.push_back(1);
return shape;
}
static bool isReadHintOrNone(const CachePolicyAttr &attr) {
if (!attr)
return true;
auto kind = attr.getValue();
return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
}
static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
if (!attr)
return true;
auto kind = attr.getValue();
return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
static LogicalResult
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
VectorType valueTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
auto maskVecTy = dyn_cast<VectorType>(maskTy);
auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
if (!valueTy) {
if (chunkSize > 1)
return emitError() << "Expecting chunk size == 1 for scalar result";
if (maskVecTy || offsetsVecTy)
return emitError() << "Expecting scalar mask and offsets.";
else if (maskVecTy && offsetsVecTy)
return emitError() << "Expecting a vector type result.";
return success();
}
auto valueSize = valueTy.getNumElements();
// SIMT mode with scalar mask and offsets.
if (!maskVecTy && !offsetsVecTy) {
if (valueSize != chunkSize)
return emitError() << "value elements must match chunk size "
<< chunkSize;
return success();
}
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
if (!maskVecTy)
return emitError() << "Expecting a vector type mask.";
int64_t maskSize = maskVecTy.getNumElements();
if (chunkSize > 1) {
if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
return emitError() << "value elements must match chunk size "
<< chunkSize;
} else {
if (valueSize != maskSize)
return emitError()
<< "Mask should match value except the chunk size dim.";
}
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
if (maskSize == 1)
return success();
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
return emitError() << "Mask should match value except the chunk size dim.";
return success();
}
LogicalResult
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
"are only allowed when result is a VectorType.";
else
return success();
}
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
ArrayAttr strideAttr = mdescTy.getStrideAttr();
SmallVector<int64_t> strides;
for (Attribute attr : strideAttr.getValue()) {
strides.push_back(cast<IntegerAttr>(attr).getInt());
}
if (subgroup_block_io && layout) {
auto laneData = layout.getEffectiveLaneDataAsInt();
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
if (!laneData.empty()) {
bool isLaneDataContiguous =
std::all_of(laneData.begin(), std::prev(laneData.end()),
[](int x) { return x == 1; });
if (!isLaneDataContiguous)
return emitError() << "With subgroup_block_io, accessed data must be "
"contiguous and coalesced.";
for (size_t i = 0; i < laneData.size(); ++i) {
if (laneLayout[i] != blockShape[i])
return emitError() << "With subgroup_block_io, the block shape must "
"match the lane layout.";
if (laneLayout[i] != 1 && strides[i] != 1)
return emitError() << "With subgroup_block_io, the distributed "
"dimensions must be contiguous.";
}
}
}
if (layout && !layout.isDistributable(
SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
return emitError() << "Value shape is not distributable with the layout";
if (dataShape.size() == mdescShape.size()) {
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
}
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
if (subgroup_block_io && !blockShape.size())
return emitError() << "mem_desc must have block attribute when "
"subgroup_block_io is set.";
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source) {
[[maybe_unused]] auto ty = source.getType();
assert(ty.hasStaticShape() && "expecting a memref with static shape");
build(builder, state, tdesc, source, ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
DenseI64ArrayAttr({}) /* empty const shape*/,
DenseI64ArrayAttr({}) /* empty const strides*/);
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
Type srcTy = source.getType();
assert((isa<IntegerType, MemRefType>(srcTy)) &&
"Source has to be either int or memref.");
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
auto memrefShape = memrefTy.getShape();
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
// to keep the IR print clean (only do so for full-static case, otherwise
// printer would fail trying to print empty array-attr).
if (staticShape == memrefShape && staticStrides == memrefStrides &&
dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
}
build(builder, state, tdesc, source, dynamicShape, dynamicStrides,
staticShapeAttr, staticStridesAttr);
}
LogicalResult CreateNdDescOp::verify() {
size_t rank = getMixedSizes().size();
bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;
// Memory space of created TensorDesc should match with the source.
// Both source and TensorDesc are considered for global memory by default,
// if the memory scope attr is not specified. If source is an integer,
// it is considered as ptr to global memory.
auto srcMemorySpace = getSourceMemorySpace();
auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
if (srcMemorySpace != tdescMemorySpace)
return emitOpError("Memory space mismatch.")
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, and source (if source "
"is a memref) should match with each other.");
// check result TensorDesc rank
if (getType().getRank() > (int64_t)rank)
return emitOpError("Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides or the memref source.");
if (invalidElemTy)
return emitOpError("TensorDesc should have the same element "
"type with the source if it is a memref.\n");
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
int64_t tDescRank = tdescTy.getRank();
int64_t offsetSize = getMixedOffsets().size();
if (offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
if (auto layout = getAnchorLayout()) {
if (!layout.isDistributable(getShapeOf(tdescTy)))
return emitOpError(
"TensorDesc shape is not distributable with the layout");
}
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_LoadNdOp
//===----------------------------------------------------------------------===//
void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
UnitAttr packed, DenseI64ArrayAttr transpose,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
packed, transpose, l1_hint, l2_hint, l3_hint,
/*anchor_layout=*/layout);
}
LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
if (!valueTy)
return emitOpError("Invalid result, it should be a VectorType.\n");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
int valueElems = valueTy.getNumElements();
// If the result vector is 1D and has less elements than the tensor
// descriptor, it is supposed to be a SIMT op. The layout attribute in
// tensor_desc is not needed.
if (valueElems < tdescElems && valueTy.getRank() == 1) {
// SIMT mode doesn't need LayoutAttr.
if (tdescTy.getLayoutAttr())
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
// For SIMT code, the load is evenly distributed across all lanes in a
// subgroup. Since subgroup size is arch dependent, we only check even
// distribution here.
if (tdescElems % valueElems)
return emitOpError()
<< "Result shape " << makeString(getShapeOf(valueTy))
<< " is not a valid distribution for tensor descriptor "
<< tdescTy;
return success();
}
// Check SIMD mode.
auto tdescShape = getShapeOf(tdescTy);
auto valueShape = getShapeOf(valueTy);
if (getTranspose()) {
auto trans = getTranspose().value();
// Make sure the transpose value is valid, and apply it
if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
tdescShape = applyPermutation(tdescShape, trans);
else
mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
}
if (getPacked()) {
if (tdescTy.getRank() == 2) {
const int axis = 0;
auto vnni_factor = valueShape.back();
tdescShape[axis] /= vnni_factor;
tdescShape.push_back(vnni_factor);
} else {
mlir::emitWarning(getLoc())
<< "Invalid Packed Attr. It is ignored (available for 2D "
"TensorDesc only).";
}
}
// Handle array_length. Two result shape conventions are accepted:
// * 3D shape: leading array_length dimension prepended, e.g. descriptor
// 16x16 with array_length=2 -> [2, 16, 16].
// * Stacked 2D shape: array blocks stacked along the non-FCD (first)
// dimension, e.g. descriptor 16x16 with array_length=2 -> [32, 16].
auto array_len = tdescTy.getArrayLength();
SmallVector<int64_t> stacked2DShape(tdescShape);
SmallVector<int64_t> threeDShape(tdescShape);
if (array_len > 1 && !tdescShape.empty()) {
stacked2DShape[0] *= array_len;
threeDShape.insert(threeDShape.begin(), array_len);
}
if (valueShape != stacked2DShape && valueShape != threeDShape)
return emitOpError() << "Result shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< tdescTy;
int64_t tDescRank = tdescTy.getRank();
int64_t offsetSize = getMixedOffsets().size();
if (offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
if (auto layout = getAnchorLayout()) {
if (!layout.isDistributable(getShapeOf(tdescTy)))
return emitOpError(
"TensorDesc shape is not distributable with the layout");
}
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_StoreNdOp
//===----------------------------------------------------------------------===//
void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
Value tensorDesc, ArrayRef<OpFoldResult> offsets,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
xegpu::DistributeLayoutAttr layout) {
SmallVector<Value> dynamicOffsets;
SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
}
LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
if (!valTy)
return emitOpError("Expecting a VectorType result.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto array_len = dstTy.getArrayLength();
if (array_len > 1)
return emitOpError("array length is not supported by store_nd.\n");
auto tdescElems = dstTy.getNumElements();
auto valueElems = valTy.getNumElements();
// Similar to LoadNdOp, if the value vector is 1D and has less elements than
// the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
// in tensor_desc is not needed.
if (valTy.getRank() == 1 && valueElems < tdescElems) {
// SIMT mode doesn't need LayoutAttr.
if (dstTy.getLayoutAttr())
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
if (tdescElems % valueElems)
return emitOpError()
<< "Value shape " << makeString(getShapeOf(valTy))
<< " is not a valid distribution for tensor descriptor " << dstTy;
return success();
}
// SIMD code should have the same shape as the tensor descriptor.
auto tdescShape = getShapeOf(dstTy);
auto valueShape = getShapeOf(valTy);
if (tdescShape != valueShape)
return emitOpError() << "Value shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< dstTy;
int64_t tDescRank = dstTy.getRank();
int64_t offsetSize = getMixedOffsets().size();
if (offsetSize != tDescRank)
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
if (auto layout = getAnchorLayout()) {
if (!layout.isDistributable(tdescShape))
return emitOpError(
"TensorDesc shape is not distributable with the layout");
}
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchOp
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto srcTy = getSourceType();
if (srcTy.isInteger() && !getOffsetAlignByteAttr())
return emitOpError("offset_align_byte is required with integer source.");
if (getOffsetAlignByteAttr() && !srcTy.isInteger())
return emitOpError("offset_align_byte only allowed with integer source.");
if (auto layout = getAnchorLayout()) {
// get the offset operand and its shape
auto offsetsTy = getOffsets().getType();
if (llvm::isa<VectorType>(offsetsTy) &&
!layout.isDistributable(getShapeOf(offsetsTy)))
return emitOpError("offset shape is not distributable with the layout");
}
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto srcTy = getSourceType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
if (auto layout = getAnchorLayout()) {
if (!layout.isDistributable(getShapeOf(valueTy)))
return emitOpError("Value shape is not distributable with the layout");
}
auto offsetsTy = getOffsets().getType();
return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
l2_hint, l3_hint, /*anchor_layout=*/nullptr);
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint,
DistributeLayoutAttr layout) {
auto loc = source.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
l2_hint, l3_hint, layout);
}
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto destTy = getDestType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
if (auto layout = getAnchorLayout()) {
if (!layout.isDistributable(getShapeOf(valueTy)))
return emitOpError("Value shape is not distributable with the layout");
}
auto offsetsTy = getOffsets().getType();
return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask,
IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
l3_hint, /*anchor_layout=*/nullptr);
}
void StoreScatterOp::build(
OpBuilder &builder, OperationState &state, Value value, Value dest,
ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
auto loc = dest.getLoc();
int64_t size = static_cast<int64_t>(offsets.size());
auto type = VectorType::get(size, builder.getIndexType());
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
auto offset = vector::FromElementsOp::create(builder, loc, type, values);
// Call the correct builder overload that does not expect result types.
build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
l3_hint, layout);
}
//===----------------------------------------------------------------------===//
// DPAS Common Verification Helpers
//===----------------------------------------------------------------------===//
// Helper to verify layout distributability for a value
static LogicalResult
verifyLayoutDistributable(Operation *op,
std::optional<DistributeLayoutAttr> layout,
ArrayRef<int64_t> shape, StringRef operandName) {
if (layout && !layout->isDistributable(
SmallVector<int64_t>(shape.begin(), shape.end())))
return op->emitOpError(operandName)
<< " shape is not distributable with the layout";
return success();
}
// Helper to verify M, N, K dimensions match between A, B, and result matrices
static LogicalResult verifyDpasDimensions(Operation *op,
ArrayRef<int64_t> aShape,
ArrayRef<int64_t> bShape,
ArrayRef<int64_t> resShape) {
auto aRank = aShape.size();
auto bRank = bShape.size();
auto resRank = resShape.size();
if (aRank == 1 && bRank == 1 && resRank == 1)
return success();
// A must be at least 2D, B must be 2D or 3D (innermost dims), result at
// least 2D.
if (aRank < 2)
return op->emitOpError("A operand must be at least a 2D vector.");
if (bRank < 2)
return op->emitOpError("B operand must be at least a 2D vector.");
if (resRank < 2)
return op->emitOpError("Result must be at least a 2D vector.");
// FIXME: B may have one extra trailing dim for VNNI packing
// (B[batch..., K/vnni, N, vnni]). We plan to drop VNNI packing support, so
// rather than properly verifying the packed dimensions, we simply accept
// the packed form here and skip the detailed verification. This branch
// should be removed once VNNI packing support is dropped.
if (bRank == aRank + 1)
return success();
// All operands have the same rank. They share the same batch dimensions,
// with the last two dims being the core matmul dims: A[batch..., M, K],
// B[batch..., K, N], result[batch..., M, N].
if (aRank != bRank || aRank != resRank)
return op->emitOpError("Rank mismatch among A, B, and result.");
int64_t batchRank = aRank - 2;
// Verify batch dimensions match.
for (int64_t i = 0; i < batchRank; ++i) {
if (aShape[i] != resShape[i])
return op->emitOpError("Batch dimension mismatch at dim ")
<< i << ": A has " << aShape[i] << " but result has "
<< resShape[i] << ".";
if (aShape[i] != bShape[i])
return op->emitOpError("Batch dimension mismatch at dim ")
<< i << ": A has " << aShape[i] << " but B has " << bShape[i]
<< ".";
}
// Core matmul dimensions (last two dims of each operand).
int64_t aM = aShape[batchRank];
int64_t aK = aShape[batchRank + 1];
int64_t bK = bShape[batchRank];
int64_t bN = bShape[batchRank + 1];
int64_t resM = resShape[batchRank];
int64_t resN = resShape[batchRank + 1];
// Verify K dimension match between A and B
if (bK != aK)
return op->emitOpError("K-dimension mismatch: A has K=")
<< aK << " but B has K=" << bK << ".";
// Verify M dimension match between A and result
if (aM != resM)
return op->emitOpError("M-dimension mismatch: A has M=")
<< aM << " but result has M=" << resM << ".";
// Verify N dimension match between B and result
if (bN != resN)
return op->emitOpError("N-dimension mismatch: B has N=")
<< bN << " but result has N=" << resN << ".";
return success();
}
// Helper to verify accumulator matches result type
static LogicalResult verifyDpasAccumulator(Operation *op, Type accType,
Type resultType) {
if (accType != resultType)
return op->emitOpError("Accumulator type must match result type.");
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
auto lhsShape = getLhsType().getShape();
auto rhsShape = getRhsType().getShape();
auto resShape = getResultType().getShape();
// Verify layout distributability
if (failed(
verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
return failure();
if (failed(verifyLayoutDistributable(*this, getLayoutA(), lhsShape, "A")))
return failure();
if (failed(verifyLayoutDistributable(*this, getLayoutB(), rhsShape, "B")))
return failure();
// Verify accumulator if present
if (getAcc() &&
failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
return failure();
return verifyDpasDimensions(*this, lhsShape, rhsShape, resShape);
}
//===----------------------------------------------------------------------===//
// XeGPU_ConvertLayoutOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertLayoutOp::verify() {
auto srcLayout = getInputLayout();
auto resLayout = getTargetLayout();
if (!srcLayout)
return emitOpError("expected input layout.");
if (!resLayout)
return emitOpError("expected target layout.");
// both input and target layouts should be WgLayout or SgLayout at the same
// time.
if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
(!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
return emitOpError("expected input layout and target layout be WgLayout or "
"SgLayout at the same time.");
Type srcType = getSource().getType();
if (llvm::isa<VectorType>(srcType)) {
SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
if (!srcLayout.isDistributable(shape))
return emitOpError(
"invalid input layout, data cannot be evenly distributed.");
if (!resLayout.isDistributable(shape))
return emitOpError(
"invalid target layout, data cannot be evenly distributed.");
}
return mlir::success();
}
//===----------------------------------------------------------------------===//
// XeGPU_LoadMatrixOp
//===----------------------------------------------------------------------===//
void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
TypedValue<MemDescType> memDesc,
llvm::ArrayRef<OpFoldResult> offsets,
DistributeLayoutAttr layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
// Call the generated builder with all parameters (including optional ones as
// nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
/*subgroup_block_io=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
auto resTy = dyn_cast<VectorType>(getRes().getType());
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
// XeGPU_StoreMatrixOp
//===----------------------------------------------------------------------===//
void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
TypedValue<MemDescType> memDesc,
llvm::ArrayRef<OpFoldResult> offsets,
DistributeLayoutAttr layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
/*subgroup_block_io=*/nullptr, layout);
}
LogicalResult StoreMatrixOp::verify() {
auto dataTy = dyn_cast<VectorType>(getData().getType());
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
// XeGPU_TruncfOp
//===----------------------------------------------------------------------===//
LogicalResult TruncfOp::verify() {
auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
auto resultVecType = dyn_cast<VectorType>(getResult().getType());
if (sourceVecType.getElementTypeBitWidth() <=
resultVecType.getElementTypeBitWidth())
return emitOpError("input type must be wider than result type.");
return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_DpasMxOp
//===----------------------------------------------------------------------===//
LogicalResult DpasMxOp::verify() {
auto aShape = getAType().getShape();
auto bShape = getBType().getShape();
auto resShape = getResultType().getShape();
// Verify layout distributability for A, B, and result
if (failed(
verifyLayoutDistributable(*this, getLayoutCd(), resShape, "Result")))
return failure();
if (failed(verifyLayoutDistributable(*this, getLayoutA(), aShape, "A")))
return failure();
if (failed(verifyLayoutDistributable(*this, getLayoutB(), bShape, "B")))
return failure();
// Verify accumulator if present
if (getAcc() &&
failed(verifyDpasAccumulator(*this, getAcc().getType(), getResultType())))
return failure();
// Verify M, N, K dimensions
if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
return failure();
// Determine batch rank from A operand.
int64_t aBatchRank = aShape.size() - 2;
// Validate scale_a if present
if (getScaleA()) {
auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
// Only validate if scale is a vector (scalars are always valid)
if (scaleAVecType && scaleAVecType.getRank() > 1) {
auto scaleAShape = scaleAVecType.getShape();
if (scaleAVecType.getRank() < 2)
return emitOpError("Scale A must be at least a 2D vector when not a "
"scalar.");
// Verify layout distributability for scale_a
if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
scaleAShape, "ScaleA")))
return failure();
// Validate M dimension: scale_a's M must match A's M (last-1 dim)
if (scaleAShape[scaleAShape.size() - 2] != aShape[aBatchRank])
return emitOpError("Scale A M dimension [")
<< scaleAShape[scaleAShape.size() - 2]
<< "] must match A M dimension [" << aShape[aBatchRank] << "].";
}
}
// Validate scale_b if present
if (getScaleB()) {
auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
// Only validate if scale is a vector (scalars are always valid)
if (scaleBVecType && scaleBVecType.getRank() > 1) {
auto scaleBShape = scaleBVecType.getShape();
if (scaleBVecType.getRank() < 2)
return emitOpError("Scale B must be at least a 2D vector when not a "
"scalar.");
// Verify layout distributability for scale_b
if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
scaleBShape, "ScaleB")))
return failure();
// Validate N dimension: scale_b's N (last dim) must match B's N (last
// dim)
if (scaleBShape.back() != bShape.back())
return emitOpError("Scale B N dimension [")
<< scaleBShape.back() << "] must match B N dimension ["
<< bShape.back() << "].";
}
}
// Validate scale K dimension compatibility if both scales are present and
// vectors
if (getScaleA() && getScaleB()) {
auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
auto scaleBVecType = dyn_cast<VectorType>(getScaleBType());
if (scaleAVecType && scaleBVecType && scaleAVecType.getRank() > 1 &&
scaleBVecType.getRank() > 1) {
auto scaleAShape = scaleAVecType.getShape();
auto scaleBShape = scaleBVecType.getShape();
// Validate scale K dimension compatibility: scale_a's last dim must
// match scale_b's second-to-last dim
if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2])
return emitOpError("Scale K dimension mismatch: scale_a has K=")
<< scaleAShape.back()
<< " but scale_b has K=" << scaleBShape[scaleBShape.size() - 2]
<< ".";
}
}
return success();
}
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>