blob: fa8b7dae6c76a4abac49869010092e626c35cbd4 [file] [log] [blame] [edit]
//===- DXILMemIntrinsics.cpp - Eliminate Memory Intrinsics ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "DXILMemIntrinsics.h"
#include "DirectX.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
#define DEBUG_TYPE "dxil-mem-intrinsics"
using namespace llvm;
void expandMemSet(MemSetInst *MemSet) {
IRBuilder<> Builder(MemSet);
Value *Dst = MemSet->getDest();
Value *Val = MemSet->getValue();
ConstantInt *LengthCI = dyn_cast<ConstantInt>(MemSet->getLength());
assert(LengthCI && "Expected length to be a ConstantInt");
[[maybe_unused]] const DataLayout &DL =
Builder.GetInsertBlock()->getModule()->getDataLayout();
[[maybe_unused]] uint64_t OrigLength = LengthCI->getZExtValue();
AllocaInst *Alloca = dyn_cast<AllocaInst>(Dst);
assert(Alloca && "Expected memset on an Alloca");
assert(OrigLength == Alloca->getAllocationSize(DL)->getFixedValue() &&
"Expected for memset size to match DataLayout size");
Type *AllocatedTy = Alloca->getAllocatedType();
ArrayType *ArrTy = dyn_cast<ArrayType>(AllocatedTy);
assert(ArrTy && "Expected Alloca for an Array Type");
Type *ElemTy = ArrTy->getElementType();
uint64_t Size = ArrTy->getArrayNumElements();
[[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
assert(ElemSize > 0 && "Size must be set");
assert(OrigLength == ElemSize * Size && "Size in bytes must match");
Value *TypedVal = Val;
if (Val->getType() != ElemTy)
TypedVal = Builder.CreateIntCast(Val, ElemTy, false);
for (uint64_t I = 0; I < Size; ++I) {
Value *Zero = Builder.getInt32(0);
Value *Offset = Builder.getInt32(I);
Value *Ptr = Builder.CreateGEP(ArrTy, Dst, {Zero, Offset}, "gep");
Builder.CreateStore(TypedVal, Ptr);
}
MemSet->eraseFromParent();
}
static Type *getPointeeType(Value *Ptr, const DataLayout &DL) {
if (auto *GV = dyn_cast<GlobalVariable>(Ptr))
return GV->getValueType();
if (auto *AI = dyn_cast<AllocaInst>(Ptr))
return AI->getAllocatedType();
if (auto *II = dyn_cast<IntrinsicInst>(Ptr)) {
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) {
Type *Ty = cast<dxil::AnyResourceExtType>(II->getArgOperand(0)->getType())
->getResourceType();
assert(Ty && "getpointer used on untyped resource");
return Ty;
}
}
if (auto *GEP = dyn_cast<GEPOperator>(Ptr)) {
Type *Ty = GEP->getResultElementType();
if (!Ty->isIntegerTy(8))
return Ty;
// We have ptradd, so we have to hope there's enough information to work out
// what we're indexing.
Type *IndexedType = getPointeeType(GEP->getPointerOperand(), DL);
if (auto *AT = dyn_cast<ArrayType>(IndexedType))
return AT->getElementType();
if (auto *ST = dyn_cast<StructType>(IndexedType)) {
// Indexing a struct should always be constant
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
[[maybe_unused]] bool IsConst =
GEP->accumulateConstantOffset(DL, ConstantOffset);
assert(IsConst && "Non-constant GEP into struct?");
// Now, work out what we'll find at that offset.
const StructLayout *Layout = DL.getStructLayout(ST);
unsigned Idx =
Layout->getElementContainingOffset(ConstantOffset.getZExtValue());
return ST->getTypeAtIndex(Idx);
}
llvm_unreachable("Could not infer type from GEP");
}
llvm_unreachable("Could not calculate pointee type");
}
static size_t flattenTypes(Type *ContainerTy, const DataLayout &DL,
SmallVectorImpl<std::pair<Type *, size_t>> &FlatTys,
size_t NextOffset = 0) {
if (auto *AT = dyn_cast<ArrayType>(ContainerTy)) {
for (uint64_t I = 0, E = AT->getNumElements(); I != E; ++I)
NextOffset = flattenTypes(AT->getElementType(), DL, FlatTys, NextOffset);
return NextOffset;
}
if (auto *ST = dyn_cast<StructType>(ContainerTy)) {
for (Type *Ty : ST->elements())
NextOffset = flattenTypes(Ty, DL, FlatTys, NextOffset);
return NextOffset;
}
FlatTys.emplace_back(ContainerTy, NextOffset);
return NextOffset + DL.getTypeStoreSize(ContainerTy);
}
void expandMemCpy(MemCpyInst *MemCpy) {
IRBuilder<> Builder(MemCpy);
Value *Dst = MemCpy->getDest();
Value *Src = MemCpy->getSource();
ConstantInt *LengthCI = dyn_cast<ConstantInt>(MemCpy->getLength());
assert(LengthCI && "Expected Length to be a ConstantInt");
assert(!MemCpy->isVolatile() && "Handling for volatile not implemented");
uint64_t ByteLength = LengthCI->getZExtValue();
// If length to copy is zero, no memcpy is needed.
if (ByteLength == 0)
return;
const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
SmallVector<std::pair<Type *, size_t>> FlattenedTypes;
[[maybe_unused]] size_t MaxLength =
flattenTypes(getPointeeType(Dst, DL), DL, FlattenedTypes);
assert(MaxLength >= ByteLength && "Dst not large enough for memcpy");
LLVM_DEBUG({
// Check if Src is layout compatible with Dst. This should always be true
// unless the frontend did something wrong.
SmallVector<std::pair<Type *, size_t>> SrcTypes;
size_t SrcLength = flattenTypes(getPointeeType(Src, DL), DL, SrcTypes);
assert(SrcLength >= ByteLength && "Src not large enough for memcpy");
for (const auto &[LHS, RHS] : zip(FlattenedTypes, SrcTypes)) {
auto &[DstTy, DstOffset] = LHS;
auto &[SrcTy, SrcOffset] = RHS;
assert(DstTy == SrcTy && "Mismatched types for memcpy");
assert(DstOffset == SrcOffset && "Incompatible layouts for memcpy");
if (DstOffset >= ByteLength)
break;
}
});
for (const auto &[Ty, Offset] : FlattenedTypes) {
if (Offset >= ByteLength)
break;
// TODO: Should we skip padding types here?
Type *Int8Ty = Builder.getInt8Ty();
Value *ByteOffset = Builder.getInt32(Offset);
Value *SrcPtr = Builder.CreateInBoundsGEP(Int8Ty, Src, ByteOffset);
Value *SrcVal = Builder.CreateLoad(Ty, SrcPtr);
Value *DstPtr = Builder.CreateInBoundsGEP(Int8Ty, Dst, ByteOffset);
Builder.CreateStore(SrcVal, DstPtr);
}
MemCpy->eraseFromParent();
}
void expandMemMove(MemMoveInst *MemMove) {
report_fatal_error("memmove expansion is not implemented yet.");
}
static bool eliminateMemIntrinsics(Module &M) {
bool HadMemIntrinsicUses = false;
for (auto &F : make_early_inc_range(M.functions())) {
Intrinsic::ID IID = F.getIntrinsicID();
switch (IID) {
case Intrinsic::memcpy:
case Intrinsic::memcpy_inline:
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::memset_inline:
break;
default:
continue;
}
for (User *U : make_early_inc_range(F.users())) {
HadMemIntrinsicUses = true;
if (auto *MemSet = dyn_cast<MemSetInst>(U))
expandMemSet(MemSet);
else if (auto *MemCpy = dyn_cast<MemCpyInst>(U))
expandMemCpy(MemCpy);
else if (auto *MemMove = dyn_cast<MemMoveInst>(U))
expandMemMove(MemMove);
else
llvm_unreachable("Unhandled memory intrinsic");
}
assert(F.user_empty() && "Mem intrinsic not eliminated?");
F.eraseFromParent();
}
return HadMemIntrinsicUses;
}
PreservedAnalyses DXILMemIntrinsics::run(Module &M, ModuleAnalysisManager &) {
if (eliminateMemIntrinsics(M))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
class DXILMemIntrinsicsLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override { return eliminateMemIntrinsics(M); }
DXILMemIntrinsicsLegacy() : ModulePass(ID) {}
static char ID; // Pass identification.
};
char DXILMemIntrinsicsLegacy::ID = 0;
INITIALIZE_PASS_BEGIN(DXILMemIntrinsicsLegacy, DEBUG_TYPE,
"DXIL Memory Intrinsic Elimination", false, false)
INITIALIZE_PASS_END(DXILMemIntrinsicsLegacy, DEBUG_TYPE,
"DXIL Memory Intrinsic Elimination", false, false)
ModulePass *llvm::createDXILMemIntrinsicsLegacyPass() {
return new DXILMemIntrinsicsLegacy();
}