| //===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===// |
| // |
| // The LLVM Compiler Infrastructure |
| // |
| // This file is distributed under the University of Illinois Open Source |
| // License. See LICENSE.TXT for details. |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Copy struct args to local memory. This is needed for kernel functions only. |
| // This is a preparation for handling cases like |
| // |
| // kernel void foo(struct A arg, ...) |
| // { |
| // struct A *p = &arg; |
| // ... |
| // ... = p->filed1 ... (this is no generic address for .param) |
| // p->filed2 = ... (this is no write access to .param) |
| // } |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "NVPTX.h" |
| #include "NVPTXUtilities.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/IntrinsicInst.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/Pass.h" |
| |
| using namespace llvm; |
| |
| namespace llvm { |
| void initializeNVPTXLowerStructArgsPass(PassRegistry &); |
| } |
| |
| class LLVM_LIBRARY_VISIBILITY NVPTXLowerStructArgs : public FunctionPass { |
| bool runOnFunction(Function &F) override; |
| |
| void handleStructPtrArgs(Function &); |
| void handleParam(Argument *); |
| |
| public: |
| static char ID; // Pass identification, replacement for typeid |
| NVPTXLowerStructArgs() : FunctionPass(ID) {} |
| const char *getPassName() const override { |
| return "Copy structure (byval *) arguments to stack"; |
| } |
| }; |
| |
| char NVPTXLowerStructArgs::ID = 1; |
| |
| INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args", |
| "Lower structure arguments (NVPTX)", false, false) |
| |
| void NVPTXLowerStructArgs::handleParam(Argument *Arg) { |
| Function *Func = Arg->getParent(); |
| Instruction *FirstInst = &(Func->getEntryBlock().front()); |
| PointerType *PType = dyn_cast<PointerType>(Arg->getType()); |
| |
| assert(PType && "Expecting pointer type in handleParam"); |
| |
| Type *StructType = PType->getElementType(); |
| AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst); |
| |
| /* Set the alignment to alignment of the byval parameter. This is because, |
| * later load/stores assume that alignment, and we are going to replace |
| * the use of the byval parameter with this alloca instruction. |
| */ |
| AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1)); |
| |
| Arg->replaceAllUsesWith(AllocA); |
| |
| // Get the cvt.gen.to.param intrinsic |
| Type *CvtTypes[] = { |
| Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM), |
| Type::getInt8PtrTy(Func->getParent()->getContext(), |
| ADDRESS_SPACE_GENERIC)}; |
| Function *CvtFunc = Intrinsic::getDeclaration( |
| Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes); |
| |
| Value *BitcastArgs[] = { |
| new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(), |
| ADDRESS_SPACE_GENERIC), |
| Arg->getName(), FirstInst)}; |
| CallInst *CallCVT = |
| CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst); |
| |
| BitCastInst *BitCast = new BitCastInst( |
| CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM), |
| Arg->getName(), FirstInst); |
| LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst); |
| new StoreInst(LI, AllocA, FirstInst); |
| } |
| |
| // ============================================================================= |
| // If the function had a struct ptr arg, say foo(%struct.x *byval %d), then |
| // add the following instructions to the first basic block : |
| // |
| // %temp = alloca %struct.x, align 8 |
| // %tt1 = bitcast %struct.x * %d to i8 * |
| // %tt2 = llvm.nvvm.cvt.gen.to.param %tt2 |
| // %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) * |
| // %tv = load %struct.x addrspace(101) * %tempd |
| // store %struct.x %tv, %struct.x * %temp, align 8 |
| // |
| // The above code allocates some space in the stack and copies the incoming |
| // struct from param space to local space. |
| // Then replace all occurences of %d by %temp. |
| // ============================================================================= |
| void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) { |
| for (Argument &Arg : F.args()) { |
| if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) { |
| handleParam(&Arg); |
| } |
| } |
| } |
| |
| // ============================================================================= |
| // Main function for this pass. |
| // ============================================================================= |
| bool NVPTXLowerStructArgs::runOnFunction(Function &F) { |
| // Skip non-kernels. See the comments at the top of this file. |
| if (!isKernelFunction(F)) |
| return false; |
| |
| handleStructPtrArgs(F); |
| return true; |
| } |
| |
| FunctionPass *llvm::createNVPTXLowerStructArgsPass() { |
| return new NVPTXLowerStructArgs(); |
| } |