blob: 199fda0b1125d37fdca09d7f3c434cb4e8a8749e [file] [log] [blame] [edit]
//===-- SPIRVPostLegalizer.cpp - amend info after legalization -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The pass partially applies pre-legalization logic to new instructions
// inserted as a result of legalization:
// - assigns SPIR-V types to registers for new instructions.
// - inserts ASSIGN_TYPE pseudo-instructions required for type folding.
//
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
using namespace llvm;
namespace {
class SPIRVPostLegalizer : public MachineFunctionPass {
public:
static char ID;
SPIRVPostLegalizer() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
};
} // namespace
namespace llvm {
// Defined in SPIRVPreLegalizer.cpp.
extern void updateRegType(Register Reg, Type *Ty, SPIRVType *SpirvTy,
SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI);
extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR,
SPIRVType *KnownResType);
} // namespace llvm
static SPIRVType *deduceIntTypeFromResult(Register ResVReg,
MachineIRBuilder &MIB,
SPIRVGlobalRegistry *GR) {
const LLT &Ty = MIB.getMRI()->getType(ResVReg);
return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
}
static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I,
MachineIRBuilder &MIB,
SPIRVGlobalRegistry *GR,
unsigned OpIdx) {
Register OpReg = I->getOperand(OpIdx).getReg();
if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
Register ResVReg = I->getOperand(0).getReg();
const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
if (ResLLT.isVector())
return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(),
MIB, false);
return CompType;
}
}
return nullptr;
}
static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
MachineIRBuilder &MIB,
SPIRVGlobalRegistry *GR,
unsigned StartOp, unsigned EndOp) {
SPIRVType *ResType = nullptr;
for (unsigned i = StartOp; i < EndOp; ++i) {
if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, i)) {
#ifdef EXPENSIVE_CHECKS
assert(!ResType || Type == ResType && "Conflicting type from operands.");
ResType = Type;
#else
return Type;
#endif
}
}
return ResType;
}
static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
for (const MachineOperand &MO : Use->defs()) {
if (!MO.isReg())
continue;
if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(MO.getReg())) {
if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
const LLT &ResLLT = MIB.getMRI()->getType(UseRegister);
if (ResLLT.isVector())
return GR->getOrCreateSPIRVVectorType(
CompType, ResLLT.getNumElements(), MIB, false);
return CompType;
}
}
}
return nullptr;
}
static SPIRVType *deducePointerTypeFromResultRegister(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
Use->getOpcode() == TargetOpcode::G_STORE);
Register ValueReg = Use->getOperand(0).getReg();
SPIRVType *ValueType = GR->getSPIRVTypeForVReg(ValueReg);
if (!ValueType)
return nullptr;
return GR->getOrCreateSPIRVPointerType(ValueType, MIB,
SPIRV::StorageClass::Function);
}
static SPIRVType *deduceTypeFromPointerOperand(MachineInstr *Use,
Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
assert(Use->getOpcode() == TargetOpcode::G_LOAD ||
Use->getOpcode() == TargetOpcode::G_STORE);
Register PtrReg = Use->getOperand(1).getReg();
SPIRVType *PtrType = GR->getSPIRVTypeForVReg(PtrReg);
if (!PtrType)
return nullptr;
return GR->getPointeeType(PtrType);
}
static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
SPIRVType *ResType = nullptr;
LLVM_DEBUG(dbgs() << "Looking at use " << Use);
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
case TargetOpcode::G_UNMERGE_VALUES:
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FREM:
case TargetOpcode::G_FMA:
case TargetOpcode::COPY:
case TargetOpcode::G_STRICT_FMA:
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
break;
case TargetOpcode::G_LOAD:
case TargetOpcode::G_STORE:
if (Reg == Use.getOperand(1).getReg())
ResType = deducePointerTypeFromResultRegister(&Use, Reg, GR, MIB);
else
ResType = deduceTypeFromPointerOperand(&Use, Reg, GR, MIB);
break;
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_INTRINSIC: {
auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
if (IntrinsicID == Intrinsic::spv_insertelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
} else if (IntrinsicID == Intrinsic::spv_extractelt) {
if (Reg == Use.getOperand(2).getReg())
ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
}
break;
}
}
if (ResType) {
LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
return ResType;
}
}
return nullptr;
}
static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
Register ResVReg = I->getOperand(0).getReg();
switch (I->getOpcode()) {
case TargetOpcode::G_CONSTANT:
case TargetOpcode::G_ANYEXT:
case TargetOpcode::G_SEXT:
case TargetOpcode::G_ZEXT:
return deduceIntTypeFromResult(ResVReg, MIB, GR);
case TargetOpcode::G_BUILD_VECTOR:
return deduceTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
case TargetOpcode::G_SHUFFLE_VECTOR:
return deduceTypeFromOperandRange(I, MIB, GR, 1, 3);
default:
if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
I->getOperand(1).isReg())
return deduceTypeFromSingleOperand(I, MIB, GR, 1);
return nullptr;
}
}
static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
SPIRVType *ScalarType = nullptr;
if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
assert(DefType->getOpcode() == SPIRV::OpTypeVector);
ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
}
if (!ScalarType) {
// If we could not deduce the type from the source, try to deduce it from
// the uses of the results.
for (unsigned i = 0; i < I->getNumDefs(); ++i) {
Register DefReg = I->getOperand(i).getReg();
ScalarType = deduceTypeFromUses(DefReg, MF, GR, MIB);
if (ScalarType) {
ScalarType = GR->getScalarOrVectorComponentType(ScalarType);
break;
}
}
}
if (!ScalarType)
return false;
for (unsigned i = 0; i < I->getNumOperands(); ++i) {
Register DefReg = I->getOperand(i).getReg();
if (GR->getSPIRVTypeForVReg(DefReg))
continue;
LLT DefLLT = MRI.getType(DefReg);
SPIRVType *ResType =
DefLLT.isVector()
? GR->getOrCreateSPIRVVectorType(
ScalarType, DefLLT.getNumElements(), *I,
*MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
: ScalarType;
setRegClassType(DefReg, ResType, GR, &MRI, MF);
}
return true;
}
static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I);
MachineRegisterInfo &MRI = MF.getRegInfo();
Register ResVReg = I->getOperand(0).getReg();
// G_UNMERGE_VALUES is handled separately because it has multiple definitions,
// unlike the other instructions which have a single result register. The main
// deduction logic is designed for the single-definition case.
if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
return deduceAndAssignTypeForGUnmerge(I, MF, GR, MIB);
LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB);
if (!ResType) {
LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
}
if (!ResType)
return false;
LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType);
GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
if (!MRI.getRegClassOrNull(ResVReg)) {
LLVM_DEBUG(dbgs() << "Updating the register class.\n");
setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
}
return true;
}
static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
MachineRegisterInfo &MRI) {
LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
<< I;);
if (I.getNumDefs() == 0) {
LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
return false;
}
if (!I.isPreISelOpcode()) {
LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n");
return false;
}
Register ResultRegister = I.defs().begin()->getReg();
if (GR->getSPIRVTypeForVReg(ResultRegister)) {
LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
if (!MRI.getRegClassOrNull(ResultRegister)) {
LLVM_DEBUG(dbgs() << "Updating the register class.\n");
setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
GR, &MRI, *GR->CurMF, true);
}
return false;
}
return true;
}
static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
SPIRVGlobalRegistry *GR) {
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 8> Worklist;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &I : MBB) {
if (requiresSpirvType(I, GR, MRI)) {
Worklist.push_back(&I);
}
}
}
if (Worklist.empty()) {
LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
return;
}
LLVM_DEBUG(dbgs() << "Initial worklist:\n";
for (auto *I : Worklist) { I->dump(); });
bool Changed;
do {
Changed = false;
SmallVector<MachineInstr *, 8> NextWorklist;
for (MachineInstr *I : Worklist) {
MachineIRBuilder MIB(*I);
if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
Changed = true;
} else {
NextWorklist.push_back(I);
}
}
Worklist = std::move(NextWorklist);
LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
} while (Changed);
if (Worklist.empty())
return;
for (auto *I : Worklist) {
MachineIRBuilder MIB(*I);
LLVM_DEBUG(dbgs() << "Assigning default type to results in " << *I);
for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
Register ResVReg = I->getOperand(Idx).getReg();
if (GR->getSPIRVTypeForVReg(ResVReg))
continue;
const LLT &ResLLT = MRI.getType(ResVReg);
SPIRVType *ResType = nullptr;
if (ResLLT.isVector()) {
SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
ResLLT.getElementType().getSizeInBits(), MIB);
ResType = GR->getOrCreateSPIRVVectorType(
CompType, ResLLT.getNumElements(), MIB, false);
} else {
ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
}
setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
}
}
}
static bool hasAssignType(Register Reg, MachineRegisterInfo &MRI) {
for (MachineInstr &UseInstr : MRI.use_nodbg_instructions(Reg)) {
if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
return true;
}
}
return false;
}
static void generateAssignType(MachineInstr &MI, Register ResultRegister,
SPIRVType *ResultType, SPIRVGlobalRegistry *GR,
MachineRegisterInfo &MRI) {
LLVM_DEBUG(dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
<< printReg(ResultRegister, MRI.getTargetRegisterInfo())
<< " with type: " << *ResultType);
MachineIRBuilder MIB(MI);
updateRegType(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
// Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is
// present after each auto-folded instruction to take a type reference
// from.
Register NewReg =
MRI.createGenericVirtualRegister(MRI.getType(ResultRegister));
const auto *RegClass = GR->getRegClass(ResultType);
MRI.setRegClass(NewReg, RegClass);
MRI.setRegClass(ResultRegister, RegClass);
GR->assignSPIRVTypeToVReg(ResultType, ResultRegister, MIB.getMF());
// This is to make it convenient for Legalizer to get the SPIRVType
// when processing the actual MI (i.e. not pseudo one).
GR->assignSPIRVTypeToVReg(ResultType, NewReg, MIB.getMF());
// Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to
// keep the flags after instruction selection.
const uint32_t Flags = MI.getFlags();
MIB.buildInstr(SPIRV::ASSIGN_TYPE)
.addDef(ResultRegister)
.addUse(NewReg)
.addUse(GR->getSPIRVTypeID(ResultType))
.setMIFlags(Flags);
for (unsigned I = 0, E = MI.getNumDefs(); I != E; ++I) {
MachineOperand &MO = MI.getOperand(I);
if (MO.getReg() == ResultRegister) {
MO.setReg(NewReg);
break;
}
}
}
static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
SPIRVGlobalRegistry *GR) {
LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
<< MF.getName() << "\n");
MachineRegisterInfo &MRI = MF.getRegInfo();
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (!isTypeFoldingSupported(MI.getOpcode()))
continue;
LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
Register ResultRegister = MI.defs().begin()->getReg();
if (hasAssignType(ResultRegister, MRI)) {
LLVM_DEBUG(dbgs() << " Instruction already has ASSIGN_TYPE\n");
continue;
}
SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
assert(ResultType);
generateAssignType(MI, ResultRegister, ResultType, GR, MRI);
}
}
}
// Do a preorder traversal of the CFG starting from the BB |Start|.
// point. Calls |op| on each basic block encountered during the traversal.
void visit(MachineFunction &MF, MachineBasicBlock &Start,
std::function<void(MachineBasicBlock *)> op) {
std::stack<MachineBasicBlock *> ToVisit;
SmallPtrSet<MachineBasicBlock *, 8> Seen;
ToVisit.push(&Start);
Seen.insert(ToVisit.top());
while (ToVisit.size() != 0) {
MachineBasicBlock *MBB = ToVisit.top();
ToVisit.pop();
op(MBB);
for (auto Succ : MBB->successors()) {
if (Seen.contains(Succ))
continue;
ToVisit.push(Succ);
Seen.insert(Succ);
}
}
}
// Do a preorder traversal of the CFG starting from the given function's entry
// point. Calls |op| on each basic block encountered during the traversal.
void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
visit(MF, *MF.begin(), std::move(op));
}
bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
// Initialize the type registry.
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
registerSpirvTypeForNewInstructions(MF, GR);
ensureAssignTypeForTypeFolding(MF, GR);
return true;
}
INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
false)
char SPIRVPostLegalizer::ID = 0;
FunctionPass *llvm::createSPIRVPostLegalizerPass() {
return new SPIRVPostLegalizer();
}