blob: 5e8e1c55d91c64c7003242827d413f82f56d0b1e [file] [log] [blame]
//===------------ SPIRVMapping.h - SPIR-V Duplicates Tracker ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// General infrastructure for keeping track of the values that according to
// the SPIR-V binary layout should be global to the whole module.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include <type_traits>
namespace llvm {
namespace SPIRV {
inline size_t to_hash(const MachineInstr *MI) {
hash_code H = llvm::hash_combine(MI->getOpcode(), MI->getNumOperands());
for (unsigned I = MI->getNumDefs(); I < MI->getNumOperands(); ++I) {
const MachineOperand &MO = MI->getOperand(I);
if (MO.getType() == MachineOperand::MO_CImmediate)
H = llvm::hash_combine(H, MO.getType(), MO.getCImm());
else if (MO.getType() == MachineOperand::MO_FPImmediate)
H = llvm::hash_combine(H, MO.getType(), MO.getFPImm());
else
H = llvm::hash_combine(H, MO.getType());
}
return H;
}
using MIHandle = std::tuple<const MachineInstr *, Register, size_t>;
inline MIHandle getMIKey(const MachineInstr *MI) {
return std::make_tuple(MI, MI->getOperand(0).getReg(), SPIRV::to_hash(MI));
}
using IRHandle = std::tuple<const void *, unsigned, unsigned>;
using IRHandleMF = std::pair<IRHandle, const MachineFunction *>;
inline IRHandleMF getIRHandleMF(IRHandle Handle, const MachineFunction *MF) {
return std::make_pair(Handle, MF);
}
enum SpecialTypeKind {
STK_Empty = 0,
STK_Image,
STK_SampledImage,
STK_Sampler,
STK_Pipe,
STK_DeviceEvent,
STK_ElementPointer,
STK_Type,
STK_Value,
STK_MachineInstr,
STK_VkBuffer,
STK_Last = -1
};
union ImageAttrs {
struct BitFlags {
unsigned Dim : 3;
unsigned Depth : 2;
unsigned Arrayed : 1;
unsigned MS : 1;
unsigned Sampled : 2;
unsigned ImageFormat : 6;
unsigned AQ : 2;
} Flags;
unsigned Val;
ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
Val = 0;
Flags.Dim = Dim;
Flags.Depth = Depth;
Flags.Arrayed = Arrayed;
Flags.MS = MS;
Flags.Sampled = Sampled;
Flags.ImageFormat = ImageFormat;
Flags.AQ = AQ;
}
};
inline IRHandle irhandle_image(const Type *SampledTy, unsigned Dim,
unsigned Depth, unsigned Arrayed, unsigned MS,
unsigned Sampled, unsigned ImageFormat,
unsigned AQ = 0) {
return std::make_tuple(
SampledTy,
ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
SpecialTypeKind::STK_Image);
}
inline IRHandle irhandle_sampled_image(const Type *SampledTy,
const MachineInstr *ImageTy) {
assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
unsigned AC = AccessQualifier::AccessQualifier::None;
if (ImageTy->getNumOperands() > 8)
AC = ImageTy->getOperand(8).getImm();
return std::make_tuple(
SampledTy,
ImageAttrs(
ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(), AC)
.Val,
SpecialTypeKind::STK_SampledImage);
}
inline IRHandle irhandle_sampler() {
return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
}
inline IRHandle irhandle_pipe(uint8_t AQ) {
return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
}
inline IRHandle irhandle_event() {
return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
}
inline IRHandle irhandle_pointee(const Type *ElementType,
unsigned AddressSpace) {
return std::make_tuple(unifyPtrType(ElementType), AddressSpace,
SpecialTypeKind::STK_ElementPointer);
}
inline IRHandle irhandle_ptr(const void *Ptr, unsigned Arg,
enum SpecialTypeKind STK) {
return std::make_tuple(Ptr, Arg, STK);
}
inline IRHandle irhandle_vkbuffer(const Type *ElementType,
StorageClass::StorageClass SC,
bool IsWriteable) {
return std::make_tuple(ElementType, (SC << 1) | IsWriteable,
SpecialTypeKind::STK_VkBuffer);
}
inline IRHandle handle(const Type *Ty) {
const Type *WrpTy = unifyPtrType(Ty);
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
}
inline IRHandle handle(const Value *V) {
return irhandle_ptr(V, V->getValueID(), STK_Value);
}
inline IRHandle handle(const MachineInstr *KeyMI) {
return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr);
}
} // namespace SPIRV
// Bi-directional mappings between LLVM entities and (v-reg, machine function)
// pairs support management of unique SPIR-V definitions per machine function
// per an LLVM/GlobalISel entity (e.g., Type, Constant, Machine Instruction).
class SPIRVIRMapping {
DenseMap<SPIRV::IRHandleMF, SPIRV::MIHandle> Vregs;
DenseMap<const MachineInstr *, SPIRV::IRHandleMF> Defs;
public:
bool add(SPIRV::IRHandle Handle, const MachineInstr *MI) {
if (auto DefIt = Defs.find(MI); DefIt != Defs.end()) {
auto [ExistHandle, ExistMF] = DefIt->second;
if (Handle == ExistHandle && MI->getMF() == ExistMF)
return false; // already exists
// invalidate the record
Vregs.erase(DefIt->second);
Defs.erase(DefIt);
}
SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MI->getMF());
SPIRV::MIHandle MIKey = SPIRV::getMIKey(MI);
auto It1 = Vregs.try_emplace(HandleMF, MIKey);
if (!It1.second) {
// there is an expired record that we need to invalidate
Defs.erase(std::get<0>(It1.first->second));
// update the record
It1.first->second = MIKey;
}
[[maybe_unused]] auto It2 = Defs.try_emplace(MI, HandleMF);
assert(It2.second);
return true;
}
bool erase(const MachineInstr *MI) {
bool Res = false;
if (auto It = Defs.find(MI); It != Defs.end()) {
Res = Vregs.erase(It->second);
Defs.erase(It);
}
return Res;
}
const MachineInstr *findMI(SPIRV::IRHandle Handle,
const MachineFunction *MF) {
SPIRV::IRHandleMF HandleMF = SPIRV::getIRHandleMF(Handle, MF);
auto It = Vregs.find(HandleMF);
if (It == Vregs.end())
return nullptr;
auto [MI, Reg, Hash] = It->second;
const MachineInstr *Def = MF->getRegInfo().getVRegDef(Reg);
if (!Def || Def != MI || SPIRV::to_hash(MI) != Hash) {
// there is an expired record that we need to invalidate
erase(MI);
return nullptr;
}
assert(Defs.find(MI) != Defs.end() && Defs.find(MI)->second == HandleMF);
return MI;
}
Register find(SPIRV::IRHandle Handle, const MachineFunction *MF) {
const MachineInstr *MI = findMI(Handle, MF);
return MI ? MI->getOperand(0).getReg() : Register();
}
// helpers
bool add(const Type *PointeeTy, unsigned AddressSpace,
const MachineInstr *MI) {
return add(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MI);
}
Register find(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF) {
return find(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
}
const MachineInstr *findMI(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF) {
return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
}
template <typename T> bool add(const T *Obj, const MachineInstr *MI) {
return add(SPIRV::handle(Obj), MI);
}
template <typename T> Register find(const T *Obj, const MachineFunction *MF) {
return find(SPIRV::handle(Obj), MF);
}
template <typename T>
const MachineInstr *findMI(const T *Obj, const MachineFunction *MF) {
return findMI(SPIRV::handle(Obj), MF);
}
};
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVIRMAPPING_H