| //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/LLVMIR/XeVMDialect.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Interfaces/ViewLikeInterface.h" |
| |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "xegpu" |
| |
| using namespace mlir; |
| using namespace mlir::xegpu; |
| |
| static bool isSharedMemory(const MemRefType &memrefTy) { |
| Attribute attr = memrefTy.getMemorySpace(); |
| if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) |
| return intAttr.getInt() == 3; |
| if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr)) |
| return memrefSpace.getValue() == MemorySpace::SLM; |
| if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr)) |
| return xevmSpace.getValue() == xevm::AddrSpace::SHARED; |
| return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); |
| } |
| |
| template <typename T> |
| static std::string makeString(T array, bool breakline = false) { |
| std::string buf; |
| buf.clear(); |
| llvm::raw_string_ostream os(buf); |
| os << "["; |
| for (size_t i = 1; i < array.size(); i++) { |
| os << array[i - 1] << ", "; |
| if (breakline) |
| os << "\n\t\t"; |
| } |
| os << array.back() << "]"; |
| return buf; |
| } |
| |
| static SmallVector<int64_t> getShapeOf(Type type) { |
| SmallVector<int64_t> shape; |
| if (auto ty = llvm::dyn_cast<ShapedType>(type)) |
| shape = SmallVector<int64_t>(ty.getShape()); |
| else |
| shape.push_back(1); |
| return shape; |
| } |
| |
| static bool isReadHintOrNone(const CachePolicyAttr &attr) { |
| if (!attr) |
| return true; |
| auto kind = attr.getValue(); |
| return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || |
| kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; |
| } |
| |
| static bool isWriteHintOrNone(const CachePolicyAttr &attr) { |
| if (!attr) |
| return true; |
| auto kind = attr.getValue(); |
| return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || |
| kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; |
| } |
| |
| static LogicalResult |
| isValidGatherScatterParams(Type maskTy, VectorType valueTy, |
| TensorDescType tdescTy, |
| function_ref<InFlightDiagnostic()> emitError) { |
| |
| if (!tdescTy.isScattered()) |
| return emitError() << "Expects a scattered TensorDesc."; |
| |
| auto chunkSize = tdescTy.getChunkSizeAsInt(); |
| if (!valueTy) { |
| if (chunkSize > 1) |
| return emitError() << "Expecting chunk size == 1 for scalar result"; |
| if (dyn_cast<VectorType>(maskTy)) |
| return emitError() << "Expecting a vector type result."; |
| return success(); |
| } |
| |
| auto maskShape = getShapeOf(maskTy); |
| auto valueShape = getShapeOf(valueTy); |
| auto tdescShape = getShapeOf(tdescTy); |
| |
| if (valueTy.getElementType() != tdescTy.getElementType()) |
| return emitError() |
| << "Value should have the same element type as TensorDesc."; |
| |
| llvm::SmallVector<int64_t> expectedMaskShape(tdescShape); |
| if (chunkSize > 1) |
| expectedMaskShape.pop_back(); |
| if (expectedMaskShape != maskShape) |
| return emitError() |
| << "Mask should match TensorDesc except the chunk size dim."; |
| |
| // a valid shape for SIMT case |
| if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { |
| if (tdescTy.getLayoutAttr()) |
| return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code"; |
| return success(); |
| } |
| |
| if (tdescShape != valueShape) |
| return emitError() << "Value shape " << makeString(valueShape) |
| << " is neither a valid distribution for SIMT nor " |
| "consistent with the tensor descriptor for SIMD " |
| << tdescTy; |
| return success(); |
| } |
| |
| static LogicalResult |
| isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, |
| VectorType valueTy, int64_t chunkSize, |
| function_ref<InFlightDiagnostic()> emitError) { |
| |
| auto maskVecTy = dyn_cast<VectorType>(maskTy); |
| auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy); |
| if (!valueTy) { |
| if (chunkSize > 1) |
| return emitError() << "Expecting chunk size == 1 for scalar result"; |
| if (maskVecTy || offsetsVecTy) |
| return emitError() << "Expecting scalar mask and offsets."; |
| else if (maskVecTy && offsetsVecTy) |
| return emitError() << "Expecting a vector type result."; |
| return success(); |
| } |
| |
| auto valueSize = valueTy.getNumElements(); |
| // SIMT mode with scalar mask and offsets. |
| if (!maskVecTy && !offsetsVecTy) { |
| if (valueSize != chunkSize) |
| return emitError() << "value elements must match chunk size " |
| << chunkSize; |
| return success(); |
| } |
| auto maskShape = getShapeOf(maskTy); |
| auto valueShape = getShapeOf(valueTy); |
| |
| if (!maskVecTy) |
| return emitError() << "Expecting a vector type mask."; |
| int64_t maskSize = maskVecTy.getNumElements(); |
| |
| if (chunkSize > 1) { |
| if ((valueTy.getRank() == 1) && (valueSize != chunkSize)) |
| return emitError() << "value elements must match chunk size " |
| << chunkSize; |
| } else { |
| if (valueSize != maskSize) |
| return emitError() |
| << "Mask should match value except the chunk size dim."; |
| } |
| llvm::SmallVector<int64_t> expectedMaskShape(valueShape); |
| if (maskSize == 1) |
| return success(); |
| if (chunkSize > 1) |
| expectedMaskShape.pop_back(); |
| if (expectedMaskShape != maskShape) |
| return emitError() << "Mask should match value except the chunk size dim."; |
| |
| return success(); |
| } |
| |
| LogicalResult |
| IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, |
| UnitAttr subgroup_block_io, DistributeLayoutAttr layout, |
| function_ref<InFlightDiagnostic()> emitError) { |
| |
| if (!dataTy) { |
| if (subgroup_block_io) |
| return emitError() << "subgroup_block_io " |
| "are only allowed when result is a VectorType."; |
| else |
| return success(); |
| } |
| |
| if (mdescTy.getRank() < 2) |
| return emitError() << "mem_desc must be 2D or greater."; |
| |
| ArrayRef<int64_t> dataShape = dataTy.getShape(); |
| ArrayRef<int64_t> mdescShape = mdescTy.getShape(); |
| |
| SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); |
| ArrayAttr strideAttr = mdescTy.getStrideAttr(); |
| SmallVector<int64_t> strides; |
| for (Attribute attr : strideAttr.getValue()) { |
| strides.push_back(cast<IntegerAttr>(attr).getInt()); |
| } |
| if (subgroup_block_io && layout) { |
| auto laneData = layout.getEffectiveLaneDataAsInt(); |
| auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); |
| if (!laneData.empty()) { |
| bool isLaneDataContiguous = |
| std::all_of(laneData.begin(), std::prev(laneData.end()), |
| [](int x) { return x == 1; }); |
| if (!isLaneDataContiguous) |
| return emitError() << "With subgroup_block_io, accessed data must be " |
| "contiguous and coalesced."; |
| for (size_t i = 0; i < laneData.size(); ++i) { |
| if (laneLayout[i] != blockShape[i]) |
| return emitError() << "With subgroup_block_io, the block shape must " |
| "match the lane layout."; |
| if (laneLayout[i] != 1 && strides[i] != 1) |
| return emitError() << "With subgroup_block_io, the distributed " |
| "dimensions must be contiguous."; |
| } |
| } |
| } |
| if (dataShape.size() == 2) { |
| if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), |
| [](auto p) { return std::get<0>(p) > std::get<1>(p); })) |
| return emitError() << "data shape must not exceed mem_desc shape."; |
| } else { |
| // if the subgroup_block_io attribute is set, mdescTy must have block |
| // attribute |
| if (subgroup_block_io && !blockShape.size()) |
| return emitError() << "mem_desc must have block attribute when " |
| "subgroup_block_io is set."; |
| // if the subgroup_block_io attribute is set, the memdesc should be row |
| // major |
| if (subgroup_block_io && mdescTy.isColMajor()) |
| return emitError() << "mem_desc should be row major when " |
| "subgroup_block_io is set."; |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_CreateNdDescOp |
| //===----------------------------------------------------------------------===// |
| |
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
| Type tdesc, TypedValue<MemRefType> source) { |
| [[maybe_unused]] auto ty = source.getType(); |
| assert(ty.hasStaticShape() && "expecting a memref with static shape"); |
| |
| build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */, |
| ValueRange({}) /* empty dynamic shape */, |
| ValueRange({}) /* empty dynamic strides */, |
| DenseI64ArrayAttr({}) /* const offsets */, |
| DenseI64ArrayAttr({}) /* empty const shape*/, |
| DenseI64ArrayAttr({}) /* empty const strides*/); |
| } |
| |
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
| Type tdesc, Value source, |
| llvm::ArrayRef<OpFoldResult> shape, |
| llvm::ArrayRef<OpFoldResult> strides) { |
| Type srcTy = source.getType(); |
| assert((isa<IntegerType, MemRefType>(srcTy)) && |
| "Source has to be either int or memref."); |
| |
| llvm::SmallVector<Value> dynamicShape; |
| llvm::SmallVector<Value> dynamicStrides; |
| |
| llvm::SmallVector<int64_t> staticShape; |
| llvm::SmallVector<int64_t> staticStrides; |
| |
| dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); |
| dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); |
| |
| auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); |
| auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); |
| |
| if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) { |
| auto memrefShape = memrefTy.getShape(); |
| auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); |
| |
| // if shape and strides are from Memref, we don't need attributes for them |
| // to keep the IR print clean (only do so for full-static case, otherwise |
| // printer would fail trying to print empty array-attr). |
| if (staticShape == memrefShape && staticStrides == memrefStrides && |
| dynamicShape.empty() && dynamicStrides.empty()) { |
| staticShapeAttr = DenseI64ArrayAttr(); |
| staticStridesAttr = DenseI64ArrayAttr(); |
| } |
| } |
| |
| build(builder, state, tdesc, source, ValueRange({}), dynamicShape, |
| dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, |
| staticStridesAttr); |
| } |
| |
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
| Type tdesc, TypedValue<MemRefType> source, |
| llvm::ArrayRef<OpFoldResult> offsets) { |
| [[maybe_unused]] auto ty = source.getType(); |
| assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); |
| |
| llvm::SmallVector<int64_t> staticOffsets; |
| llvm::SmallVector<Value> dynamicOffsets; |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| |
| build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, |
| ValueRange({}) /* empty dynamic shape */, |
| ValueRange({}) /* empty dynamic strides */, |
| builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, |
| {} /* empty const shape*/, {} /* empty const strides*/); |
| } |
| |
| void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
| Type tdesc, Value source, |
| llvm::ArrayRef<OpFoldResult> offsets, |
| llvm::ArrayRef<OpFoldResult> shape, |
| llvm::ArrayRef<OpFoldResult> strides) { |
| assert(!shape.empty() && !offsets.empty() && !strides.empty() && |
| shape.size() == strides.size() && shape.size() == offsets.size()); |
| |
| Type srcTy = source.getType(); |
| assert((isa<IntegerType, MemRefType>(srcTy)) && |
| "Source has to be either int or memref."); |
| |
| llvm::SmallVector<Value> dynamicOffsets; |
| llvm::SmallVector<Value> dynamicShape; |
| llvm::SmallVector<Value> dynamicStrides; |
| |
| llvm::SmallVector<int64_t> staticOffsets; |
| llvm::SmallVector<int64_t> staticShape; |
| llvm::SmallVector<int64_t> staticStrides; |
| |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); |
| dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); |
| |
| auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
| auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); |
| auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); |
| |
| if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) { |
| auto memrefShape = memrefTy.getShape(); |
| auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); |
| |
| // if shape and strides are from Memref, we don't need attributes for them |
| // to keep the IR print clean (only do so for full-static case, otherwise |
| // printer would fail trying to print empty array-attr). |
| if (staticShape == memrefShape && staticStrides == memrefStrides && |
| dynamicShape.empty() && dynamicStrides.empty()) { |
| staticShapeAttr = DenseI64ArrayAttr(); |
| staticStridesAttr = DenseI64ArrayAttr(); |
| } |
| } |
| |
| build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, |
| dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); |
| } |
| |
| LogicalResult CreateNdDescOp::verify() { |
| size_t rank = getMixedSizes().size(); |
| bool invalidRank = rank != getMixedStrides().size(); |
| bool invalidElemTy = false; |
| |
| // Memory space of created TensorDesc should match with the source. |
| // Both source and TensorDesc are considered for global memory by default, |
| // if the memory scope attr is not specified. If source is an integer, |
| // it is considered as ptr to global memory. |
| auto srcMemorySpace = getSourceMemorySpace(); |
| auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace()); |
| if (srcMemorySpace != tdescMemorySpace) |
| return emitOpError("Memory space mismatch.") |
| << " Source: " << srcMemorySpace |
| << ", TensorDesc: " << tdescMemorySpace; |
| |
| if (size_t offsetRank = getMixedOffsets().size()) |
| invalidRank |= (offsetRank != rank); |
| |
| // check source type matches the rank if it is a memref. |
| // It also should have the same ElementType as TensorDesc. |
| if (auto memrefTy = dyn_cast<MemRefType>(getSourceType())) |
| invalidElemTy |= memrefTy.getElementType() != getElementType(); |
| |
| if (llvm::isa<IntegerType>(getSourceType())) { |
| // strides and shape must present for integer source. |
| if (getMixedStrides().empty() || getMixedSizes().empty()) |
| return emitOpError("expecting strides and shape to be present for " |
| "integer source."); |
| } |
| |
| if (invalidRank) |
| return emitOpError( |
| "Expecting the rank of shape, strides, offsets, and source (if source " |
| "is a memref) should match with each other."); |
| |
| // check result TensorDesc rank |
| if (getType().getRank() > (int64_t)rank) |
| return emitOpError( |
| "Expecting the TensorDesc rank is not greater than the " |
| "ranks of shape, strides, offsets or the memref source."); |
| |
| if (invalidElemTy) |
| return emitOpError("TensorDesc should have the same element " |
| "type with the source if it is a memref.\n"); |
| |
| if (getType().isScattered()) |
| return emitOpError("Expects a non-scattered TensorDesc.\n"); |
| |
| return success(); |
| } |
| |
| static ParseResult parseOptionalDynamicIndexList( |
| OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
| DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr, |
| AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
| |
| SmallVector<int64_t, 4> integerVals; |
| auto parseIntegerOrValue = [&]() { |
| OpAsmParser::UnresolvedOperand operand; |
| auto res = parser.parseOptionalOperand(operand); |
| |
| if (res.has_value() && succeeded(res.value())) { |
| values.push_back(operand); |
| integerVals.push_back(ShapedType::kDynamic); |
| if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) |
| return failure(); |
| } else { |
| int64_t integer; |
| if (failed(parser.parseInteger(integer))) |
| return failure(); |
| integerVals.push_back(integer); |
| } |
| return success(); |
| }; |
| |
| // If the optional values are given there must be left bracket |
| if (parser.parseOptionalLSquare().succeeded()) { |
| if (parser.parseCommaSeparatedList(parseIntegerOrValue) || |
| parser.parseRSquare()) |
| return parser.emitError(parser.getNameLoc()) |
| << "expected a list of SSA values or integers"; |
| integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); |
| return success(); |
| } |
| |
| return success(); |
| } |
| |
| static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, |
| OperandRange values, |
| DenseI64ArrayAttr integers) { |
| if (!integers || integers.empty()) |
| return; |
| printDynamicIndexList(printer, op, values, integers, |
| /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); |
| } |
| //===----------------------------------------------------------------------===// |
| // XeGPU_PrefetchNdOp |
| //===----------------------------------------------------------------------===// |
| |
| void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, |
| Value tensorDesc, xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| |
| return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), |
| l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); |
| } |
| |
| void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, |
| Value tensorDesc, ArrayRef<OpFoldResult> offsets, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint, |
| xegpu::DistributeLayoutAttr layout) { |
| SmallVector<Value> dynamicOffsets; |
| SmallVector<int64_t> staticOffsets; |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| |
| auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
| |
| build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, |
| l2_hint, l3_hint, /*anchor_layout=*/layout); |
| } |
| |
| LogicalResult PrefetchNdOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| if (tdescTy.isScattered()) |
| return emitOpError("Expects a non-scattered TensorDesc.\n"); |
| |
| if (!isReadHintOrNone(getL1HintAttr())) |
| return emitOpError("invalid l1_hint: ") << getL1HintAttr(); |
| |
| if (!isReadHintOrNone(getL2HintAttr())) |
| return emitOpError("invalid l2_hint: ") << getL2HintAttr(); |
| |
| if (!isReadHintOrNone(getL3HintAttr())) |
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); |
| |
| int64_t tDescRank = tdescTy.getRank(); |
| int64_t offsetSize = getMixedOffsets().size(); |
| if (offsetSize != 0 && offsetSize != tDescRank) |
| return emitOpError( |
| "Mismatched ranks between offsets and tensor descriptor"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_LoadNdOp |
| //===----------------------------------------------------------------------===// |
| |
| void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, |
| Value tensorDesc, UnitAttr packed, |
| DenseI64ArrayAttr transpose, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| |
| return build(builder, state, retType, tensorDesc, ValueRange(), |
| DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, |
| l3_hint, /*anchor_layout=*/nullptr); |
| } |
| |
| void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, |
| Value tensorDesc, ArrayRef<OpFoldResult> offsets, |
| UnitAttr packed, DenseI64ArrayAttr transpose, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint, |
| xegpu::DistributeLayoutAttr layout) { |
| SmallVector<Value> dynamicOffsets; |
| SmallVector<int64_t> staticOffsets; |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| |
| auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
| |
| build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, |
| packed, transpose, l1_hint, l2_hint, l3_hint, |
| /*anchor_layout=*/layout); |
| } |
| |
| LogicalResult LoadNdOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| auto valueTy = getType(); |
| |
| if (tdescTy.isScattered()) |
| return emitOpError("Expects a non-scattered TensorDesc.\n"); |
| |
| if (tdescTy.getRank() > 2) |
| return emitOpError("Expects a 1D or 2D TensorDesc.\n"); |
| |
| if (!valueTy) |
| return emitOpError("Invalid result, it should be a VectorType.\n"); |
| |
| if (!isReadHintOrNone(getL1HintAttr())) |
| return emitOpError("invalid l1_hint: ") << getL1HintAttr(); |
| |
| if (!isReadHintOrNone(getL2HintAttr())) |
| return emitOpError("invalid l2_hint: ") << getL2HintAttr(); |
| |
| if (!isReadHintOrNone(getL3HintAttr())) |
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); |
| |
| int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength(); |
| int valueElems = valueTy.getNumElements(); |
| |
| // If the result vector is 1D and has less elements than the tensor |
| // descriptor, it is supposed to be a SIMT op. The layout attribute in |
| // tensor_desc is not needed. |
| if (valueElems < tdescElems && valueTy.getRank() == 1) { |
| // SIMT mode doesn't need LayoutAttr. |
| if (tdescTy.getLayoutAttr()) |
| return emitOpError() |
| << "TensorDesc doesn't need LayoutAttr for SIMT code"; |
| |
| // For SIMT code, the load is evenly distributed across all lanes in a |
| // subgroup. Since subgroup size is arch dependent, we only check even |
| // distribution here. |
| if (tdescElems % valueElems) |
| return emitOpError() |
| << "Result shape " << makeString(getShapeOf(valueTy)) |
| << " is not a valid distribution for tensor descriptor " |
| << tdescTy; |
| |
| return success(); |
| } |
| |
| // Check SIMD mode. |
| auto tdescShape = getShapeOf(tdescTy); |
| auto valueShape = getShapeOf(valueTy); |
| |
| if (getTranspose()) { |
| auto trans = getTranspose().value(); |
| // Make sure the transpose value is valid, and apply it |
| if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); })) |
| tdescShape = applyPermutation(tdescShape, trans); |
| else |
| mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; |
| } |
| |
| if (getPacked()) { |
| if (tdescTy.getRank() == 2) { |
| const int axis = 0; |
| auto vnni_factor = valueShape.back(); |
| tdescShape[axis] /= vnni_factor; |
| tdescShape.push_back(vnni_factor); |
| } else { |
| mlir::emitWarning(getLoc()) |
| << "Invalid Packed Attr. It is ignored (available for 2D " |
| "TensorDesc only)."; |
| } |
| } |
| |
| auto array_len = tdescTy.getArrayLength(); |
| if (array_len > 1) |
| tdescShape.insert(tdescShape.begin(), array_len); |
| |
| if (tdescShape != valueShape) |
| return emitOpError() << "Result shape " << makeString(valueShape) |
| << " is not consistent with tensor descriptor " |
| << tdescTy; |
| |
| int64_t tDescRank = tdescTy.getRank(); |
| int64_t offsetSize = getMixedOffsets().size(); |
| if (offsetSize != 0 && offsetSize != tDescRank) |
| return emitOpError( |
| "Mismatched ranks between offsets and tensor descriptor"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_StoreNdOp |
| //===----------------------------------------------------------------------===// |
| |
| void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, |
| Value tensorDesc, xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| |
| return build(builder, state, value, tensorDesc, ValueRange(), |
| DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint, |
| /*anchor_layout=*/nullptr); |
| } |
| |
| void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, |
| Value tensorDesc, ArrayRef<OpFoldResult> offsets, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint, |
| xegpu::DistributeLayoutAttr layout) { |
| SmallVector<Value> dynamicOffsets; |
| SmallVector<int64_t> staticOffsets; |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| |
| auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
| |
| build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, |
| l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout); |
| } |
| |
| LogicalResult StoreNdOp::verify() { |
| auto dstTy = getTensorDescType(); // Tile |
| auto valTy = getValueType(); // Vector |
| |
| if (dstTy.isScattered()) |
| return emitOpError("Expects a non-scattered TensorDesc.\n"); |
| |
| if (dstTy.getRank() > 2) |
| return emitOpError("Expects a 1D or 2D TensorDesc.\n"); |
| |
| if (!valTy) |
| return emitOpError("Expecting a VectorType result.\n"); |
| |
| if (!isWriteHintOrNone(getL1HintAttr())) |
| return emitOpError("invalid l1_hint: ") << getL1HintAttr(); |
| |
| if (!isWriteHintOrNone(getL2HintAttr())) |
| return emitOpError("invalid l2_hint: ") << getL2HintAttr(); |
| |
| if (!isWriteHintOrNone(getL3HintAttr())) |
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); |
| |
| auto array_len = dstTy.getArrayLength(); |
| if (array_len > 1) |
| return emitOpError("array length is not supported by store_nd.\n"); |
| |
| auto tdescElems = dstTy.getNumElements(); |
| auto valueElems = valTy.getNumElements(); |
| |
| // Similar to LoadNdOp, if the value vector is 1D and has less elements than |
| // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute |
| // in tensor_desc is not needed. |
| if (valTy.getRank() == 1 && valueElems < tdescElems) { |
| // SIMT mode doesn't need LayoutAttr. |
| if (dstTy.getLayoutAttr()) |
| return emitOpError() |
| << "TensorDesc doesn't need LayoutAttr for SIMT code"; |
| |
| if (tdescElems % valueElems) |
| return emitOpError() |
| << "Value shape " << makeString(getShapeOf(valTy)) |
| << " is not a valid distribution for tensor descriptor " << dstTy; |
| |
| return success(); |
| } |
| |
| // SIMD code should have the same shape as the tensor descriptor. |
| auto tdescShape = getShapeOf(dstTy); |
| auto valueShape = getShapeOf(valTy); |
| if (tdescShape != valueShape) |
| return emitOpError() << "Value shape " << makeString(valueShape) |
| << " is not consistent with tensor descriptor " |
| << dstTy; |
| |
| int64_t tDescRank = dstTy.getRank(); |
| int64_t offsetSize = getMixedOffsets().size(); |
| if (offsetSize != 0 && offsetSize != tDescRank) |
| return emitOpError( |
| "Mismatched ranks between offsets and tensor descriptor"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_UpdateNDOffsetOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult UpdateNdOffsetOp::verify() { |
| auto ty = getTensorDescType(); |
| if (ty.isScattered()) |
| return emitOpError("Expects a non-scattered TensorDesc.\n"); |
| |
| // number of offsets specified must match the rank of the tensor descriptor |
| if (ty.getRank() != (int64_t)getNumOffsets()) { |
| return emitOpError("Invalid number of offsets."); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_CreateDescOp |
| //===----------------------------------------------------------------------===// |
| |
| void CreateDescOp::build(OpBuilder &builder, OperationState &state, |
| TensorDescType TensorDesc, Value source, |
| llvm::ArrayRef<OpFoldResult> offsets) { |
| auto loc = source.getLoc(); |
| int64_t size = static_cast<int64_t>(offsets.size()); |
| auto type = VectorType::get(size, builder.getIndexType()); |
| auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
| auto offset = vector::FromElementsOp::create(builder, loc, type, values); |
| build(builder, state, TensorDesc, source, offset); |
| } |
| |
| void CreateDescOp::build(OpBuilder &builder, OperationState &state, |
| TensorDescType TensorDesc, Value source, |
| llvm::ArrayRef<int64_t> offsets) { |
| auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); |
| build(builder, state, TensorDesc, source, ofrs); |
| } |
| |
| LogicalResult CreateDescOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| |
| if (!tdescTy.isScattered()) |
| return emitOpError("Expects a scattered TensorDesc.\n"); |
| |
| // Memory space of created TensorDesc should match with the source. |
| // Both source and TensorDesc are considered for global memory by default, |
| // if the memory scope attr is not specified. If source is an integer, |
| // it is considered as ptr to global memory. |
| auto srcMemorySpace = getSourceMemorySpace(); |
| auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace()); |
| if (srcMemorySpace != tdescMemorySpace) |
| return emitOpError("Memory space mismatch.") |
| << " Source: " << srcMemorySpace |
| << ", TensorDesc: " << tdescMemorySpace; |
| |
| // check total size |
| auto chunkSize = tdescTy.getChunkSizeAsInt(); |
| SmallVector<int64_t> shape(getOffsetsType().getShape()); |
| if (chunkSize != 1) |
| shape.push_back(chunkSize); |
| |
| auto tdescShape = getShapeOf(tdescTy); |
| if (shape != tdescShape) |
| return emitOpError("Incorrect TensorDesc shape. ") |
| << "Expected is " << makeString(shape) << "\n"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_PrefetchOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult PrefetchOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| |
| if (!tdescTy && !getOffsets()) |
| return emitOpError("Expects offsets."); |
| |
| if (tdescTy && getOffsets()) |
| return emitOpError("offsets not allowed."); |
| |
| if (tdescTy && !tdescTy.isScattered()) |
| return emitOpError("Expects a scattered TensorDesc."); |
| |
| if (!isReadHintOrNone(getL1HintAttr())) |
| return emitOpError("invalid l1_hint: ") << getL1HintAttr(); |
| |
| if (!isReadHintOrNone(getL2HintAttr())) |
| return emitOpError("invalid l2_hint: ") << getL2HintAttr(); |
| |
| if (!isReadHintOrNone(getL3HintAttr())) |
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); |
| |
| auto srcTy = getSourceType(); |
| if (srcTy.isInteger() && !getOffsetAlignByteAttr()) |
| return emitOpError("offset_align_byte is required with integer source."); |
| |
| if (getOffsetAlignByteAttr() && !srcTy.isInteger()) |
| return emitOpError("offset_align_byte only allowed with integer source."); |
| |
| return success(); |
| } |
| |
| void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, |
| IntegerAttr{}, /*anchor_layout=*/nullptr); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_LoadGatherOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult LoadGatherOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| auto maskTy = getMaskType(); |
| auto valueTy = getValueType(); |
| |
| if (!tdescTy && !getOffsets()) |
| return emitOpError("Expects offsets."); |
| |
| if (tdescTy && getOffsets()) |
| return emitOpError("offsets not allowed."); |
| |
| if (tdescTy && !tdescTy.isScattered()) |
| return emitOpError("Expects a scattered TensorDesc."); |
| |
| if (!isReadHintOrNone(getL1HintAttr())) |
| return emitOpError("invalid l1_hint: ") << getL1HintAttr(); |
| |
| if (!isReadHintOrNone(getL2HintAttr())) |
| return emitOpError("invalid l2_hint: ") << getL2HintAttr(); |
| |
| if (!isReadHintOrNone(getL3HintAttr())) |
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); |
| |
| if (tdescTy) |
| return isValidGatherScatterParams(maskTy, valueTy, tdescTy, |
| [&]() { return emitOpError(); }); |
| auto srcTy = getSourceType(); |
| uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); |
| auto memTy = dyn_cast<MemRefType>(srcTy); |
| |
| if (memTy && (getElementType() != memTy.getElementType())) |
| return emitError() << "Value should have the same element type as MemRef."; |
| |
| auto offsetsTy = getOffsets().getType(); |
| return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, |
| [&]() { return emitOpError(); }); |
| } |
| |
| void LoadGatherOp::build(OpBuilder &builder, OperationState &state, |
| Type valueType, Value source, Value mask, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| build(builder, state, valueType, source, Value(), mask, IntegerAttr(), |
| l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); |
| } |
| |
| void LoadGatherOp::build(OpBuilder &builder, OperationState &state, |
| Type valueType, Value source, |
| ArrayRef<OpFoldResult> offsets, Value mask, |
| IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| auto loc = source.getLoc(); |
| int64_t size = static_cast<int64_t>(offsets.size()); |
| auto type = VectorType::get(size, builder.getIndexType()); |
| auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
| auto offset = vector::FromElementsOp::create(builder, loc, type, values); |
| |
| build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, |
| l2_hint, l3_hint, /*anchor_layout=*/nullptr); |
| } |
| |
| void LoadGatherOp::build(OpBuilder &builder, OperationState &state, |
| Type valueType, Value source, |
| ArrayRef<OpFoldResult> offsets, Value mask, |
| IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint, |
| DistributeLayoutAttr layout) { |
| auto loc = source.getLoc(); |
| int64_t size = static_cast<int64_t>(offsets.size()); |
| auto type = VectorType::get(size, builder.getIndexType()); |
| auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
| auto offset = vector::FromElementsOp::create(builder, loc, type, values); |
| |
| build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, |
| l2_hint, l3_hint, layout); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_StoreScatterOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult StoreScatterOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| auto maskTy = getMaskType(); |
| auto valueTy = getValueType(); |
| |
| if (!tdescTy && !getOffsets()) |
| return emitOpError("Expects offsets."); |
| |
| if (tdescTy && getOffsets()) |
| return emitOpError("offsets not allowed."); |
| |
| if (tdescTy && !tdescTy.isScattered()) |
| return emitOpError("Expects a scattered TensorDesc."); |
| |
| if (!isWriteHintOrNone(getL1HintAttr())) |
| return emitOpError("invalid l1_hint: ") << getL1HintAttr(); |
| |
| if (!isWriteHintOrNone(getL2HintAttr())) |
| return emitOpError("invalid l2_hint: ") << getL2HintAttr(); |
| |
| if (!isWriteHintOrNone(getL3HintAttr())) |
| return emitOpError("invalid l3_hint: ") << getL3HintAttr(); |
| |
| if (tdescTy) |
| return isValidGatherScatterParams(maskTy, valueTy, tdescTy, |
| [&]() { return emitOpError(); }); |
| |
| auto destTy = getDestType(); |
| uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); |
| auto memTy = dyn_cast<MemRefType>(destTy); |
| |
| if (memTy && (getElementType() != memTy.getElementType())) |
| return emitError() << "Value should have the same element type as MemRef."; |
| |
| auto offsetsTy = getOffsets().getType(); |
| return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, |
| [&]() { return emitOpError(); }); |
| } |
| |
| void StoreScatterOp::build(OpBuilder &builder, OperationState &state, |
| Value value, Value dest, Value mask, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, |
| l2_hint, l3_hint, /*anchor_layout=*/nullptr); |
| } |
| |
| void StoreScatterOp::build(OpBuilder &builder, OperationState &state, |
| Value value, Value dest, |
| ArrayRef<OpFoldResult> offsets, Value mask, |
| IntegerAttr chunk_size, |
| xegpu::CachePolicyAttr l1_hint, |
| xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint) { |
| auto loc = dest.getLoc(); |
| int64_t size = static_cast<int64_t>(offsets.size()); |
| auto type = VectorType::get(size, builder.getIndexType()); |
| auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
| auto offset = vector::FromElementsOp::create(builder, loc, type, values); |
| |
| // Call the correct builder overload that does not expect result types. |
| build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, |
| l3_hint, /*anchor_layout=*/nullptr); |
| } |
| |
| void StoreScatterOp::build( |
| OpBuilder &builder, OperationState &state, Value value, Value dest, |
| ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size, |
| xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, |
| xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { |
| auto loc = dest.getLoc(); |
| int64_t size = static_cast<int64_t>(offsets.size()); |
| auto type = VectorType::get(size, builder.getIndexType()); |
| auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
| auto offset = vector::FromElementsOp::create(builder, loc, type, values); |
| |
| // Call the correct builder overload that does not expect result types. |
| build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, |
| l3_hint, layout); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_UpdateOffsetOp |
| //===----------------------------------------------------------------------===// |
| void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, |
| mlir::Value tensorDesc, |
| llvm::ArrayRef<OpFoldResult> offsets) { |
| auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType()); |
| assert(tdescTy && "Expecting the source is a TensorDescType value."); |
| auto loc = tensorDesc.getLoc(); |
| int64_t size = static_cast<int64_t>(offsets.size()); |
| auto type = VectorType::get({size}, builder.getIndexType()); |
| auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
| auto offset = vector::FromElementsOp::create(builder, loc, type, values); |
| build(builder, state, tdescTy, tensorDesc, offset); |
| } |
| |
| void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, |
| Value tensorDesc, llvm::ArrayRef<int64_t> offsets) { |
| auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); |
| build(builder, state, tensorDesc, ofrs); |
| } |
| |
| LogicalResult UpdateOffsetOp::verify() { |
| auto tdescTy = getTensorDescType(); |
| if (!tdescTy.isScattered()) |
| return emitOpError("Expects a scattered TensorDesc.\n"); |
| |
| SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy); |
| SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType()); |
| if (tdescTy.getChunkSizeAsInt() > 1) |
| expectedOffsetShape.pop_back(); |
| |
| if (expectedOffsetShape != offsetShape) |
| return emitOpError( |
| "Offsets should match TensorDesc except the chunk size dim."); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_DpasOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult DpasOp::verify() { |
| int64_t lhsRank = getLhsType().getRank(); |
| int64_t rhsRank = getRhsType().getRank(); |
| int64_t resRank = getResultType().getRank(); |
| auto lhsShape = getLhsType().getShape(); |
| auto rhsShape = getRhsType().getShape(); |
| auto resShape = getResultType().getShape(); |
| |
| if (getAcc() && getAcc().getType() != getResultType()) |
| return emitOpError("Expecting the acc type to be the same as result."); |
| |
| // SIMT code: the size of the B operand has to be a multiple of 32 bits. |
| // It skips the semantic check since lack of architecture information. |
| // Users need to ensure the correctness. |
| if (lhsRank == 1 && rhsRank == 1 && resRank == 1) { |
| auto numElems = getRhsType().getNumElements(); |
| auto elemTy = getRhsType().getElementType(); |
| auto factor = 32 / elemTy.getIntOrFloatBitWidth(); |
| if (numElems % factor != 0) |
| return emitOpError("Expecting B operand to be a multiple of 32 bits."); |
| return success(); |
| } |
| |
| // SIMD code |
| if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) |
| return emitOpError( |
| "expecting lhs and result to be a 2D vector, and rhs to be either " |
| "2D or 3D (packed) vector."); |
| auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; |
| if (bK != lhsShape[1]) |
| return emitOpError("K-dimension mismatch."); |
| if (lhsShape[0] != resShape[0]) |
| return emitOpError("M-dimension mismatch."); |
| if (rhsShape[1] != resShape[1]) |
| return emitOpError("N-dimension mismatch."); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_ConvertLayoutOp |
| //===----------------------------------------------------------------------===// |
| LogicalResult ConvertLayoutOp::verify() { |
| auto srcLayout = getInputLayout(); |
| auto resLayout = getTargetLayout(); |
| if (!srcLayout) |
| return emitOpError("expected input layout."); |
| if (!resLayout) |
| return emitOpError("expected target layout."); |
| |
| // both input and target layouts should be WgLayout or SgLayout at the same |
| // time. |
| if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) && |
| (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup())) |
| return emitOpError("expected input layout and target layout be WgLayout or " |
| "SgLayout at the same time."); |
| |
| auto shape = getSource().getType().getShape(); |
| if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout)) |
| return emitOpError( |
| "invalid input layout, data cannot be evenly distributed."); |
| |
| if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout)) |
| return emitOpError( |
| "invalid target layout, data cannot be evenly distributed."); |
| |
| return mlir::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_LoadMatrixOp |
| //===----------------------------------------------------------------------===// |
| void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, |
| TypedValue<MemDescType> memDesc, |
| llvm::ArrayRef<OpFoldResult> offsets, |
| DistributeLayoutAttr layout) { |
| llvm::SmallVector<Value> dynamicOffsets; |
| llvm::SmallVector<int64_t> staticOffsets; |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
| // Call the generated builder with all parameters (including optional ones as |
| // nullptr/empty) |
| build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, |
| /*subgroup_block_io=*/nullptr, layout); |
| } |
| |
| LogicalResult LoadMatrixOp::verify() { |
| |
| auto resTy = dyn_cast<VectorType>(getRes().getType()); |
| UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); |
| MemDescType mdescTy = getMemDesc().getType(); |
| |
| return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, |
| getLayoutAttr(), [&]() { return emitError(); }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_StoreMatrixOp |
| //===----------------------------------------------------------------------===// |
| void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, |
| TypedValue<MemDescType> memDesc, |
| llvm::ArrayRef<OpFoldResult> offsets, |
| DistributeLayoutAttr layout) { |
| llvm::SmallVector<Value> dynamicOffsets; |
| llvm::SmallVector<int64_t> staticOffsets; |
| dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
| auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
| build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, |
| /*subgroup_block_io=*/nullptr, layout); |
| } |
| |
| LogicalResult StoreMatrixOp::verify() { |
| |
| auto dataTy = dyn_cast<VectorType>(getData().getType()); |
| UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); |
| MemDescType mdescTy = getMemDesc().getType(); |
| return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, |
| getLayoutAttr(), [&]() { return emitError(); }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_TruncfOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult TruncfOp::verify() { |
| auto sourceVecType = dyn_cast<VectorType>(getSource().getType()); |
| auto resultVecType = dyn_cast<VectorType>(getResult().getType()); |
| |
| if (sourceVecType.getElementTypeBitWidth() <= |
| resultVecType.getElementTypeBitWidth()) |
| return emitOpError("input type must be wider than result type."); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XeGPU_DpasMxOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult DpasMxOp::verify() { |
| if (getAcc() && getAcc().getType() != getResultType()) |
| return emitOpError("Expecting the acc type to be the same as result."); |
| |
| return success(); |
| } |
| |
| namespace mlir { |
| #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> |
| } // namespace mlir |
| #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> |
| #define GET_OP_CLASSES |
| #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> |