| //===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===// |
| // |
| // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/Dialect/LLVMIR/XeVMDialect.h" |
| |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| #include "mlir/Dialect/Index/IR/IndexOps.h" |
| #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Types.h" |
| |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| #include <numeric> |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| |
| namespace { |
| |
| // TODO: Below are uArch dependent values, should move away from hardcoding |
| static constexpr int32_t systolicDepth{8}; |
| static constexpr int32_t executionSize{16}; |
| |
| // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. |
| enum class NdTdescOffset : uint32_t { |
| BasePtr = 0, // Base pointer (i64) |
| BaseShapeW = 2, // Base shape width (i32) |
| BaseShapeH = 3, // Base shape height (i32) |
| BasePitch = 4, // Base pitch (i32) |
| }; |
| |
| static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { |
| switch (xeGpuMemspace) { |
| case xegpu::MemorySpace::Global: |
| return static_cast<int>(xevm::AddrSpace::GLOBAL); |
| case xegpu::MemorySpace::SLM: |
| return static_cast<int>(xevm::AddrSpace::SHARED); |
| } |
| llvm_unreachable("Unknown XeGPU memory space"); |
| } |
| |
| /// Checks if the given MemRefType refers to shared memory. |
| static bool isSharedMemRef(const MemRefType &memrefTy) { |
| Attribute attr = memrefTy.getMemorySpace(); |
| if (!attr) |
| return false; |
| if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) |
| return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED); |
| if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr)) |
| return xevmSpace.getValue() == xevm::AddrSpace::SHARED; |
| return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); |
| } |
| |
| // Get same bitwidth flat vector type of new element type. |
| static VectorType encodeVectorTypeTo(VectorType currentVecType, |
| Type toElemType) { |
| auto elemType = currentVecType.getElementType(); |
| auto currentBitWidth = elemType.getIntOrFloatBitWidth(); |
| auto newBitWidth = toElemType.getIntOrFloatBitWidth(); |
| const int size = |
| currentVecType.getNumElements() * currentBitWidth / newBitWidth; |
| return VectorType::get(size, toElemType); |
| } |
| |
| static xevm::LoadCacheControl |
| translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, |
| std::optional<xegpu::CachePolicy> L3hint) { |
| auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); |
| auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); |
| switch (L1hintVal) { |
| case xegpu::CachePolicy::CACHED: |
| if (L3hintVal == xegpu::CachePolicy::CACHED) |
| return xevm::LoadCacheControl::L1C_L2UC_L3C; |
| else if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::LoadCacheControl::L1C_L2UC_L3UC; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| case xegpu::CachePolicy::UNCACHED: |
| if (L3hintVal == xegpu::CachePolicy::CACHED) |
| return xevm::LoadCacheControl::L1UC_L2UC_L3C; |
| else if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::LoadCacheControl::L1UC_L2UC_L3UC; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| case xegpu::CachePolicy::STREAMING: |
| if (L3hintVal == xegpu::CachePolicy::CACHED) |
| return xevm::LoadCacheControl::L1S_L2UC_L3C; |
| else if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::LoadCacheControl::L1S_L2UC_L3UC; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| case xegpu::CachePolicy::READ_INVALIDATE: |
| return xevm::LoadCacheControl::INVALIDATE_READ; |
| default: |
| llvm_unreachable("Unsupported cache control."); |
| } |
| } |
| |
| static xevm::StoreCacheControl |
| translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, |
| std::optional<xegpu::CachePolicy> L3hint) { |
| auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); |
| auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); |
| switch (L1hintVal) { |
| case xegpu::CachePolicy::UNCACHED: |
| if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::StoreCacheControl::L1UC_L2UC_L3UC; |
| else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) |
| return xevm::StoreCacheControl::L1UC_L2UC_L3WB; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| case xegpu::CachePolicy::STREAMING: |
| if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::StoreCacheControl::L1S_L2UC_L3UC; |
| else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) |
| return xevm::StoreCacheControl::L1S_L2UC_L3WB; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| case xegpu::CachePolicy::WRITE_BACK: |
| if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::StoreCacheControl::L1WB_L2UC_L3UC; |
| else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) |
| return xevm::StoreCacheControl::L1WB_L2UC_L3WB; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| case xegpu::CachePolicy::WRITE_THROUGH: |
| if (L3hintVal == xegpu::CachePolicy::UNCACHED) |
| return xevm::StoreCacheControl::L1WT_L2UC_L3UC; |
| else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) |
| return xevm::StoreCacheControl::L1WT_L2UC_L3WB; |
| else |
| llvm_unreachable("Unsupported cache control."); |
| default: |
| llvm_unreachable("Unsupported cache control."); |
| } |
| } |
| |
| // |
| // Note: |
| // Block operations for tile of sub byte element types are handled by |
| // emulating with larger element types. |
| // Tensor descriptor are keep intact and only ops consuming them are |
| // emulated |
| // |
| |
| class CreateNdDescToXeVMPattern |
| : public OpConversionPattern<xegpu::CreateNdDescOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::CreateNdDescOp op, |
| xegpu::CreateNdDescOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); |
| if (mixedOffsets.size() != 0) |
| return rewriter.notifyMatchFailure(op, "Offsets not supported."); |
| auto loc = op.getLoc(); |
| auto source = op.getSource(); |
| // Op is lowered to a code sequence that populates payload. |
| // Payload is a 8xi32 vector. Offset to individual fields are defined in |
| // NdTdescOffset enum. |
| Type payloadElemTy = rewriter.getI32Type(); |
| VectorType payloadTy = VectorType::get(8, payloadElemTy); |
| Type i64Ty = rewriter.getI64Type(); |
| // 4xi64 view is used for inserting the base pointer. |
| VectorType payloadI64Ty = VectorType::get(4, i64Ty); |
| // Initialize payload to zero. |
| Value payload = arith::ConstantOp::create( |
| rewriter, loc, |
| DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); |
| |
| Value baseAddr; |
| Value baseShapeW; |
| Value baseShapeH; |
| |
| // Source can be a memref or a pointer (ui64, ui32, i64 or i32). |
| SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); |
| SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides(); |
| // Descriptor shape is expected to be 2D. |
| int64_t rank = mixedSizes.size(); |
| auto sourceTy = source.getType(); |
| auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); |
| // If source is a memref, we need to extract the aligned pointer as index. |
| // Pointer type is passed as i32 or i64 by type converter. |
| if (sourceMemrefTy) { |
| if (!sourceMemrefTy.hasRank()) { |
| return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); |
| } |
| // Access adaptor after failure check to avoid rolling back generated code |
| // for materialization cast. |
| baseAddr = adaptor.getSource(); |
| } else { |
| baseAddr = adaptor.getSource(); |
| if (baseAddr.getType() != i64Ty) { |
| // Pointer type may be i32. Cast to i64 if needed. |
| baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); |
| } |
| } |
| // 1D tensor descriptor is just the base address. |
| if (rank == 1) { |
| rewriter.replaceOp(op, baseAddr); |
| return success(); |
| } |
| // Utility for creating offset values from op fold result. |
| auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec, |
| unsigned idx) -> Value { |
| Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); |
| val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); |
| return val; |
| }; |
| // Get shape values from op fold results. |
| baseShapeW = createOffset(mixedSizes, 1); |
| baseShapeH = createOffset(mixedSizes, 0); |
| // Get pitch value from op fold results. |
| Value basePitch = createOffset(mixedStrides, 0); |
| // Populate payload. |
| Value payLoadAsI64 = |
| vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); |
| payLoadAsI64 = |
| vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, |
| static_cast<int>(NdTdescOffset::BasePtr)); |
| payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); |
| payload = |
| vector::InsertOp::create(rewriter, loc, baseShapeW, payload, |
| static_cast<int>(NdTdescOffset::BaseShapeW)); |
| payload = |
| vector::InsertOp::create(rewriter, loc, baseShapeH, payload, |
| static_cast<int>(NdTdescOffset::BaseShapeH)); |
| payload = |
| vector::InsertOp::create(rewriter, loc, basePitch, payload, |
| static_cast<int>(NdTdescOffset::BasePitch)); |
| rewriter.replaceOp(op, payload); |
| return success(); |
| } |
| }; |
| |
| template < |
| typename OpType, |
| typename = std::enable_if_t<llvm::is_one_of< |
| OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>> |
| class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { |
| using OpConversionPattern<OpType>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto mixedOffsets = op.getMixedOffsets(); |
| int64_t opOffsetsSize = mixedOffsets.size(); |
| auto loc = op.getLoc(); |
| auto ctxt = rewriter.getContext(); |
| |
| auto tdesc = adaptor.getTensorDesc(); |
| auto tdescTy = op.getTensorDescType(); |
| auto tileRank = tdescTy.getRank(); |
| if (opOffsetsSize != tileRank) |
| return rewriter.notifyMatchFailure( |
| op, "Expected offset rank to match descriptor rank."); |
| auto elemType = tdescTy.getElementType(); |
| auto elemBitSize = elemType.getIntOrFloatBitWidth(); |
| bool isSubByte = elemBitSize < 8; |
| uint64_t wScaleFactor = 1; |
| |
| if (!isSubByte && (elemBitSize % 8 != 0)) |
| return rewriter.notifyMatchFailure( |
| op, "Expected element type bit width to be multiple of 8."); |
| auto tileW = tdescTy.getDimSize(tileRank - 1); |
| // For sub byte types, only 4bits are currently supported. |
| if (isSubByte) { |
| if (elemBitSize != 4) |
| return rewriter.notifyMatchFailure( |
| op, "Only sub byte types of 4bits are supported."); |
| if (tileRank != 2) |
| return rewriter.notifyMatchFailure( |
| op, "Sub byte types are only supported for 2D tensor descriptors."); |
| auto subByteFactor = 8 / elemBitSize; |
| auto tileH = tdescTy.getDimSize(0); |
| // Handle special case for packed load. |
| if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { |
| if (op.getPacked().value_or(false)) { |
| // packed load is implemented as packed loads of 8bit elements. |
| if (tileH == systolicDepth * 4 && |
| tileW == executionSize * subByteFactor) { |
| // Usage case for loading as Matrix B with pack request. |
| // source is assumed to pre-packed into 8bit elements |
| // Emulate with 8bit loads with pack request. |
| // scaled_tileW = executionSize |
| elemType = rewriter.getIntegerType(8); |
| tileW = executionSize; |
| wScaleFactor = subByteFactor; |
| } |
| } |
| } |
| // If not handled by packed load case above, handle other cases. |
| if (wScaleFactor == 1) { |
| auto sub16BitFactor = subByteFactor * 2; |
| if (tileW == executionSize * sub16BitFactor) { |
| // Usage case for loading as Matrix A operand |
| // Emulate with 16bit loads/stores. |
| // scaled_tileW = executionSize |
| elemType = rewriter.getIntegerType(16); |
| tileW = executionSize; |
| wScaleFactor = sub16BitFactor; |
| } else { |
| return rewriter.notifyMatchFailure( |
| op, "Unsupported tile shape for sub byte types."); |
| } |
| } |
| // recompute element bit size for emulation. |
| elemBitSize = elemType.getIntOrFloatBitWidth(); |
| } |
| |
| // Get address space from tensor descriptor memory space. |
| auto ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); |
| if (tileRank == 2) { |
| // Compute element byte size. |
| Value elemByteSize = arith::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); |
| VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); |
| Value payLoadAsI64 = |
| vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); |
| Value basePtr = |
| vector::ExtractOp::create(rewriter, loc, payLoadAsI64, |
| static_cast<int>(NdTdescOffset::BasePtr)); |
| Value baseShapeW = vector::ExtractOp::create( |
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); |
| Value baseShapeH = vector::ExtractOp::create( |
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); |
| Value basePitch = vector::ExtractOp::create( |
| rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch)); |
| // Offsets are provided by the op. |
| // convert them to i32. |
| Value offsetW = |
| getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); |
| offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, |
| rewriter.getI32Type(), offsetW); |
| Value offsetH = |
| getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); |
| offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, |
| rewriter.getI32Type(), offsetH); |
| // Convert base pointer (i64) to LLVM pointer type. |
| Value basePtrLLVM = |
| LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); |
| // FIXME: width or pitch is not the same as baseShapeW it should be the |
| // stride of the second to last dimension in row major layout. |
| // Compute width in bytes. |
| Value baseShapeWInBytes = |
| arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); |
| // Compute pitch in bytes. |
| Value basePitchBytes = |
| arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); |
| |
| if (wScaleFactor > 1) { |
| // Scale offsetW, baseShapeWInBytes for sub byte emulation. |
| // Note: tileW is already scaled above. |
| Value wScaleFactorValLog2 = arith::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor)); |
| baseShapeWInBytes = arith::ShRSIOp::create( |
| rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2); |
| basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes, |
| wScaleFactorValLog2); |
| offsetW = |
| arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2); |
| } |
| // Get tile height from the tensor descriptor type. |
| auto tileH = tdescTy.getDimSize(0); |
| // Get vblocks from the tensor descriptor type. |
| int32_t vblocks = tdescTy.getArrayLength(); |
| if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { |
| Value src = adaptor.getValue(); |
| // If store value is a scalar, get value from op instead of adaptor. |
| // Adaptor might have optimized away single element vector |
| if (src.getType().isIntOrFloat()) { |
| src = op.getValue(); |
| } |
| VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); |
| if (!srcVecTy) |
| return rewriter.notifyMatchFailure( |
| op, "Expected store value to be a vector type."); |
| // Get flat vector type of integer type with matching element bit size. |
| VectorType newSrcVecTy = |
| encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); |
| if (srcVecTy != newSrcVecTy) |
| src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); |
| auto storeCacheControl = |
| translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); |
| xevm::BlockStore2dOp::create( |
| rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, |
| basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src, |
| xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); |
| rewriter.eraseOp(op); |
| } else { |
| auto loadCacheControl = |
| translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); |
| if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { |
| xevm::BlockPrefetch2dOp::create( |
| rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, |
| basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, |
| vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); |
| rewriter.eraseOp(op); |
| } else { |
| VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); |
| const bool vnni = op.getPacked().value_or(false); |
| auto transposeValue = op.getTranspose(); |
| bool transpose = |
| transposeValue.has_value() && transposeValue.value()[0] == 1; |
| VectorType loadedTy = encodeVectorTypeTo( |
| dstVecTy, vnni ? rewriter.getI32Type() |
| : rewriter.getIntegerType(elemBitSize)); |
| |
| Value resultFlatVec = xevm::BlockLoad2dOp::create( |
| rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes, |
| baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW, |
| tileH, vblocks, transpose, vnni, |
| xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); |
| resultFlatVec = vector::BitCastOp::create( |
| rewriter, loc, |
| encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), |
| resultFlatVec); |
| rewriter.replaceOp(op, resultFlatVec); |
| } |
| } |
| } else { |
| // 1D tensor descriptor. |
| // `tdesc` represents base address as i64 |
| // Offset in number of elements, need to multiply by element byte size. |
| // Compute byte offset. |
| // byteOffset = offset * elementByteSize |
| Value offset = |
| getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); |
| offset = getValueOrCreateCastToIndexLike(rewriter, loc, |
| rewriter.getI64Type(), offset); |
| // Compute element byte size. |
| Value elemByteSize = arith::ConstantIntOp::create( |
| rewriter, loc, rewriter.getI64Type(), elemBitSize / 8); |
| Value byteOffset = |
| rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize); |
| // Final address = basePtr + byteOffset |
| Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>( |
| loc, tdesc, |
| getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), |
| byteOffset)); |
| // Convert base pointer (i64) to LLVM pointer type. |
| Value finalPtrLLVM = |
| LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); |
| if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { |
| Value src = adaptor.getValue(); |
| // If store value is a scalar, get value from op instead of adaptor. |
| // Adaptor might have optimized away single element vector |
| if (src.getType().isIntOrFloat()) { |
| src = op.getValue(); |
| } |
| VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); |
| if (!srcVecTy) |
| return rewriter.notifyMatchFailure( |
| op, "Expected store value to be a vector type."); |
| // Get flat vector type of integer type with matching element bit size. |
| VectorType newSrcVecTy = |
| encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); |
| if (srcVecTy != newSrcVecTy) |
| src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); |
| auto storeCacheControl = |
| translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); |
| rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>( |
| op, finalPtrLLVM, src, |
| xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); |
| } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { |
| auto loadCacheControl = |
| translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); |
| VectorType resTy = cast<VectorType>(op.getValue().getType()); |
| VectorType loadedTy = |
| encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); |
| Value load = xevm::BlockLoadOp::create( |
| rewriter, loc, loadedTy, finalPtrLLVM, |
| xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); |
| if (loadedTy != resTy) |
| load = vector::BitCastOp::create(rewriter, loc, resTy, load); |
| rewriter.replaceOp(op, load); |
| } else { |
| return rewriter.notifyMatchFailure( |
| op, "Unsupported operation: xegpu.prefetch_nd with tensor " |
| "descriptor rank == 1"); |
| } |
| } |
| return success(); |
| } |
| }; |
| |
| // Add a builder that creates |
| // offset * elemByteSize + baseAddr |
| static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, |
| Location loc, Value baseAddr, Value offset, |
| int64_t elemByteSize) { |
| Value byteSize = arith::ConstantIntOp::create( |
| rewriter, loc, baseAddr.getType(), elemByteSize); |
| Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); |
| Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); |
| return newAddr; |
| } |
| |
| template <typename OpType, |
| typename = std::enable_if_t<llvm::is_one_of< |
| OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>> |
| class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { |
| using OpConversionPattern<OpType>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Value offset = adaptor.getOffsets(); |
| if (!offset) |
| return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); |
| auto loc = op.getLoc(); |
| auto ctxt = rewriter.getContext(); |
| auto tdescTy = op.getTensorDescType(); |
| Value basePtrI64; |
| // Load result or Store valye Type can be vector or scalar. |
| Type valOrResTy; |
| if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) |
| valOrResTy = |
| this->getTypeConverter()->convertType(op.getResult().getType()); |
| else |
| valOrResTy = adaptor.getValue().getType(); |
| VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); |
| bool hasScalarVal = !valOrResVecTy; |
| int64_t elemBitWidth = |
| hasScalarVal ? valOrResTy.getIntOrFloatBitWidth() |
| : valOrResVecTy.getElementType().getIntOrFloatBitWidth(); |
| // Element type must be multiple of 8 bits. |
| if (elemBitWidth % 8 != 0) |
| return rewriter.notifyMatchFailure( |
| op, "Expected element type bit width to be multiple of 8."); |
| int64_t elemByteSize = elemBitWidth / 8; |
| // Default memory space is global. |
| LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); |
| // If tensor descriptor is available, we use its memory space. |
| if (tdescTy) |
| ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); |
| // Base pointer can come from source (load) or dest (store). |
| // If they are memrefs, we use their memory space. |
| if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) { |
| basePtrI64 = adaptor.getSource(); |
| if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) { |
| auto addrSpace = memRefTy.getMemorySpaceAsInt(); |
| if (addrSpace != 0) |
| ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); |
| } |
| } else { |
| basePtrI64 = adaptor.getDest(); |
| if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) { |
| auto addrSpace = memRefTy.getMemorySpaceAsInt(); |
| if (addrSpace != 0) |
| ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); |
| } |
| } |
| // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. |
| if (basePtrI64.getType() != rewriter.getI64Type()) { |
| basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), |
| basePtrI64); |
| } |
| Value mask = adaptor.getMask(); |
| if (dyn_cast<VectorType>(offset.getType())) { |
| // Offset needs be scalar. Single element vector is converted to scalar |
| // by type converter. |
| return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar."); |
| } else { |
| // If offset is provided, we add them to the base pointer. |
| // Offset is in number of elements, we need to multiply by |
| // element byte size. |
| basePtrI64 = |
| addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); |
| } |
| // Convert base pointer (i64) to LLVM pointer type. |
| Value basePtrLLVM = |
| LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); |
| |
| Value maskForLane; |
| VectorType maskVecTy = dyn_cast<VectorType>(mask.getType()); |
| if (maskVecTy) { |
| // Mask needs be scalar. Single element vector is converted to scalar by |
| // type converter. |
| return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); |
| } else |
| maskForLane = mask; |
| if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) { |
| scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, |
| maskForLane, true, true); |
| // If mask is true,- then clause - load from memory and yield. |
| rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| if (!hasScalarVal) |
| valOrResTy = VectorType::get({valOrResVecTy.getNumElements()}, |
| valOrResVecTy.getElementType()); |
| Value loaded = |
| LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM); |
| // Set cache control attribute on the load operation. |
| loaded.getDefiningOp()->setAttr( |
| "cache_control", xevm::LoadCacheControlAttr::get( |
| ctxt, translateLoadXeGPUCacheHint( |
| op.getL1Hint(), op.getL3Hint()))); |
| scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); |
| rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| // If mask is false - else clause -yield a vector of zeros. |
| auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType(); |
| TypedAttr eVal; |
| if (eTy.isFloat()) |
| eVal = FloatAttr::get(eTy, 0.0); |
| else |
| eVal = IntegerAttr::get(eTy, 0); |
| if (hasScalarVal) |
| loaded = arith::ConstantOp::create(rewriter, loc, eVal); |
| else |
| loaded = arith::ConstantOp::create( |
| rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal)); |
| scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); |
| rewriter.replaceOp(op, ifOp.getResult(0)); |
| } else { |
| // If mask is true, perform the store. |
| scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); |
| auto body = ifOp.getBody(); |
| rewriter.setInsertionPointToStart(body); |
| auto storeOp = |
| LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM); |
| // Set cache control attribute on the store operation. |
| storeOp.getOperation()->setAttr( |
| "cache_control", xevm::StoreCacheControlAttr::get( |
| ctxt, translateStoreXeGPUCacheHint( |
| op.getL1Hint(), op.getL3Hint()))); |
| rewriter.eraseOp(op); |
| } |
| return success(); |
| } |
| }; |
| |
| class CreateMemDescOpPattern final |
| : public OpConversionPattern<xegpu::CreateMemDescOp> { |
| public: |
| using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| rewriter.replaceOp(op, adaptor.getSource()); |
| return success(); |
| } |
| }; |
| |
| template <typename OpType, |
| typename = std::enable_if_t<llvm::is_one_of< |
| OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> |
| class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { |
| using OpConversionPattern<OpType>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| |
| SmallVector<OpFoldResult> offsets = op.getMixedOffsets(); |
| if (offsets.empty()) |
| return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); |
| |
| auto loc = op.getLoc(); |
| auto ctxt = rewriter.getContext(); |
| Value baseAddr32 = adaptor.getMemDesc(); |
| Value mdescVal = op.getMemDesc(); |
| // Load result or Store value Type can be vector or scalar. |
| Type dataTy; |
| if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { |
| Type resType = op.getResult().getType(); |
| // Some transforms may leave unit dimension in the 2D vector, adaptors do |
| // not catch it for results. |
| if (auto vecType = dyn_cast<VectorType>(resType)) { |
| assert(llvm::count_if(vecType.getShape(), |
| [](int64_t d) { return d != 1; }) <= 1 && |
| "Expected either 1D vector or nD with unit dimensions"); |
| resType = VectorType::get({vecType.getNumElements()}, |
| vecType.getElementType()); |
| } |
| dataTy = resType; |
| } else |
| dataTy = adaptor.getData().getType(); |
| VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy); |
| if (!valOrResVecTy) |
| valOrResVecTy = VectorType::get(1, dataTy); |
| |
| int64_t elemBitWidth = |
| valOrResVecTy.getElementType().getIntOrFloatBitWidth(); |
| // Element type must be multiple of 8 bits. |
| if (elemBitWidth % 8 != 0) |
| return rewriter.notifyMatchFailure( |
| op, "Expected element type bit width to be multiple of 8."); |
| int64_t elemByteSize = elemBitWidth / 8; |
| |
| // Default memory space is SLM. |
| LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); |
| |
| auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); |
| |
| Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); |
| linearOffset = arith::IndexCastUIOp::create( |
| rewriter, loc, rewriter.getI32Type(), linearOffset); |
| Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, |
| linearOffset, elemByteSize); |
| |
| // convert base pointer (i32) to LLVM pointer type |
| Value basePtrLLVM = |
| LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); |
| |
| if (op.getSubgroupBlockIoAttr()) { |
| // if the attribute 'subgroup_block_io' is set to true, it lowers to |
| // xevm.blockload |
| |
| Type intElemTy = rewriter.getIntegerType(elemBitWidth); |
| VectorType intVecTy = |
| VectorType::get(valOrResVecTy.getShape(), intElemTy); |
| |
| if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { |
| Value loadOp = |
| xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); |
| if (intVecTy != valOrResVecTy) { |
| loadOp = |
| vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); |
| } |
| rewriter.replaceOp(op, loadOp); |
| } else { |
| Value dataToStore = adaptor.getData(); |
| if (valOrResVecTy != intVecTy) { |
| dataToStore = |
| vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); |
| } |
| xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, |
| nullptr); |
| rewriter.eraseOp(op); |
| } |
| return success(); |
| } |
| |
| if (valOrResVecTy.getNumElements() >= 1) { |
| auto chipOpt = xegpu::getChipStr(op); |
| if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { |
| // the lowering for chunk load only works for pvc and bmg |
| return rewriter.notifyMatchFailure( |
| op, "The lowering is specific to pvc or bmg."); |
| } |
| } |
| |
| if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { |
| // if the size of valOrResVecTy is 1, it lowers to a scalar load/store |
| // operation. LLVM load/store does not support vector of size 1, so we |
| // need to handle this case separately. |
| auto scalarTy = valOrResVecTy.getElementType(); |
| LLVM::LoadOp loadOp; |
| if (valOrResVecTy.getNumElements() == 1) |
| loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); |
| else |
| loadOp = |
| LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); |
| rewriter.replaceOp(op, loadOp); |
| } else { |
| LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); |
| rewriter.eraseOp(op); |
| } |
| return success(); |
| } |
| }; |
| |
| class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto ctxt = rewriter.getContext(); |
| auto tdescTy = op.getTensorDescType(); |
| Value basePtrI64 = adaptor.getSource(); |
| // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. |
| if (basePtrI64.getType() != rewriter.getI64Type()) |
| basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), |
| basePtrI64); |
| Value offsets = adaptor.getOffsets(); |
| if (offsets) { |
| VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType()); |
| if (offsetsVecTy) { |
| // Offset needs be scalar. |
| return rewriter.notifyMatchFailure(op, |
| "Expected offsets to be a scalar."); |
| } else { |
| int64_t elemBitWidth{0}; |
| int64_t elemByteSize; |
| // Element byte size can come from three sources: |
| if (tdescTy) { |
| // If tensor descriptor is available, we use its element type to |
| // determine element byte size. |
| elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); |
| } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) { |
| // If memref is available, we use its element type to |
| // determine element byte size. |
| elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth(); |
| } else { |
| // Otherwise, we use the provided offset byte alignment. |
| elemByteSize = *op.getOffsetAlignByte(); |
| } |
| if (elemBitWidth != 0) { |
| if (elemBitWidth % 8 != 0) |
| return rewriter.notifyMatchFailure( |
| op, "Expected element type bit width to be multiple of 8."); |
| elemByteSize = elemBitWidth / 8; |
| } |
| basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, |
| elemByteSize); |
| } |
| } |
| // Default memory space is global. |
| LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); |
| // If tensor descriptor is available, we use its memory space. |
| if (tdescTy) |
| ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); |
| // If source is a memref, we use its memory space. |
| if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) { |
| auto addrSpace = memRefTy.getMemorySpaceAsInt(); |
| if (addrSpace != 0) |
| ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); |
| } |
| // Convert base pointer (i64) to LLVM pointer type. |
| Value ptrLLVM = |
| LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); |
| // Create the prefetch op with cache control attribute. |
| xevm::PrefetchOp::create( |
| rewriter, loc, ptrLLVM, |
| xevm::LoadCacheControlAttr::get( |
| ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()))); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| xevm::MemScope memScope{xevm::MemScope::WORKGROUP}; |
| switch (op.getFenceScope()) { |
| case xegpu::FenceScope::Workgroup: |
| memScope = xevm::MemScope::WORKGROUP; |
| break; |
| case xegpu::FenceScope::GPU: |
| memScope = xevm::MemScope::DEVICE; |
| break; |
| } |
| xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL}; |
| switch (op.getMemoryKind()) { |
| case xegpu::MemorySpace::Global: |
| addrSpace = xevm::AddrSpace::GLOBAL; |
| break; |
| case xegpu::MemorySpace::SLM: |
| addrSpace = xevm::AddrSpace::SHARED; |
| break; |
| } |
| xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto ctxt = rewriter.getContext(); |
| auto aTy = cast<VectorType>(op.getLhs().getType()); |
| auto bTy = cast<VectorType>(op.getRhs().getType()); |
| auto resultType = cast<VectorType>(op.getResultType()); |
| |
| auto encodePrecision = [&](Type type) -> xevm::ElemType { |
| if (type == rewriter.getBF16Type()) |
| return xevm::ElemType::BF16; |
| else if (type == rewriter.getF16Type()) |
| return xevm::ElemType::F16; |
| else if (type == rewriter.getTF32Type()) |
| return xevm::ElemType::TF32; |
| else if (type.isInteger(8)) { |
| if (type.isUnsignedInteger()) |
| return xevm::ElemType::U8; |
| return xevm::ElemType::S8; |
| } else if (type == rewriter.getF32Type()) |
| return xevm::ElemType::F32; |
| else if (type.isInteger(32)) |
| return xevm::ElemType::S32; |
| llvm_unreachable("add more support for ElemType"); |
| }; |
| xevm::ElemType precATy = encodePrecision(aTy.getElementType()); |
| xevm::ElemType precBTy = encodePrecision(bTy.getElementType()); |
| Value c = op.getAcc(); |
| if (!c) { |
| auto elementTy = resultType.getElementType(); |
| Attribute initValueAttr; |
| if (isa<FloatType>(elementTy)) |
| initValueAttr = FloatAttr::get(elementTy, 0.0); |
| else |
| initValueAttr = IntegerAttr::get(elementTy, 0); |
| c = arith::ConstantOp::create( |
| rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr)); |
| } |
| |
| Value aVec = op.getLhs(); |
| Value bVec = op.getRhs(); |
| auto cvecty = cast<VectorType>(c.getType()); |
| xevm::ElemType precCTy = encodePrecision(cvecty.getElementType()); |
| xevm::ElemType precDTy = encodePrecision(resultType.getElementType()); |
| VectorType cNty = |
| VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); |
| if (cvecty != cNty) |
| c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); |
| Value dpasRes = xevm::MMAOp::create( |
| rewriter, loc, cNty, aVec, bVec, c, |
| xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize, |
| systolicDepth * |
| getNumOperandsPerDword(precATy)), |
| xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy)); |
| if (cvecty != cNty) |
| dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes); |
| rewriter.replaceOp(op, dpasRes); |
| return success(); |
| } |
| |
| private: |
| static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { |
| switch (pTy) { |
| case xevm::ElemType::TF32: |
| return 1; |
| case xevm::ElemType::BF16: |
| case xevm::ElemType::F16: |
| return 2; |
| case xevm::ElemType::U8: |
| case xevm::ElemType::S8: |
| return 4; |
| default: |
| llvm_unreachable("unsupported xevm::ElemType"); |
| } |
| } |
| }; |
| |
| static std::optional<LLVM::AtomicBinOp> |
| matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) { |
| switch (arithKind) { |
| case arith::AtomicRMWKind::addf: |
| return LLVM::AtomicBinOp::fadd; |
| case arith::AtomicRMWKind::addi: |
| return LLVM::AtomicBinOp::add; |
| case arith::AtomicRMWKind::assign: |
| return LLVM::AtomicBinOp::xchg; |
| case arith::AtomicRMWKind::maximumf: |
| return LLVM::AtomicBinOp::fmax; |
| case arith::AtomicRMWKind::maxs: |
| return LLVM::AtomicBinOp::max; |
| case arith::AtomicRMWKind::maxu: |
| return LLVM::AtomicBinOp::umax; |
| case arith::AtomicRMWKind::minimumf: |
| return LLVM::AtomicBinOp::fmin; |
| case arith::AtomicRMWKind::mins: |
| return LLVM::AtomicBinOp::min; |
| case arith::AtomicRMWKind::minu: |
| return LLVM::AtomicBinOp::umin; |
| case arith::AtomicRMWKind::ori: |
| return LLVM::AtomicBinOp::_or; |
| case arith::AtomicRMWKind::andi: |
| return LLVM::AtomicBinOp::_and; |
| default: |
| return std::nullopt; |
| } |
| } |
| |
| class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto ctxt = rewriter.getContext(); |
| auto tdesc = op.getTensorDesc().getType(); |
| auto ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace())); |
| Value basePtrI64 = arith::IndexCastOp::create( |
| rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc()); |
| Value basePtrLLVM = |
| LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); |
| VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType()); |
| VectorType srcOrDstFlatVecTy = VectorType::get( |
| srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); |
| Value srcFlatVec = vector::ShapeCastOp::create( |
| rewriter, loc, srcOrDstFlatVecTy, op.getValue()); |
| auto atomicKind = matchSimpleAtomicOp(op.getKind()); |
| assert(atomicKind.has_value()); |
| Value resVec = srcFlatVec; |
| for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) { |
| auto val = vector::ExtractOp::create(rewriter, loc, resVec, i); |
| Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), |
| rewriter.getIndexAttr(i)); |
| Value currPtr = |
| LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM, |
| srcOrDstVecTy.getElementType(), basePtrLLVM, idx); |
| Value newVal = |
| LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr, |
| val, LLVM::AtomicOrdering::seq_cst); |
| resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i); |
| } |
| rewriter.replaceOp(op, resVec); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Pass Definition |
| //===----------------------------------------------------------------------===// |
| |
| struct ConvertXeGPUToXeVMPass |
| : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> { |
| using Base::Base; |
| |
| void runOnOperation() override { |
| LLVMTypeConverter typeConverter(&getContext()); |
| typeConverter.addConversion([&](VectorType type) -> Type { |
| unsigned rank = type.getRank(); |
| auto elemType = type.getElementType(); |
| // If the element type is index, convert it to i64. |
| if (llvm::isa<IndexType>(elemType)) |
| elemType = IntegerType::get(&getContext(), 64); |
| // If the vector is a scalar or has a single element, return the element |
| if (rank < 1 || type.getNumElements() == 1) |
| return elemType; |
| // Otherwise, convert the vector to a flat vector type. |
| int64_t sum = llvm::product_of(type.getShape()); |
| return VectorType::get(sum, elemType); |
| }); |
| typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { |
| // Scattered descriptors are not supported in XeVM lowering. |
| if (type.isScattered()) |
| return {}; |
| if (type.getRank() == 1) |
| return IntegerType::get(&getContext(), 64); |
| auto i32Type = IntegerType::get(&getContext(), 32); |
| return VectorType::get(8, i32Type); |
| }); |
| // Convert MemDescType into i32 for SLM |
| typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { |
| return IntegerType::get(&getContext(), 32); |
| }); |
| |
| typeConverter.addConversion([&](MemRefType type) -> Type { |
| return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64)); |
| }); |
| |
| // LLVM type converter puts unrealized casts for the following cases: |
| // add materialization casts to handle them. |
| |
| // Materialization to convert memref to i64 or i32 depending on global/SLM |
| auto memrefMaterializationCast = [](OpBuilder &builder, Type type, |
| ValueRange inputs, |
| Location loc) -> Value { |
| if (inputs.size() != 1) |
| return {}; |
| auto input = inputs.front(); |
| if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) { |
| unsigned rank = memrefTy.getRank(); |
| Type indexType = builder.getIndexType(); |
| |
| int64_t intOffsets; |
| SmallVector<int64_t> intStrides; |
| Value addr; |
| Value offset; |
| if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) && |
| ShapedType::isStatic(intOffsets)) { |
| addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, |
| input); |
| offset = arith::ConstantOp::create(builder, loc, |
| builder.getIndexAttr(intOffsets)); |
| } else { |
| |
| // Result types: [base_memref, offset, stride0, stride1, ..., |
| // strideN-1, size0, size1, ..., sizeN-1] |
| SmallVector<Type> resultTypes{ |
| MemRefType::get({}, memrefTy.getElementType(), |
| MemRefLayoutAttrInterface(), |
| memrefTy.getMemorySpace()), |
| indexType}; |
| // strides + sizes |
| resultTypes.append(2 * rank, indexType); |
| |
| auto meta = memref::ExtractStridedMetadataOp::create( |
| builder, loc, resultTypes, input); |
| |
| addr = memref::ExtractAlignedPointerAsIndexOp::create( |
| builder, loc, meta.getBaseBuffer()); |
| offset = meta.getOffset(); |
| } |
| |
| auto addrCasted = |
| arith::IndexCastUIOp::create(builder, loc, type, addr); |
| auto offsetCasted = |
| arith::IndexCastUIOp::create(builder, loc, type, offset); |
| |
| // Compute the final address: base address + byte offset |
| auto byteSize = arith::ConstantOp::create( |
| builder, loc, type, |
| builder.getIntegerAttr(type, |
| memrefTy.getElementTypeBitWidth() / 8)); |
| auto byteOffset = |
| arith::MulIOp::create(builder, loc, offsetCasted, byteSize); |
| auto addrWithOffset = |
| arith::AddIOp::create(builder, loc, addrCasted, byteOffset); |
| |
| return addrWithOffset.getResult(); |
| } |
| return {}; |
| }; |
| |
| // Materialization to convert ui64 to i64 |
| auto ui64MaterializationCast = [](OpBuilder &builder, Type type, |
| ValueRange inputs, |
| Location loc) -> Value { |
| if (inputs.size() != 1) |
| return {}; |
| auto input = inputs.front(); |
| if (input.getType() == builder.getIntegerType(64, false)) { |
| Value cast = |
| index::CastUOp::create(builder, loc, builder.getIndexType(), input) |
| .getResult(); |
| return arith::IndexCastUIOp::create(builder, loc, type, cast) |
| .getResult(); |
| } |
| return {}; |
| }; |
| |
| // Materialization to convert ui32 to i32 |
| auto ui32MaterializationCast = [](OpBuilder &builder, Type type, |
| ValueRange inputs, |
| Location loc) -> Value { |
| if (inputs.size() != 1) |
| return {}; |
| auto input = inputs.front(); |
| if (input.getType() == builder.getIntegerType(32, false)) { |
| Value cast = |
| index::CastUOp::create(builder, loc, builder.getIndexType(), input) |
| .getResult(); |
| return arith::IndexCastUIOp::create(builder, loc, type, cast) |
| .getResult(); |
| } |
| return {}; |
| }; |
| |
| // Materialization to convert |
| // - single element 1D vector to scalar |
| // - bitcast vector of same rank |
| // - shape vector of different rank but same element type |
| auto vectorMaterializationCast = [](OpBuilder &builder, Type type, |
| ValueRange inputs, |
| Location loc) -> Value { |
| if (inputs.size() != 1) |
| return {}; |
| auto input = inputs.front(); |
| if (auto vecTy = dyn_cast<VectorType>(input.getType())) { |
| if (vecTy.getNumElements() == 1) { |
| // If the vector has a single element, return the element type. |
| Value cast = |
| vector::ExtractOp::create(builder, loc, input, 0).getResult(); |
| if (vecTy.getElementType() == builder.getIndexType()) |
| cast = arith::IndexCastUIOp::create(builder, loc, type, cast) |
| .getResult(); |
| return cast; |
| } else if (auto targetVecTy = dyn_cast<VectorType>(type)) { |
| // If the target type is a vector of same rank, |
| // bitcast to the target type. |
| if (targetVecTy.getRank() == vecTy.getRank()) |
| return vector::BitCastOp::create(builder, loc, targetVecTy, input) |
| .getResult(); |
| else if (targetVecTy.getElementType() == vecTy.getElementType()) { |
| // If the target type is a vector of different rank but same element |
| // type, reshape to the target type. |
| return vector::ShapeCastOp::create(builder, loc, targetVecTy, input) |
| .getResult(); |
| } |
| } |
| } |
| return {}; |
| }; |
| |
| // If result type of original op is single element vector and lowered type |
| // is scalar. This materialization cast creates a single element vector by |
| // broadcasting the scalar value. |
| auto singleElementVectorMaterializationCast = |
| [](OpBuilder &builder, Type type, ValueRange inputs, |
| Location loc) -> Value { |
| if (inputs.size() != 1) |
| return {}; |
| auto input = inputs.front(); |
| if (input.getType().isIntOrIndexOrFloat()) { |
| // If the input is a scalar, and the target type is a vector of single |
| // element, create a single element vector by broadcasting. |
| if (auto vecTy = dyn_cast<VectorType>(type)) { |
| if (vecTy.getNumElements() == 1) { |
| return vector::BroadcastOp::create(builder, loc, vecTy, input) |
| .getResult(); |
| } |
| } |
| } |
| return {}; |
| }; |
| typeConverter.addSourceMaterialization( |
| singleElementVectorMaterializationCast); |
| typeConverter.addSourceMaterialization(vectorMaterializationCast); |
| typeConverter.addTargetMaterialization(memrefMaterializationCast); |
| typeConverter.addTargetMaterialization(ui32MaterializationCast); |
| typeConverter.addTargetMaterialization(ui64MaterializationCast); |
| typeConverter.addTargetMaterialization(vectorMaterializationCast); |
| ConversionTarget target(getContext()); |
| target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect, |
| vector::VectorDialect, arith::ArithDialect, |
| memref::MemRefDialect, gpu::GPUDialect, |
| index::IndexDialect>(); |
| target.addIllegalDialect<xegpu::XeGPUDialect>(); |
| |
| RewritePatternSet patterns(&getContext()); |
| populateXeGPUToXeVMConversionPatterns(typeConverter, patterns); |
| scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, |
| patterns, target); |
| if (failed(applyPartialConversion(getOperation(), target, |
| std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern Population |
| //===----------------------------------------------------------------------===// |
| void mlir::populateXeGPUToXeVMConversionPatterns( |
| const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| patterns.add<CreateNdDescToXeVMPattern, |
| LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>, |
| LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>, |
| LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>( |
| typeConverter, patterns.getContext()); |
| patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern, |
| LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, |
| LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( |
| typeConverter, patterns.getContext()); |
| patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>, |
| LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>, |
| CreateMemDescOpPattern>(typeConverter, patterns.getContext()); |
| patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter, |
| patterns.getContext()); |
| } |