blob: 3ca0b40cac93ede3d21cd0e1e0bdd9cec19a9a2e [file] [log] [blame] [edit]
//===- SPIRVCBufferAccess.cpp - Translate CBuffer Loads ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all accesses to constant buffer global variables with
// accesses to the proper SPIR-V resource.
//
// The pass operates as follows:
// 1. It finds all constant buffers by looking for the `!hlsl.cbs` metadata.
// 2. For each cbuffer, it finds the global variable holding the resource handle
// and the global variables for each of the cbuffer's members.
// 3. For each member variable, it creates a call to the
// `llvm.spv.resource.getpointer` intrinsic. This intrinsic takes the
// resource handle and the member's index within the cbuffer as arguments.
// The result is a pointer to that member within the SPIR-V resource.
// 4. It then replaces all uses of the original member global variable with the
// pointer returned by the `getpointer` intrinsic. This effectively retargets
// all loads and GEPs to the new resource pointer.
// 5. Finally, it cleans up by deleting the original global variables and the
// `!hlsl.cbs` metadata.
//
// This approach allows subsequent passes, like SPIRVEmitIntrinsics, to
// correctly handle GEPs that operate on the result of the `getpointer` call,
// folding them into a single OpAccessChain instruction.
//
//===----------------------------------------------------------------------===//
#include "SPIRVCBufferAccess.h"
#include "SPIRV.h"
#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ReplaceConstant.h"
#define DEBUG_TYPE "spirv-cbuffer-access"
using namespace llvm;
// Finds the single instruction that defines the resource handle. This is
// typically a call to `llvm.spv.resource.handlefrombinding`.
static Instruction *findHandleDef(GlobalVariable *HandleVar) {
for (User *U : HandleVar->users()) {
if (auto *SI = dyn_cast<StoreInst>(U)) {
if (auto *I = dyn_cast<Instruction>(SI->getValueOperand())) {
return I;
}
}
}
return nullptr;
}
static bool replaceCBufferAccesses(Module &M) {
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
if (!CBufMD)
return false;
SmallVector<Constant *> CBufferGlobals;
for (const hlsl::CBufferMapping &Mapping : *CBufMD)
for (const hlsl::CBufferMember &Member : Mapping.Members)
CBufferGlobals.push_back(Member.GV);
convertUsersOfConstantsToInstructions(CBufferGlobals);
for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
Instruction *HandleDef = findHandleDef(Mapping.Handle);
if (!HandleDef) {
report_fatal_error("Could not find handle definition for cbuffer: " +
Mapping.Handle->getName());
}
// The handle definition should dominate all uses of the cbuffer members.
// We'll insert our getpointer calls right after it.
IRBuilder<> Builder(HandleDef->getNextNode());
for (uint32_t Index = 0; Index < Mapping.Members.size(); ++Index) {
GlobalVariable *MemberGV = Mapping.Members[Index].GV;
if (MemberGV->use_empty()) {
continue;
}
// Create the getpointer intrinsic call.
Value *IndexVal = Builder.getInt32(Index);
Type *PtrType = MemberGV->getType();
Value *GetPointerCall = Builder.CreateIntrinsic(
PtrType, Intrinsic::spv_resource_getpointer, {HandleDef, IndexVal});
MemberGV->replaceAllUsesWith(GetPointerCall);
}
}
// Now that all uses are replaced, clean up the globals and metadata.
for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
for (const auto &Member : Mapping.Members) {
Member.GV->eraseFromParent();
}
// Erase the stores to the handle variable before erasing the handle itself.
SmallVector<Instruction *, 4> HandleStores;
for (User *U : Mapping.Handle->users()) {
if (auto *SI = dyn_cast<StoreInst>(U)) {
HandleStores.push_back(SI);
}
}
for (Instruction *I : HandleStores) {
I->eraseFromParent();
}
Mapping.Handle->eraseFromParent();
}
CBufMD->eraseFromModule();
return true;
}
PreservedAnalyses SPIRVCBufferAccess::run(Module &M,
ModuleAnalysisManager &AM) {
if (replaceCBufferAccesses(M)) {
return PreservedAnalyses::none();
}
return PreservedAnalyses::all();
}
namespace {
class SPIRVCBufferAccessLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
StringRef getPassName() const override { return "SPIRV CBuffer Access"; }
SPIRVCBufferAccessLegacy() : ModulePass(ID) {}
static char ID; // Pass identification.
};
char SPIRVCBufferAccessLegacy::ID = 0;
} // end anonymous namespace
INITIALIZE_PASS(SPIRVCBufferAccessLegacy, DEBUG_TYPE, "SPIRV CBuffer Access",
false, false)
ModulePass *llvm::createSPIRVCBufferAccessLegacyPass() {
return new SPIRVCBufferAccessLegacy();
}