blob: 1f2700ac55647e273af3263bf886e0ae2c58e80d [file] [log] [blame]
//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
#include "DXILDataScalarization.h"
#include "DirectX.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/IR/Type.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#define DEBUG_TYPE "dxil-data-scalarization"
static const int MaxVecSize = 4;
using namespace llvm;
class DXILDataScalarizationLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override;
DXILDataScalarizationLegacy() : ModulePass(ID) {}
static char ID; // Pass identification.
};
static bool findAndReplaceVectors(Module &M);
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
public:
DataScalarizerVisitor() : GlobalMap() {}
bool visit(Instruction &I);
// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
bool visitInstruction(Instruction &I) { return false; }
bool visitSelectInst(SelectInst &SI) { return false; }
bool visitICmpInst(ICmpInst &ICI) { return false; }
bool visitFCmpInst(FCmpInst &FCI) { return false; }
bool visitUnaryOperator(UnaryOperator &UO) { return false; }
bool visitBinaryOperator(BinaryOperator &BO) { return false; }
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
bool visitCastInst(CastInst &CI) { return false; }
bool visitBitCastInst(BitCastInst &BCI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
bool visitPHINode(PHINode &PHI) { return false; }
bool visitLoadInst(LoadInst &LI);
bool visitStoreInst(StoreInst &SI);
bool visitCallInst(CallInst &ICI) { return false; }
bool visitFreezeInst(FreezeInst &FI) { return false; }
friend bool findAndReplaceVectors(llvm::Module &M);
private:
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
};
bool DataScalarizerVisitor::visit(Instruction &I) {
assert(!GlobalMap.empty());
return InstVisitor::visit(I);
}
GlobalVariable *
DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
auto It = GlobalMap.find(OldGlobal);
if (It != GlobalMap.end()) {
return It->second; // Found, return the new global
}
}
return nullptr; // Not found
}
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
unsigned NumOperands = LI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = LI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
GetElementPtrInst *OldGEP =
cast<GetElementPtrInst>(CE->getAsInstruction());
OldGEP->insertBefore(LI.getIterator());
IRBuilder<> Builder(&LI);
LoadInst *NewLoad =
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
NewLoad->setAlignment(LI.getAlign());
LI.replaceAllUsesWith(NewLoad);
LI.eraseFromParent();
visitGetElementPtrInst(*OldGEP);
return true;
}
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
LI.setOperand(I, NewGlobal);
}
return false;
}
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
unsigned NumOperands = SI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = SI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
GetElementPtrInst *OldGEP =
cast<GetElementPtrInst>(CE->getAsInstruction());
OldGEP->insertBefore(SI.getIterator());
IRBuilder<> Builder(&SI);
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
NewStore->setAlignment(SI.getAlign());
SI.replaceAllUsesWith(NewStore);
SI.eraseFromParent();
visitGetElementPtrInst(*OldGEP);
return true;
}
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
SI.setOperand(I, NewGlobal);
}
return false;
}
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
unsigned NumOperands = GEPI.getNumOperands();
GlobalVariable *NewGlobal = nullptr;
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = GEPI.getOperand(I);
NewGlobal = lookupReplacementGlobal(CurrOpperand);
if (NewGlobal)
break;
}
if (!NewGlobal)
return false;
IRBuilder<> Builder(&GEPI);
SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
Value *NewGEP =
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
GEPI.getName(), GEPI.getNoWrapFlags());
GEPI.replaceAllUsesWith(NewGEP);
GEPI.eraseFromParent();
return true;
}
// Recursively Creates and Array like version of the given vector like type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}
Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
LLVMContext &Ctx) {
// Handle ConstantAggregateZero (zero-initialized constants)
if (isa<ConstantAggregateZero>(Init)) {
return ConstantAggregateZero::get(NewType);
}
// Handle UndefValue (undefined constants)
if (isa<UndefValue>(Init)) {
return UndefValue::get(NewType);
}
// Handle vector to array transformation
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
// Convert vector initializer to array initializer
SmallVector<Constant *, MaxVecSize> ArrayElements;
if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
ArrayElements.push_back(ConstVecInit->getOperand(I));
} else if (ConstantDataVector *ConstDataVecInit =
llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
} else {
assert(false && "Expected a ConstantVector or ConstantDataVector for "
"vector initializer!");
}
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
}
// Handle array of vectors transformation
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
assert(ArrayInit && "Expected a ConstantArray for array initializer!");
SmallVector<Constant *, MaxVecSize> NewArrayElements;
for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
// Recursively transform array elements
Constant *NewElemInit = transformInitializer(
ArrayInit->getOperand(I), ArrayTy->getElementType(),
cast<ArrayType>(NewType)->getElementType(), Ctx);
NewArrayElements.push_back(NewElemInit);
}
return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
}
// If not a vector or array, return the original initializer
return Init;
}
static bool findAndReplaceVectors(Module &M) {
bool MadeChange = false;
LLVMContext &Ctx = M.getContext();
IRBuilder<> Builder(Ctx);
DataScalarizerVisitor Impl;
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
if (OrigType != NewType) {
// Create a new global variable with the updated type
// Note: Initializer is set via transformInitializer
GlobalVariable *NewGlobal = new GlobalVariable(
M, NewType, G.isConstant(), G.getLinkage(),
/*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
G.getThreadLocalMode(), G.getAddressSpace(),
G.isExternallyInitialized());
// Copy relevant attributes
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
if (G.getAlignment() > 0) {
NewGlobal->setAlignment(G.getAlign());
}
if (G.hasInitializer()) {
Constant *Init = G.getInitializer();
Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
NewGlobal->setInitializer(NewInit);
}
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
// type equality. Instead we will use the visitor pattern.
Impl.GlobalMap[&G] = NewGlobal;
for (User *U : make_early_inc_range(G.users())) {
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
ConstantExpr *CE = cast<ConstantExpr>(U);
for (User *UCE : make_early_inc_range(CE->users())) {
if (Instruction *Inst = dyn_cast<Instruction>(UCE))
Impl.visit(*Inst);
}
}
if (Instruction *Inst = dyn_cast<Instruction>(U))
Impl.visit(*Inst);
}
}
}
// Remove the old globals after the iteration
for (auto &[Old, New] : Impl.GlobalMap) {
Old->eraseFromParent();
MadeChange = true;
}
return MadeChange;
}
PreservedAnalyses DXILDataScalarization::run(Module &M,
ModuleAnalysisManager &) {
bool MadeChanges = findAndReplaceVectors(M);
if (!MadeChanges)
return PreservedAnalyses::all();
PreservedAnalyses PA;
return PA;
}
bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
return findAndReplaceVectors(M);
}
char DXILDataScalarizationLegacy::ID = 0;
INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
"DXIL Data Scalarization", false, false)
INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
"DXIL Data Scalarization", false, false)
ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
return new DXILDataScalarizationLegacy();
}