| //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements patterns to convert MemRef dialect to SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "memref-to-spirv-pattern" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the offset of the value in `targetBits` representation. |
| /// |
| /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. |
| /// It's assumed to be non-negative. |
| /// |
| /// When accessing an element in the array treating as having elements of |
| /// `targetBits`, multiple values are loaded in the same time. The method |
| /// returns the offset where the `srcIdx` locates in the value. For example, if |
| /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is |
| /// located at (x % 4) * 8. Because there are four elements in one i32, and one |
| /// element has 8 bits. |
| static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, |
| int targetBits, OpBuilder &builder) { |
| assert(targetBits % sourceBits == 0); |
| IntegerType targetType = builder.getIntegerType(targetBits); |
| IntegerAttr idxAttr = |
| builder.getIntegerAttr(targetType, targetBits / sourceBits); |
| auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr); |
| IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); |
| auto srcBitsValue = |
| builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr); |
| auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx); |
| return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue); |
| } |
| |
| /// Returns an adjusted spirv::AccessChainOp. Based on the |
| /// extension/capabilities, certain integer bitwidths `sourceBits` might not be |
| /// supported. During conversion if a memref of an unsupported type is used, |
| /// load/stores to this memref need to be modified to use a supported higher |
| /// bitwidth `targetBits` and extracting the required bits. For an accessing a |
| /// 1D array (spv.array or spv.rt_array), the last index is modified to load the |
| /// bits needed. The extraction of the actual bits needed are handled |
| /// separately. Note that this only works for a 1-D tensor. |
| static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, |
| spirv::AccessChainOp op, |
| int sourceBits, int targetBits, |
| OpBuilder &builder) { |
| assert(targetBits % sourceBits == 0); |
| const auto loc = op.getLoc(); |
| IntegerType targetType = builder.getIntegerType(targetBits); |
| IntegerAttr attr = |
| builder.getIntegerAttr(targetType, targetBits / sourceBits); |
| auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr); |
| auto lastDim = op->getOperand(op.getNumOperands() - 1); |
| auto indices = llvm::to_vector<4>(op.indices()); |
| // There are two elements if this is a 1-D tensor. |
| assert(indices.size() == 2); |
| indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx); |
| Type t = typeConverter.convertType(op.component_ptr().getType()); |
| return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices); |
| } |
| |
| /// Returns the shifted `targetBits`-bit value with the given offset. |
| static Value shiftValue(Location loc, Value value, Value offset, Value mask, |
| int targetBits, OpBuilder &builder) { |
| Type targetType = builder.getIntegerType(targetBits); |
| Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask); |
| return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result, |
| offset); |
| } |
| |
| /// Returns true if the allocations of type `t` can be lowered to SPIR-V. |
| static bool isAllocationSupported(MemRefType t) { |
| // Currently only support workgroup local memory allocations with static |
| // shape and int or float or vector of int or float element type. |
| if (!(t.hasStaticShape() && |
| SPIRVTypeConverter::getMemorySpaceForStorageClass( |
| spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) |
| return false; |
| Type elementType = t.getElementType(); |
| if (auto vecType = elementType.dyn_cast<VectorType>()) |
| elementType = vecType.getElementType(); |
| return elementType.isIntOrFloat(); |
| } |
| |
| /// Returns the scope to use for atomic operations use for emulating store |
| /// operations of unsupported integer bitwidths, based on the memref |
| /// type. Returns None on failure. |
| static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) { |
| Optional<spirv::StorageClass> storageClass = |
| SPIRVTypeConverter::getStorageClassForMemorySpace( |
| t.getMemorySpaceAsInt()); |
| if (!storageClass) |
| return {}; |
| switch (*storageClass) { |
| case spirv::StorageClass::StorageBuffer: |
| return spirv::Scope::Device; |
| case spirv::StorageClass::Workgroup: |
| return spirv::Scope::Workgroup; |
| default: { |
| } |
| } |
| return {}; |
| } |
| |
| /// Casts the given `srcInt` into a boolean value. |
| static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { |
| if (srcInt.getType().isInteger(1)) |
| return srcInt; |
| |
| auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); |
| return builder.create<spirv::IEqualOp>(loc, srcInt, one); |
| } |
| |
| /// Casts the given `srcBool` into an integer of `dstType`. |
| static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, |
| OpBuilder &builder) { |
| assert(srcBool.getType().isInteger(1)); |
| if (dstType.isInteger(1)) |
| return srcBool; |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, builder); |
| return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Operation conversion |
| //===----------------------------------------------------------------------===// |
| |
| // Note that DRR cannot be used for the patterns in this file: we may need to |
| // convert type along the way, which requires ConversionPattern. DRR generates |
| // normal RewritePattern. |
| |
| namespace { |
| |
| /// Converts an allocation operation to SPIR-V. Currently only supports lowering |
| /// to Workgroup memory when the size is constant. Note that this pattern needs |
| /// to be applied in a pass that runs at least at spv.module scope since it wil |
| /// ladd global variables into the spv.module. |
| class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { |
| public: |
| using OpConversionPattern<memref::AllocOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Removed a deallocation if it is a supported allocation. Currently only |
| /// removes deallocation if the memory space is workgroup memory. |
| class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { |
| public: |
| using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts memref.load to spv.Load. |
| class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
| public: |
| using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts memref.load to spv.Load. |
| class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { |
| public: |
| using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts memref.store to spv.Store on integers. |
| class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
| public: |
| using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts memref.store to spv.Store. |
| class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { |
| public: |
| using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // AllocOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| MemRefType allocType = operation.getType(); |
| if (!isAllocationSupported(allocType)) |
| return operation.emitError("unhandled allocation type"); |
| |
| // Get the SPIR-V type for the allocation. |
| Type spirvType = getTypeConverter()->convertType(allocType); |
| |
| // Insert spv.GlobalVariable for this allocation. |
| Operation *parent = |
| SymbolTable::getNearestSymbolTable(operation->getParentOp()); |
| if (!parent) |
| return failure(); |
| Location loc = operation.getLoc(); |
| spirv::GlobalVariableOp varOp; |
| { |
| OpBuilder::InsertionGuard guard(rewriter); |
| Block &entryBlock = *parent->getRegion(0).begin(); |
| rewriter.setInsertionPointToStart(&entryBlock); |
| auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); |
| std::string varName = |
| std::string("__workgroup_mem__") + |
| std::to_string(std::distance(varOps.begin(), varOps.end())); |
| varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, |
| /*initializer=*/nullptr); |
| } |
| |
| // Get pointer to global variable at the current scope. |
| rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DeallocOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, |
| OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| MemRefType deallocType = operation.memref().getType().cast<MemRefType>(); |
| if (!isAllocationSupported(deallocType)) |
| return operation.emitError("unhandled deallocation type"); |
| rewriter.eraseOp(operation); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LoadOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto loc = loadOp.getLoc(); |
| auto memrefType = loadOp.memref().getType().cast<MemRefType>(); |
| if (!memrefType.getElementType().isSignlessInteger()) |
| return failure(); |
| |
| auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| spirv::AccessChainOp accessChainOp = |
| spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), |
| adaptor.indices(), loc, rewriter); |
| |
| if (!accessChainOp) |
| return failure(); |
| |
| int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
| bool isBool = srcBits == 1; |
| if (isBool) |
| srcBits = typeConverter.getOptions().boolNumBits; |
| Type pointeeType = typeConverter.convertType(memrefType) |
| .cast<spirv::PointerType>() |
| .getPointeeType(); |
| Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); |
| Type dstType; |
| if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) |
| dstType = arrayType.getElementType(); |
| else |
| dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); |
| |
| int dstBits = dstType.getIntOrFloatBitWidth(); |
| assert(dstBits % srcBits == 0); |
| |
| // If the rewrited load op has the same bit width, use the loading value |
| // directly. |
| if (srcBits == dstBits) { |
| Value loadVal = |
| rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult()); |
| if (isBool) |
| loadVal = castIntNToBool(loc, loadVal, rewriter); |
| rewriter.replaceOp(loadOp, loadVal); |
| return success(); |
| } |
| |
| // Assume that getElementPtr() works linearizely. If it's a scalar, the method |
| // still returns a linearized accessing. If the accessing is not linearized, |
| // there will be offset issues. |
| assert(accessChainOp.indices().size() == 2); |
| Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
| srcBits, dstBits, rewriter); |
| Value spvLoadOp = rewriter.create<spirv::LoadOp>( |
| loc, dstType, adjustedPtr, |
| loadOp->getAttrOfType<spirv::MemoryAccessAttr>( |
| spirv::attributeName<spirv::MemoryAccess>()), |
| loadOp->getAttrOfType<IntegerAttr>("alignment")); |
| |
| // Shift the bits to the rightmost. |
| // ____XXXX________ -> ____________XXXX |
| Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
| Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
| Value result = rewriter.create<spirv::ShiftRightArithmeticOp>( |
| loc, spvLoadOp.getType(), spvLoadOp, offset); |
| |
| // Apply the mask to extract corresponding bits. |
| Value mask = rewriter.create<spirv::ConstantOp>( |
| loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
| result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask); |
| |
| // Apply sign extension on the loading value unconditionally. The signedness |
| // semantic is carried in the operator itself, we relies other pattern to |
| // handle the casting. |
| IntegerAttr shiftValueAttr = |
| rewriter.getIntegerAttr(dstType, dstBits - srcBits); |
| Value shiftValue = |
| rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr); |
| result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result, |
| shiftValue); |
| result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result, |
| shiftValue); |
| |
| if (isBool) { |
| dstType = typeConverter.convertType(loadOp.getType()); |
| mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); |
| result = rewriter.create<spirv::IEqualOp>(loc, result, mask); |
| } else if (result.getType().getIntOrFloatBitWidth() != |
| static_cast<unsigned>(dstBits)) { |
| result = rewriter.create<spirv::SConvertOp>(loc, dstType, result); |
| } |
| rewriter.replaceOp(loadOp, result); |
| |
| assert(accessChainOp.use_empty()); |
| rewriter.eraseOp(accessChainOp); |
| |
| return success(); |
| } |
| |
| LogicalResult |
| LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto memrefType = loadOp.memref().getType().cast<MemRefType>(); |
| if (memrefType.getElementType().isSignlessInteger()) |
| return failure(); |
| auto loadPtr = spirv::getElementPtr( |
| *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(), |
| adaptor.indices(), loadOp.getLoc(), rewriter); |
| |
| if (!loadPtr) |
| return failure(); |
| |
| rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr); |
| return success(); |
| } |
| |
| LogicalResult |
| IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto memrefType = storeOp.memref().getType().cast<MemRefType>(); |
| if (!memrefType.getElementType().isSignlessInteger()) |
| return failure(); |
| |
| auto loc = storeOp.getLoc(); |
| auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| spirv::AccessChainOp accessChainOp = |
| spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), |
| adaptor.indices(), loc, rewriter); |
| |
| if (!accessChainOp) |
| return failure(); |
| |
| int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); |
| |
| bool isBool = srcBits == 1; |
| if (isBool) |
| srcBits = typeConverter.getOptions().boolNumBits; |
| |
| Type pointeeType = typeConverter.convertType(memrefType) |
| .cast<spirv::PointerType>() |
| .getPointeeType(); |
| Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0); |
| Type dstType; |
| if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>()) |
| dstType = arrayType.getElementType(); |
| else |
| dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType(); |
| |
| int dstBits = dstType.getIntOrFloatBitWidth(); |
| assert(dstBits % srcBits == 0); |
| |
| if (srcBits == dstBits) { |
| Value storeVal = adaptor.value(); |
| if (isBool) |
| storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::StoreOp>( |
| storeOp, accessChainOp.getResult(), storeVal); |
| return success(); |
| } |
| |
| // Since there are multi threads in the processing, the emulation will be done |
| // with atomic operations. E.g., if the storing value is i8, rewrite the |
| // StoreOp to |
| // 1) load a 32-bit integer |
| // 2) clear 8 bits in the loading value |
| // 3) store 32-bit value back |
| // 4) load a 32-bit integer |
| // 5) modify 8 bits in the loading value |
| // 6) store 32-bit value back |
| // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step |
| // 4 to step 6 are done by AtomicOr as another atomic step. |
| assert(accessChainOp.indices().size() == 2); |
| Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); |
| Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); |
| |
| // Create a mask to clear the destination. E.g., if it is the second i8 in |
| // i32, 0xFFFF00FF is created. |
| Value mask = rewriter.create<spirv::ConstantOp>( |
| loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); |
| Value clearBitsMask = |
| rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset); |
| clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask); |
| |
| Value storeVal = adaptor.value(); |
| if (isBool) |
| storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); |
| storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); |
| Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, |
| srcBits, dstBits, rewriter); |
| Optional<spirv::Scope> scope = getAtomicOpScope(memrefType); |
| if (!scope) |
| return failure(); |
| Value result = rewriter.create<spirv::AtomicAndOp>( |
| loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
| clearBitsMask); |
| result = rewriter.create<spirv::AtomicOrOp>( |
| loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, |
| storeVal); |
| |
| // The AtomicOrOp has no side effect. Since it is already inserted, we can |
| // just remove the original StoreOp. Note that rewriter.replaceOp() |
| // doesn't work because it only accepts that the numbers of result are the |
| // same. |
| rewriter.eraseOp(storeOp); |
| |
| assert(accessChainOp.use_empty()); |
| rewriter.eraseOp(accessChainOp); |
| |
| return success(); |
| } |
| |
| LogicalResult |
| StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto memrefType = storeOp.memref().getType().cast<MemRefType>(); |
| if (memrefType.getElementType().isSignlessInteger()) |
| return failure(); |
| auto storePtr = spirv::getElementPtr( |
| *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(), |
| adaptor.indices(), storeOp.getLoc(), rewriter); |
| |
| if (!storePtr) |
| return failure(); |
| |
| rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, |
| adaptor.value()); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern population |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
| RewritePatternSet &patterns) { |
| patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, |
| IntStoreOpPattern, LoadOpPattern, StoreOpPattern>( |
| typeConverter, patterns.getContext()); |
| } |
| } // namespace mlir |