blob: 8ec05271fc8fbfb2cb15d352f1b20d9f05df8605 [file] [log] [blame] [edit]
//===- SPIRVPushConstantAccess.cpp - Translate CBuffer Loads ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass changes the types of all the globals in the PushConstant
// address space into a target extension type, and makes all references
// to this global go though a custom SPIR-V intrinsic.
//
// This allows the backend to properly lower the push constant struct type
// to a fully laid out type, and generate the proper OpAccessChain.
//
//===----------------------------------------------------------------------===//
#include "SPIRVPushConstantAccess.h"
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Module.h"
#define DEBUG_TYPE "spirv-pushconstant-access"
using namespace llvm;
static bool replacePushConstantAccesses(Module &M, SPIRVGlobalRegistry *GR) {
bool Changed = false;
for (GlobalVariable &GV : make_early_inc_range(M.globals())) {
if (GV.getAddressSpace() !=
storageClassToAddressSpace(SPIRV::StorageClass::PushConstant))
continue;
GV.removeDeadConstantUsers();
Type *PCType = llvm::TargetExtType::get(
M.getContext(), "spirv.PushConstant", {GV.getValueType()});
GlobalVariable *NewGV =
new GlobalVariable(M, PCType, GV.isConstant(), GV.getLinkage(),
/* initializer= */ nullptr, GV.getName(),
/* InsertBefore= */ &GV, GV.getThreadLocalMode(),
GV.getAddressSpace(), GV.isExternallyInitialized());
NewGV->setVisibility(GV.getVisibility());
for (User *U : make_early_inc_range(GV.users())) {
Instruction *I = cast<Instruction>(U);
IRBuilder<> Builder(I);
Value *GetPointerCall = Builder.CreateIntrinsic(
NewGV->getType(), Intrinsic::spv_pushconstant_getpointer, {NewGV});
GR->buildAssignPtr(Builder, GV.getValueType(), GetPointerCall);
I->replaceUsesOfWith(&GV, GetPointerCall);
}
GV.eraseFromParent();
Changed = true;
}
return Changed;
}
PreservedAnalyses SPIRVPushConstantAccess::run(Module &M,
ModuleAnalysisManager &AM) {
const SPIRVSubtarget *ST = TM.getSubtargetImpl();
SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
return replacePushConstantAccesses(M, GR) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}
namespace {
class SPIRVPushConstantAccessLegacy : public ModulePass {
SPIRVTargetMachine *TM = nullptr;
public:
bool runOnModule(Module &M) override {
const SPIRVSubtarget *ST = TM->getSubtargetImpl();
SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
return replacePushConstantAccesses(M, GR);
}
StringRef getPassName() const override {
return "SPIRV push constant Access";
}
SPIRVPushConstantAccessLegacy(SPIRVTargetMachine *TM)
: ModulePass(ID), TM(TM) {}
static char ID; // Pass identification.
};
char SPIRVPushConstantAccessLegacy::ID = 0;
} // end anonymous namespace
INITIALIZE_PASS(SPIRVPushConstantAccessLegacy, DEBUG_TYPE,
"SPIRV push constant Access", false, false)
ModulePass *
llvm::createSPIRVPushConstantAccessLegacyPass(SPIRVTargetMachine *TM) {
return new SPIRVPushConstantAccessLegacy(TM);
}