| //===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a translation between the MLIR `ptr` dialect and |
| // LLVM IR. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h" |
| #include "mlir/Dialect/Ptr/IR/PtrOps.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/IR/Value.h" |
| |
| using namespace mlir; |
| using namespace mlir::ptr; |
| |
| namespace { |
| |
| /// Converts ptr::AtomicOrdering to llvm::AtomicOrdering |
| static llvm::AtomicOrdering |
| translateAtomicOrdering(ptr::AtomicOrdering ordering) { |
| switch (ordering) { |
| case ptr::AtomicOrdering::not_atomic: |
| return llvm::AtomicOrdering::NotAtomic; |
| case ptr::AtomicOrdering::unordered: |
| return llvm::AtomicOrdering::Unordered; |
| case ptr::AtomicOrdering::monotonic: |
| return llvm::AtomicOrdering::Monotonic; |
| case ptr::AtomicOrdering::acquire: |
| return llvm::AtomicOrdering::Acquire; |
| case ptr::AtomicOrdering::release: |
| return llvm::AtomicOrdering::Release; |
| case ptr::AtomicOrdering::acq_rel: |
| return llvm::AtomicOrdering::AcquireRelease; |
| case ptr::AtomicOrdering::seq_cst: |
| return llvm::AtomicOrdering::SequentiallyConsistent; |
| } |
| llvm_unreachable("Unknown atomic ordering"); |
| } |
| |
| /// Translate ptr.ptr_add operation to LLVM IR. |
| static LogicalResult |
| translatePtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase()); |
| llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset()); |
| |
| if (!basePtr || !offset) |
| return ptrAddOp.emitError("Failed to lookup operands"); |
| |
| // Create the GEP flags |
| llvm::GEPNoWrapFlags gepFlags; |
| switch (ptrAddOp.getFlags()) { |
| case ptr::PtrAddFlags::none: |
| break; |
| case ptr::PtrAddFlags::nusw: |
| gepFlags = llvm::GEPNoWrapFlags::noUnsignedSignedWrap(); |
| break; |
| case ptr::PtrAddFlags::nuw: |
| gepFlags = llvm::GEPNoWrapFlags::noUnsignedWrap(); |
| break; |
| case ptr::PtrAddFlags::inbounds: |
| gepFlags = llvm::GEPNoWrapFlags::inBounds(); |
| break; |
| } |
| |
| // Create GEP instruction for pointer arithmetic |
| llvm::Value *gep = |
| builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}, "", gepFlags); |
| |
| moduleTranslation.mapValue(ptrAddOp.getResult(), gep); |
| return success(); |
| } |
| |
| /// Translate ptr.load operation to LLVM IR. |
| static LogicalResult |
| translateLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr()); |
| if (!ptr) |
| return loadOp.emitError("Failed to lookup pointer operand"); |
| |
| // Translate result type to LLVM type |
| llvm::Type *resultType = |
| moduleTranslation.convertType(loadOp.getValue().getType()); |
| if (!resultType) |
| return loadOp.emitError("Failed to translate result type"); |
| |
| // Create the load instruction. |
| llvm::MaybeAlign alignment(loadOp.getAlignment().value_or(0)); |
| llvm::LoadInst *loadInst = builder.CreateAlignedLoad( |
| resultType, ptr, alignment, loadOp.getVolatile_()); |
| |
| // Set op flags and metadata. |
| loadInst->setAtomic(translateAtomicOrdering(loadOp.getOrdering())); |
| // Set sync scope if specified |
| if (loadOp.getSyncscope().has_value()) { |
| llvm::LLVMContext &ctx = builder.getContext(); |
| llvm::SyncScope::ID syncScope = |
| ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value()); |
| loadInst->setSyncScopeID(syncScope); |
| } |
| |
| // Set metadata for nontemporal, invariant, and invariant_group |
| if (loadOp.getNontemporal()) { |
| llvm::MDNode *nontemporalMD = |
| llvm::MDNode::get(builder.getContext(), |
| llvm::ConstantAsMetadata::get(builder.getInt32(1))); |
| loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD); |
| } |
| |
| if (loadOp.getInvariant()) { |
| llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {}); |
| loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD); |
| } |
| |
| if (loadOp.getInvariantGroup()) { |
| llvm::MDNode *invariantGroupMD = |
| llvm::MDNode::get(builder.getContext(), {}); |
| loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group, |
| invariantGroupMD); |
| } |
| |
| moduleTranslation.mapValue(loadOp.getResult(), loadInst); |
| return success(); |
| } |
| |
| /// Translate ptr.store operation to LLVM IR. |
| static LogicalResult |
| translateStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue()); |
| llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr()); |
| |
| if (!value || !ptr) |
| return storeOp.emitError("Failed to lookup operands"); |
| |
| // Create the store instruction. |
| llvm::MaybeAlign alignment(storeOp.getAlignment().value_or(0)); |
| llvm::StoreInst *storeInst = |
| builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_()); |
| |
| // Set op flags and metadata. |
| storeInst->setAtomic(translateAtomicOrdering(storeOp.getOrdering())); |
| // Set sync scope if specified |
| if (storeOp.getSyncscope().has_value()) { |
| llvm::LLVMContext &ctx = builder.getContext(); |
| llvm::SyncScope::ID syncScope = |
| ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value()); |
| storeInst->setSyncScopeID(syncScope); |
| } |
| |
| // Set metadata for nontemporal and invariant_group |
| if (storeOp.getNontemporal()) { |
| llvm::MDNode *nontemporalMD = |
| llvm::MDNode::get(builder.getContext(), |
| llvm::ConstantAsMetadata::get(builder.getInt32(1))); |
| storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD); |
| } |
| |
| if (storeOp.getInvariantGroup()) { |
| llvm::MDNode *invariantGroupMD = |
| llvm::MDNode::get(builder.getContext(), {}); |
| storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group, |
| invariantGroupMD); |
| } |
| |
| return success(); |
| } |
| |
| /// Translate ptr.type_offset operation to LLVM IR. |
| static LogicalResult |
| translateTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| // Translate the element type to LLVM type |
| llvm::Type *elementType = |
| moduleTranslation.convertType(typeOffsetOp.getElementType()); |
| if (!elementType) |
| return typeOffsetOp.emitError("Failed to translate the element type"); |
| |
| // Translate result type |
| llvm::Type *resultType = |
| moduleTranslation.convertType(typeOffsetOp.getResult().getType()); |
| if (!resultType) |
| return typeOffsetOp.emitError("Failed to translate the result type"); |
| |
| // Use GEP with null pointer to compute type size/offset. |
| llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0)); |
| llvm::Value *offsetPtr = |
| builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)}); |
| llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType); |
| |
| moduleTranslation.mapValue(typeOffsetOp.getResult(), offset); |
| return success(); |
| } |
| |
| /// Translate ptr.gather operation to LLVM IR. |
| static LogicalResult |
| translateGatherOp(GatherOp gatherOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *ptrs = moduleTranslation.lookupValue(gatherOp.getPtrs()); |
| llvm::Value *mask = moduleTranslation.lookupValue(gatherOp.getMask()); |
| llvm::Value *passthrough = |
| moduleTranslation.lookupValue(gatherOp.getPassthrough()); |
| |
| if (!ptrs || !mask || !passthrough) |
| return gatherOp.emitError("Failed to lookup operands"); |
| |
| // Translate result type to LLVM type. |
| llvm::Type *resultType = |
| moduleTranslation.convertType(gatherOp.getResult().getType()); |
| if (!resultType) |
| return gatherOp.emitError("Failed to translate result type"); |
| |
| // Get the alignment. |
| llvm::MaybeAlign alignment(gatherOp.getAlignment().value_or(0)); |
| |
| // Create the masked gather intrinsic call. |
| llvm::Value *result = builder.CreateMaskedGather( |
| resultType, ptrs, alignment.valueOrOne(), mask, passthrough); |
| |
| moduleTranslation.mapValue(gatherOp.getResult(), result); |
| return success(); |
| } |
| |
| /// Translate ptr.masked_load operation to LLVM IR. |
| static LogicalResult |
| translateMaskedLoadOp(MaskedLoadOp maskedLoadOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *ptr = moduleTranslation.lookupValue(maskedLoadOp.getPtr()); |
| llvm::Value *mask = moduleTranslation.lookupValue(maskedLoadOp.getMask()); |
| llvm::Value *passthrough = |
| moduleTranslation.lookupValue(maskedLoadOp.getPassthrough()); |
| |
| if (!ptr || !mask || !passthrough) |
| return maskedLoadOp.emitError("Failed to lookup operands"); |
| |
| // Translate result type to LLVM type. |
| llvm::Type *resultType = |
| moduleTranslation.convertType(maskedLoadOp.getResult().getType()); |
| if (!resultType) |
| return maskedLoadOp.emitError("Failed to translate result type"); |
| |
| // Get the alignment. |
| llvm::MaybeAlign alignment(maskedLoadOp.getAlignment().value_or(0)); |
| |
| // Create the masked load intrinsic call. |
| llvm::Value *result = builder.CreateMaskedLoad( |
| resultType, ptr, alignment.valueOrOne(), mask, passthrough); |
| |
| moduleTranslation.mapValue(maskedLoadOp.getResult(), result); |
| return success(); |
| } |
| |
| /// Translate ptr.masked_store operation to LLVM IR. |
| static LogicalResult |
| translateMaskedStoreOp(MaskedStoreOp maskedStoreOp, |
| llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *value = moduleTranslation.lookupValue(maskedStoreOp.getValue()); |
| llvm::Value *ptr = moduleTranslation.lookupValue(maskedStoreOp.getPtr()); |
| llvm::Value *mask = moduleTranslation.lookupValue(maskedStoreOp.getMask()); |
| |
| if (!value || !ptr || !mask) |
| return maskedStoreOp.emitError("Failed to lookup operands"); |
| |
| // Get the alignment. |
| llvm::MaybeAlign alignment(maskedStoreOp.getAlignment().value_or(0)); |
| |
| // Create the masked store intrinsic call. |
| builder.CreateMaskedStore(value, ptr, alignment.valueOrOne(), mask); |
| return success(); |
| } |
| |
| /// Translate ptr.scatter operation to LLVM IR. |
| static LogicalResult |
| translateScatterOp(ScatterOp scatterOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *value = moduleTranslation.lookupValue(scatterOp.getValue()); |
| llvm::Value *ptrs = moduleTranslation.lookupValue(scatterOp.getPtrs()); |
| llvm::Value *mask = moduleTranslation.lookupValue(scatterOp.getMask()); |
| |
| if (!value || !ptrs || !mask) |
| return scatterOp.emitError("Failed to lookup operands"); |
| |
| // Get the alignment. |
| llvm::MaybeAlign alignment(scatterOp.getAlignment().value_or(0)); |
| |
| // Create the masked scatter intrinsic call. |
| builder.CreateMaskedScatter(value, ptrs, alignment.valueOrOne(), mask); |
| return success(); |
| } |
| |
| /// Translate ptr.constant operation to LLVM IR. |
| static LogicalResult |
| translateConstantOp(ConstantOp constantOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| // Translate result type to LLVM type |
| llvm::PointerType *resultType = dyn_cast_or_null<llvm::PointerType>( |
| moduleTranslation.convertType(constantOp.getResult().getType())); |
| if (!resultType) |
| return constantOp.emitError("Expected a valid pointer type"); |
| |
| llvm::Value *result = nullptr; |
| |
| TypedAttr value = constantOp.getValue(); |
| if (auto nullAttr = dyn_cast<ptr::NullAttr>(value)) { |
| // Create a null pointer constant |
| result = llvm::ConstantPointerNull::get(resultType); |
| } else if (auto addressAttr = dyn_cast<ptr::AddressAttr>(value)) { |
| // Create an integer constant and translate it to pointer |
| llvm::APInt addressValue = addressAttr.getValue(); |
| |
| // Determine the integer type width based on the target's pointer size |
| llvm::DataLayout dataLayout = |
| moduleTranslation.getLLVMModule()->getDataLayout(); |
| unsigned pointerSizeInBits = |
| dataLayout.getPointerSizeInBits(resultType->getAddressSpace()); |
| |
| // Extend or truncate the address value to match pointer size if needed |
| if (addressValue.getBitWidth() != pointerSizeInBits) { |
| if (addressValue.getBitWidth() > pointerSizeInBits) { |
| constantOp.emitWarning() |
| << "Truncating address value to fit pointer size"; |
| } |
| addressValue = addressValue.getBitWidth() < pointerSizeInBits |
| ? addressValue.zext(pointerSizeInBits) |
| : addressValue.trunc(pointerSizeInBits); |
| } |
| |
| // Create integer constant and translate to pointer |
| llvm::Type *intType = builder.getIntNTy(pointerSizeInBits); |
| llvm::Value *intValue = llvm::ConstantInt::get(intType, addressValue); |
| result = builder.CreateIntToPtr(intValue, resultType); |
| } else { |
| return constantOp.emitError("Unsupported constant attribute type"); |
| } |
| |
| moduleTranslation.mapValue(constantOp.getResult(), result); |
| return success(); |
| } |
| |
| /// Translate ptr.ptr_diff operation operation to LLVM IR. |
| static LogicalResult |
| translatePtrDiffOp(PtrDiffOp ptrDiffOp, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Value *lhs = moduleTranslation.lookupValue(ptrDiffOp.getLhs()); |
| llvm::Value *rhs = moduleTranslation.lookupValue(ptrDiffOp.getRhs()); |
| |
| if (!lhs || !rhs) |
| return ptrDiffOp.emitError("Failed to lookup operands"); |
| |
| // Translate result type to LLVM type |
| llvm::Type *resultType = |
| moduleTranslation.convertType(ptrDiffOp.getResult().getType()); |
| if (!resultType) |
| return ptrDiffOp.emitError("Failed to translate result type"); |
| |
| PtrDiffFlags flags = ptrDiffOp.getFlags(); |
| |
| // Convert both pointers to integers using ptrtoaddr, and compute the |
| // difference: lhs - rhs |
| llvm::Value *llLhs = builder.CreatePtrToAddr(lhs); |
| llvm::Value *llRhs = builder.CreatePtrToAddr(rhs); |
| llvm::Value *result = builder.CreateSub( |
| llLhs, llRhs, /*Name=*/"", |
| /*HasNUW=*/(flags & PtrDiffFlags::nuw) == PtrDiffFlags::nuw, |
| /*HasNSW=*/(flags & PtrDiffFlags::nsw) == PtrDiffFlags::nsw); |
| |
| // Convert the difference to the expected result type by truncating or |
| // extending. |
| if (result->getType() != resultType) |
| result = builder.CreateIntCast(result, resultType, /*isSigned=*/true); |
| |
| moduleTranslation.mapValue(ptrDiffOp.getResult(), result); |
| return success(); |
| } |
| |
| /// Implementation of the dialect interface that translates operations belonging |
| /// to the `ptr` dialect to LLVM IR. |
| class PtrDialectLLVMIRTranslationInterface |
| : public LLVMTranslationDialectInterface { |
| public: |
| using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
| |
| /// Translates the given operation to LLVM IR using the provided IR builder |
| /// and saving the state in `moduleTranslation`. |
| LogicalResult |
| convertOperation(Operation *op, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) const final { |
| |
| return llvm::TypeSwitch<Operation *, LogicalResult>(op) |
| .Case([&](ConstantOp constantOp) { |
| return translateConstantOp(constantOp, builder, moduleTranslation); |
| }) |
| .Case([&](PtrAddOp ptrAddOp) { |
| return translatePtrAddOp(ptrAddOp, builder, moduleTranslation); |
| }) |
| .Case([&](PtrDiffOp ptrDiffOp) { |
| return translatePtrDiffOp(ptrDiffOp, builder, moduleTranslation); |
| }) |
| .Case([&](LoadOp loadOp) { |
| return translateLoadOp(loadOp, builder, moduleTranslation); |
| }) |
| .Case([&](StoreOp storeOp) { |
| return translateStoreOp(storeOp, builder, moduleTranslation); |
| }) |
| .Case([&](TypeOffsetOp typeOffsetOp) { |
| return translateTypeOffsetOp(typeOffsetOp, builder, |
| moduleTranslation); |
| }) |
| .Case([&](GatherOp gatherOp) { |
| return translateGatherOp(gatherOp, builder, moduleTranslation); |
| }) |
| .Case([&](MaskedLoadOp maskedLoadOp) { |
| return translateMaskedLoadOp(maskedLoadOp, builder, |
| moduleTranslation); |
| }) |
| .Case([&](MaskedStoreOp maskedStoreOp) { |
| return translateMaskedStoreOp(maskedStoreOp, builder, |
| moduleTranslation); |
| }) |
| .Case([&](ScatterOp scatterOp) { |
| return translateScatterOp(scatterOp, builder, moduleTranslation); |
| }) |
| .Default([&](Operation *op) { |
| return op->emitError("Translation for operation '") |
| << op->getName() << "' is not implemented."; |
| }); |
| } |
| |
| /// Attaches module-level metadata for functions marked as kernels. |
| LogicalResult |
| amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, |
| NamedAttribute attribute, |
| LLVM::ModuleTranslation &moduleTranslation) const final { |
| // No special amendments needed for ptr dialect operations |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::registerPtrDialectTranslation(DialectRegistry ®istry) { |
| registry.insert<ptr::PtrDialect>(); |
| registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { |
| dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>(); |
| }); |
| } |
| |
| void mlir::registerPtrDialectTranslation(MLIRContext &context) { |
| DialectRegistry registry; |
| registerPtrDialectTranslation(registry); |
| context.appendDialectRegistry(registry); |
| } |