| //===-- AMDGPURegBankLegalizeHelper.cpp -----------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| /// Implements actual lowering algorithms for each ID that can be used in |
| /// Rule.OperandMapping. Similar to legalizer helper but with register banks. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "AMDGPURegBankLegalizeHelper.h" |
| #include "AMDGPUGlobalISelUtils.h" |
| #include "AMDGPUInstrInfo.h" |
| #include "AMDGPURegisterBankInfo.h" |
| #include "MCTargetDesc/AMDGPUMCTargetDesc.h" |
| #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
| #include "llvm/CodeGen/MachineUniformityAnalysis.h" |
| |
| #define DEBUG_TYPE "amdgpu-regbanklegalize" |
| |
| using namespace llvm; |
| using namespace AMDGPU; |
| |
| RegBankLegalizeHelper::RegBankLegalizeHelper( |
| MachineIRBuilder &B, const MachineUniformityInfo &MUI, |
| const RegisterBankInfo &RBI, const RegBankLegalizeRules &RBLRules) |
| : B(B), MRI(*B.getMRI()), MUI(MUI), RBI(RBI), RBLRules(RBLRules), |
| SgprRB(&RBI.getRegBank(AMDGPU::SGPRRegBankID)), |
| VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)), |
| VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {} |
| |
| void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) { |
| const SetOfRulesForOpcode &RuleSet = RBLRules.getRulesForOpc(MI); |
| const RegBankLLTMapping &Mapping = RuleSet.findMappingForMI(MI, MRI, MUI); |
| |
| SmallSet<Register, 4> WaterfallSgprs; |
| unsigned OpIdx = 0; |
| if (Mapping.DstOpMapping.size() > 0) { |
| B.setInsertPt(*MI.getParent(), std::next(MI.getIterator())); |
| applyMappingDst(MI, OpIdx, Mapping.DstOpMapping); |
| } |
| if (Mapping.SrcOpMapping.size() > 0) { |
| B.setInstr(MI); |
| applyMappingSrc(MI, OpIdx, Mapping.SrcOpMapping, WaterfallSgprs); |
| } |
| |
| lower(MI, Mapping, WaterfallSgprs); |
| } |
| |
| void RegBankLegalizeHelper::splitLoad(MachineInstr &MI, |
| ArrayRef<LLT> LLTBreakdown, LLT MergeTy) { |
| MachineFunction &MF = B.getMF(); |
| assert(MI.getNumMemOperands() == 1); |
| MachineMemOperand &BaseMMO = **MI.memoperands_begin(); |
| Register Dst = MI.getOperand(0).getReg(); |
| const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst); |
| Register Base = MI.getOperand(1).getReg(); |
| LLT PtrTy = MRI.getType(Base); |
| const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base); |
| LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits()); |
| SmallVector<Register, 4> LoadPartRegs; |
| |
| unsigned ByteOffset = 0; |
| for (LLT PartTy : LLTBreakdown) { |
| Register BasePlusOffset; |
| if (ByteOffset == 0) { |
| BasePlusOffset = Base; |
| } else { |
| auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset); |
| BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0); |
| } |
| auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy); |
| auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO); |
| LoadPartRegs.push_back(LoadPart.getReg(0)); |
| ByteOffset += PartTy.getSizeInBytes(); |
| } |
| |
| if (!MergeTy.isValid()) { |
| // Loads are of same size, concat or merge them together. |
| B.buildMergeLikeInstr(Dst, LoadPartRegs); |
| } else { |
| // Loads are not all of same size, need to unmerge them to smaller pieces |
| // of MergeTy type, then merge pieces to Dst. |
| SmallVector<Register, 4> MergeTyParts; |
| for (Register Reg : LoadPartRegs) { |
| if (MRI.getType(Reg) == MergeTy) { |
| MergeTyParts.push_back(Reg); |
| } else { |
| auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg); |
| for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) |
| MergeTyParts.push_back(Unmerge.getReg(i)); |
| } |
| } |
| B.buildMergeLikeInstr(Dst, MergeTyParts); |
| } |
| MI.eraseFromParent(); |
| } |
| |
| void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy, |
| LLT MergeTy) { |
| MachineFunction &MF = B.getMF(); |
| assert(MI.getNumMemOperands() == 1); |
| MachineMemOperand &BaseMMO = **MI.memoperands_begin(); |
| Register Dst = MI.getOperand(0).getReg(); |
| const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst); |
| Register Base = MI.getOperand(1).getReg(); |
| |
| MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy); |
| auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO); |
| |
| if (WideTy.isScalar()) { |
| B.buildTrunc(Dst, WideLoad); |
| } else { |
| SmallVector<Register, 4> MergeTyParts; |
| auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad); |
| |
| LLT DstTy = MRI.getType(Dst); |
| unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits(); |
| for (unsigned i = 0; i < NumElts; ++i) { |
| MergeTyParts.push_back(Unmerge.getReg(i)); |
| } |
| B.buildMergeLikeInstr(Dst, MergeTyParts); |
| } |
| MI.eraseFromParent(); |
| } |
| |
| void RegBankLegalizeHelper::lower(MachineInstr &MI, |
| const RegBankLLTMapping &Mapping, |
| SmallSet<Register, 4> &WaterfallSgprs) { |
| |
| switch (Mapping.LoweringMethod) { |
| case DoNotLower: |
| return; |
| case VccExtToSel: { |
| LLT Ty = MRI.getType(MI.getOperand(0).getReg()); |
| Register Src = MI.getOperand(1).getReg(); |
| unsigned Opc = MI.getOpcode(); |
| if (Ty == S32 || Ty == S16) { |
| auto True = B.buildConstant({VgprRB, Ty}, Opc == G_SEXT ? -1 : 1); |
| auto False = B.buildConstant({VgprRB, Ty}, 0); |
| B.buildSelect(MI.getOperand(0).getReg(), Src, True, False); |
| } |
| if (Ty == S64) { |
| auto True = B.buildConstant({VgprRB, S32}, Opc == G_SEXT ? -1 : 1); |
| auto False = B.buildConstant({VgprRB, S32}, 0); |
| auto Sel = B.buildSelect({VgprRB, S32}, Src, True, False); |
| B.buildMergeValues( |
| MI.getOperand(0).getReg(), |
| {Sel.getReg(0), Opc == G_SEXT ? Sel.getReg(0) : False.getReg(0)}); |
| } |
| MI.eraseFromParent(); |
| return; |
| } |
| case UniExtToSel: { |
| LLT Ty = MRI.getType(MI.getOperand(0).getReg()); |
| auto True = B.buildConstant({SgprRB, Ty}, |
| MI.getOpcode() == AMDGPU::G_SEXT ? -1 : 1); |
| auto False = B.buildConstant({SgprRB, Ty}, 0); |
| // Input to G_{Z|S}EXT is 'Legalizer legal' S1. Most common case is compare. |
| // We are making select here. S1 cond was already 'any-extended to S32' + |
| // 'AND with 1 to clean high bits' by Sgpr32AExtBoolInReg. |
| B.buildSelect(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(), True, |
| False); |
| MI.eraseFromParent(); |
| return; |
| } |
| case Ext32To64: { |
| const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg()); |
| MachineInstrBuilder Hi; |
| |
| if (MI.getOpcode() == AMDGPU::G_ZEXT) { |
| Hi = B.buildConstant({RB, S32}, 0); |
| } else { |
| // Replicate sign bit from 32-bit extended part. |
| auto ShiftAmt = B.buildConstant({RB, S32}, 31); |
| Hi = B.buildAShr({RB, S32}, MI.getOperand(1).getReg(), ShiftAmt); |
| } |
| |
| B.buildMergeLikeInstr(MI.getOperand(0).getReg(), |
| {MI.getOperand(1).getReg(), Hi}); |
| MI.eraseFromParent(); |
| return; |
| } |
| case UniCstExt: { |
| uint64_t ConstVal = MI.getOperand(1).getCImm()->getZExtValue(); |
| B.buildConstant(MI.getOperand(0).getReg(), ConstVal); |
| |
| MI.eraseFromParent(); |
| return; |
| } |
| case VgprToVccCopy: { |
| Register Src = MI.getOperand(1).getReg(); |
| LLT Ty = MRI.getType(Src); |
| // Take lowest bit from each lane and put it in lane mask. |
| // Lowering via compare, but we need to clean high bits first as compare |
| // compares all bits in register. |
| Register BoolSrc = MRI.createVirtualRegister({VgprRB, Ty}); |
| if (Ty == S64) { |
| auto Src64 = B.buildUnmerge({VgprRB, Ty}, Src); |
| auto One = B.buildConstant(VgprRB_S32, 1); |
| auto AndLo = B.buildAnd(VgprRB_S32, Src64.getReg(0), One); |
| auto Zero = B.buildConstant(VgprRB_S32, 0); |
| auto AndHi = B.buildAnd(VgprRB_S32, Src64.getReg(1), Zero); |
| B.buildMergeLikeInstr(BoolSrc, {AndLo, AndHi}); |
| } else { |
| assert(Ty == S32 || Ty == S16); |
| auto One = B.buildConstant({VgprRB, Ty}, 1); |
| B.buildAnd(BoolSrc, Src, One); |
| } |
| auto Zero = B.buildConstant({VgprRB, Ty}, 0); |
| B.buildICmp(CmpInst::ICMP_NE, MI.getOperand(0).getReg(), BoolSrc, Zero); |
| MI.eraseFromParent(); |
| return; |
| } |
| case SplitTo32: { |
| auto Op1 = B.buildUnmerge(VgprRB_S32, MI.getOperand(1).getReg()); |
| auto Op2 = B.buildUnmerge(VgprRB_S32, MI.getOperand(2).getReg()); |
| unsigned Opc = MI.getOpcode(); |
| auto Lo = B.buildInstr(Opc, {VgprRB_S32}, {Op1.getReg(0), Op2.getReg(0)}); |
| auto Hi = B.buildInstr(Opc, {VgprRB_S32}, {Op1.getReg(1), Op2.getReg(1)}); |
| B.buildMergeLikeInstr(MI.getOperand(0).getReg(), {Lo, Hi}); |
| MI.eraseFromParent(); |
| break; |
| } |
| case SplitLoad: { |
| LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); |
| unsigned Size = DstTy.getSizeInBits(); |
| // Even split to 128-bit loads |
| if (Size > 128) { |
| LLT B128; |
| if (DstTy.isVector()) { |
| LLT EltTy = DstTy.getElementType(); |
| B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy); |
| } else { |
| B128 = LLT::scalar(128); |
| } |
| if (Size / 128 == 2) |
| splitLoad(MI, {B128, B128}); |
| else if (Size / 128 == 4) |
| splitLoad(MI, {B128, B128, B128, B128}); |
| else { |
| LLVM_DEBUG(dbgs() << "MI: "; MI.dump();); |
| llvm_unreachable("SplitLoad type not supported for MI"); |
| } |
| } |
| // 64 and 32 bit load |
| else if (DstTy == S96) |
| splitLoad(MI, {S64, S32}, S32); |
| else if (DstTy == V3S32) |
| splitLoad(MI, {V2S32, S32}, S32); |
| else if (DstTy == V6S16) |
| splitLoad(MI, {V4S16, V2S16}, V2S16); |
| else { |
| LLVM_DEBUG(dbgs() << "MI: "; MI.dump();); |
| llvm_unreachable("SplitLoad type not supported for MI"); |
| } |
| break; |
| } |
| case WidenLoad: { |
| LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); |
| if (DstTy == S96) |
| widenLoad(MI, S128); |
| else if (DstTy == V3S32) |
| widenLoad(MI, V4S32, S32); |
| else if (DstTy == V6S16) |
| widenLoad(MI, V8S16, V2S16); |
| else { |
| LLVM_DEBUG(dbgs() << "MI: "; MI.dump();); |
| llvm_unreachable("WidenLoad type not supported for MI"); |
| } |
| break; |
| } |
| } |
| |
| // TODO: executeInWaterfallLoop(... WaterfallSgprs) |
| } |
| |
| LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) { |
| switch (ID) { |
| case Vcc: |
| case UniInVcc: |
| return LLT::scalar(1); |
| case Sgpr16: |
| return LLT::scalar(16); |
| case Sgpr32: |
| case Sgpr32Trunc: |
| case Sgpr32AExt: |
| case Sgpr32AExtBoolInReg: |
| case Sgpr32SExt: |
| case UniInVgprS32: |
| case Vgpr32: |
| return LLT::scalar(32); |
| case Sgpr64: |
| case Vgpr64: |
| return LLT::scalar(64); |
| case VgprP0: |
| return LLT::pointer(0, 64); |
| case SgprP1: |
| case VgprP1: |
| return LLT::pointer(1, 64); |
| case SgprP3: |
| case VgprP3: |
| return LLT::pointer(3, 32); |
| case SgprP4: |
| case VgprP4: |
| return LLT::pointer(4, 64); |
| case SgprP5: |
| case VgprP5: |
| return LLT::pointer(5, 32); |
| case SgprV4S32: |
| case VgprV4S32: |
| case UniInVgprV4S32: |
| return LLT::fixed_vector(4, 32); |
| default: |
| return LLT(); |
| } |
| } |
| |
| LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) { |
| switch (ID) { |
| case SgprB32: |
| case VgprB32: |
| case UniInVgprB32: |
| if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) || |
| Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) || |
| Ty == LLT::pointer(6, 32)) |
| return Ty; |
| return LLT(); |
| case SgprB64: |
| case VgprB64: |
| case UniInVgprB64: |
| if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) || |
| Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) || |
| Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64)) |
| return Ty; |
| return LLT(); |
| case SgprB96: |
| case VgprB96: |
| case UniInVgprB96: |
| if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) || |
| Ty == LLT::fixed_vector(6, 16)) |
| return Ty; |
| return LLT(); |
| case SgprB128: |
| case VgprB128: |
| case UniInVgprB128: |
| if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) || |
| Ty == LLT::fixed_vector(2, 64)) |
| return Ty; |
| return LLT(); |
| case SgprB256: |
| case VgprB256: |
| case UniInVgprB256: |
| if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) || |
| Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16)) |
| return Ty; |
| return LLT(); |
| case SgprB512: |
| case VgprB512: |
| case UniInVgprB512: |
| if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) || |
| Ty == LLT::fixed_vector(8, 64)) |
| return Ty; |
| return LLT(); |
| default: |
| return LLT(); |
| } |
| } |
| |
| const RegisterBank * |
| RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) { |
| switch (ID) { |
| case Vcc: |
| return VccRB; |
| case Sgpr16: |
| case Sgpr32: |
| case Sgpr64: |
| case SgprP1: |
| case SgprP3: |
| case SgprP4: |
| case SgprP5: |
| case SgprV4S32: |
| case SgprB32: |
| case SgprB64: |
| case SgprB96: |
| case SgprB128: |
| case SgprB256: |
| case SgprB512: |
| case UniInVcc: |
| case UniInVgprS32: |
| case UniInVgprV4S32: |
| case UniInVgprB32: |
| case UniInVgprB64: |
| case UniInVgprB96: |
| case UniInVgprB128: |
| case UniInVgprB256: |
| case UniInVgprB512: |
| case Sgpr32Trunc: |
| case Sgpr32AExt: |
| case Sgpr32AExtBoolInReg: |
| case Sgpr32SExt: |
| return SgprRB; |
| case Vgpr32: |
| case Vgpr64: |
| case VgprP0: |
| case VgprP1: |
| case VgprP3: |
| case VgprP4: |
| case VgprP5: |
| case VgprV4S32: |
| case VgprB32: |
| case VgprB64: |
| case VgprB96: |
| case VgprB128: |
| case VgprB256: |
| case VgprB512: |
| return VgprRB; |
| default: |
| return nullptr; |
| } |
| } |
| |
| void RegBankLegalizeHelper::applyMappingDst( |
| MachineInstr &MI, unsigned &OpIdx, |
| const SmallVectorImpl<RegBankLLTMappingApplyID> &MethodIDs) { |
| // Defs start from operand 0 |
| for (; OpIdx < MethodIDs.size(); ++OpIdx) { |
| if (MethodIDs[OpIdx] == None) |
| continue; |
| MachineOperand &Op = MI.getOperand(OpIdx); |
| Register Reg = Op.getReg(); |
| LLT Ty = MRI.getType(Reg); |
| [[maybe_unused]] const RegisterBank *RB = MRI.getRegBank(Reg); |
| |
| switch (MethodIDs[OpIdx]) { |
| // vcc, sgpr and vgpr scalars, pointers and vectors |
| case Vcc: |
| case Sgpr16: |
| case Sgpr32: |
| case Sgpr64: |
| case SgprP1: |
| case SgprP3: |
| case SgprP4: |
| case SgprP5: |
| case SgprV4S32: |
| case Vgpr32: |
| case Vgpr64: |
| case VgprP0: |
| case VgprP1: |
| case VgprP3: |
| case VgprP4: |
| case VgprP5: |
| case VgprV4S32: { |
| assert(Ty == getTyFromID(MethodIDs[OpIdx])); |
| assert(RB == getRegBankFromID(MethodIDs[OpIdx])); |
| break; |
| } |
| // sgpr and vgpr B-types |
| case SgprB32: |
| case SgprB64: |
| case SgprB96: |
| case SgprB128: |
| case SgprB256: |
| case SgprB512: |
| case VgprB32: |
| case VgprB64: |
| case VgprB96: |
| case VgprB128: |
| case VgprB256: |
| case VgprB512: { |
| assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty)); |
| assert(RB == getRegBankFromID(MethodIDs[OpIdx])); |
| break; |
| } |
| // uniform in vcc/vgpr: scalars, vectors and B-types |
| case UniInVcc: { |
| assert(Ty == S1); |
| assert(RB == SgprRB); |
| Register NewDst = MRI.createVirtualRegister(VccRB_S1); |
| Op.setReg(NewDst); |
| auto CopyS32_Vcc = |
| B.buildInstr(AMDGPU::G_AMDGPU_COPY_SCC_VCC, {SgprRB_S32}, {NewDst}); |
| B.buildTrunc(Reg, CopyS32_Vcc); |
| break; |
| } |
| case UniInVgprS32: |
| case UniInVgprV4S32: { |
| assert(Ty == getTyFromID(MethodIDs[OpIdx])); |
| assert(RB == SgprRB); |
| Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty}); |
| Op.setReg(NewVgprDst); |
| buildReadAnyLane(B, Reg, NewVgprDst, RBI); |
| break; |
| } |
| case UniInVgprB32: |
| case UniInVgprB64: |
| case UniInVgprB96: |
| case UniInVgprB128: |
| case UniInVgprB256: |
| case UniInVgprB512: { |
| assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty)); |
| assert(RB == SgprRB); |
| Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty}); |
| Op.setReg(NewVgprDst); |
| AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI); |
| break; |
| } |
| // sgpr trunc |
| case Sgpr32Trunc: { |
| assert(Ty.getSizeInBits() < 32); |
| assert(RB == SgprRB); |
| Register NewDst = MRI.createVirtualRegister(SgprRB_S32); |
| Op.setReg(NewDst); |
| B.buildTrunc(Reg, NewDst); |
| break; |
| } |
| case InvalidMapping: { |
| LLVM_DEBUG(dbgs() << "Instruction with Invalid mapping: "; MI.dump();); |
| llvm_unreachable("missing fast rule for MI"); |
| } |
| default: |
| llvm_unreachable("ID not supported"); |
| } |
| } |
| } |
| |
| void RegBankLegalizeHelper::applyMappingSrc( |
| MachineInstr &MI, unsigned &OpIdx, |
| const SmallVectorImpl<RegBankLLTMappingApplyID> &MethodIDs, |
| SmallSet<Register, 4> &SgprWaterfallOperandRegs) { |
| for (unsigned i = 0; i < MethodIDs.size(); ++OpIdx, ++i) { |
| if (MethodIDs[i] == None || MethodIDs[i] == IntrId || MethodIDs[i] == Imm) |
| continue; |
| |
| MachineOperand &Op = MI.getOperand(OpIdx); |
| Register Reg = Op.getReg(); |
| LLT Ty = MRI.getType(Reg); |
| const RegisterBank *RB = MRI.getRegBank(Reg); |
| |
| switch (MethodIDs[i]) { |
| case Vcc: { |
| assert(Ty == S1); |
| assert(RB == VccRB || RB == SgprRB); |
| if (RB == SgprRB) { |
| auto Aext = B.buildAnyExt(SgprRB_S32, Reg); |
| auto CopyVcc_Scc = |
| B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {VccRB_S1}, {Aext}); |
| Op.setReg(CopyVcc_Scc.getReg(0)); |
| } |
| break; |
| } |
| // sgpr scalars, pointers and vectors |
| case Sgpr16: |
| case Sgpr32: |
| case Sgpr64: |
| case SgprP1: |
| case SgprP3: |
| case SgprP4: |
| case SgprP5: |
| case SgprV4S32: { |
| assert(Ty == getTyFromID(MethodIDs[i])); |
| assert(RB == getRegBankFromID(MethodIDs[i])); |
| break; |
| } |
| // sgpr B-types |
| case SgprB32: |
| case SgprB64: |
| case SgprB96: |
| case SgprB128: |
| case SgprB256: |
| case SgprB512: { |
| assert(Ty == getBTyFromID(MethodIDs[i], Ty)); |
| assert(RB == getRegBankFromID(MethodIDs[i])); |
| break; |
| } |
| // vgpr scalars, pointers and vectors |
| case Vgpr32: |
| case Vgpr64: |
| case VgprP0: |
| case VgprP1: |
| case VgprP3: |
| case VgprP4: |
| case VgprP5: |
| case VgprV4S32: { |
| assert(Ty == getTyFromID(MethodIDs[i])); |
| if (RB != VgprRB) { |
| auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg); |
| Op.setReg(CopyToVgpr.getReg(0)); |
| } |
| break; |
| } |
| // vgpr B-types |
| case VgprB32: |
| case VgprB64: |
| case VgprB96: |
| case VgprB128: |
| case VgprB256: |
| case VgprB512: { |
| assert(Ty == getBTyFromID(MethodIDs[i], Ty)); |
| if (RB != VgprRB) { |
| auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg); |
| Op.setReg(CopyToVgpr.getReg(0)); |
| } |
| break; |
| } |
| // sgpr and vgpr scalars with extend |
| case Sgpr32AExt: { |
| // Note: this ext allows S1, and it is meant to be combined away. |
| assert(Ty.getSizeInBits() < 32); |
| assert(RB == SgprRB); |
| auto Aext = B.buildAnyExt(SgprRB_S32, Reg); |
| Op.setReg(Aext.getReg(0)); |
| break; |
| } |
| case Sgpr32AExtBoolInReg: { |
| // Note: this ext allows S1, and it is meant to be combined away. |
| assert(Ty.getSizeInBits() == 1); |
| assert(RB == SgprRB); |
| auto Aext = B.buildAnyExt(SgprRB_S32, Reg); |
| // Zext SgprS1 is not legal, this instruction is most of times meant to be |
| // combined away in RB combiner, so do not make AND with 1. |
| auto Cst1 = B.buildConstant(SgprRB_S32, 1); |
| auto BoolInReg = B.buildAnd(SgprRB_S32, Aext, Cst1); |
| Op.setReg(BoolInReg.getReg(0)); |
| break; |
| } |
| case Sgpr32SExt: { |
| assert(1 < Ty.getSizeInBits() && Ty.getSizeInBits() < 32); |
| assert(RB == SgprRB); |
| auto Sext = B.buildSExt(SgprRB_S32, Reg); |
| Op.setReg(Sext.getReg(0)); |
| break; |
| } |
| default: |
| llvm_unreachable("ID not supported"); |
| } |
| } |
| } |
| |
| void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) { |
| Register Dst = MI.getOperand(0).getReg(); |
| LLT Ty = MRI.getType(Dst); |
| |
| if (Ty == LLT::scalar(1) && MUI.isUniform(Dst)) { |
| B.setInsertPt(*MI.getParent(), MI.getParent()->getFirstNonPHI()); |
| |
| Register NewDst = MRI.createVirtualRegister(SgprRB_S32); |
| MI.getOperand(0).setReg(NewDst); |
| B.buildTrunc(Dst, NewDst); |
| |
| for (unsigned i = 1; i < MI.getNumOperands(); i += 2) { |
| Register UseReg = MI.getOperand(i).getReg(); |
| |
| auto DefMI = MRI.getVRegDef(UseReg)->getIterator(); |
| MachineBasicBlock *DefMBB = DefMI->getParent(); |
| |
| B.setInsertPt(*DefMBB, DefMBB->SkipPHIsAndLabels(std::next(DefMI))); |
| |
| auto NewUse = B.buildAnyExt(SgprRB_S32, UseReg); |
| MI.getOperand(i).setReg(NewUse.getReg(0)); |
| } |
| |
| return; |
| } |
| |
| // ALL divergent i1 phis should be already lowered and inst-selected into PHI |
| // with sgpr reg class and S1 LLT. |
| // Note: this includes divergent phis that don't require lowering. |
| if (Ty == LLT::scalar(1) && MUI.isDivergent(Dst)) { |
| LLVM_DEBUG(dbgs() << "Divergent S1 G_PHI: "; MI.dump();); |
| llvm_unreachable("Make sure to run AMDGPUGlobalISelDivergenceLowering " |
| "before RegBankLegalize to lower lane mask(vcc) phis"); |
| } |
| |
| // We accept all types that can fit in some register class. |
| // Uniform G_PHIs have all sgpr registers. |
| // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr. |
| if (Ty == LLT::scalar(32) || Ty == LLT::pointer(1, 64) || |
| Ty == LLT::pointer(4, 64)) { |
| return; |
| } |
| |
| LLVM_DEBUG(dbgs() << "G_PHI not handled: "; MI.dump();); |
| llvm_unreachable("type not supported"); |
| } |
| |
| [[maybe_unused]] static bool verifyRegBankOnOperands(MachineInstr &MI, |
| const RegisterBank *RB, |
| MachineRegisterInfo &MRI, |
| unsigned StartOpIdx, |
| unsigned EndOpIdx) { |
| for (unsigned i = StartOpIdx; i <= EndOpIdx; ++i) { |
| if (MRI.getRegBankOrNull(MI.getOperand(i).getReg()) != RB) |
| return false; |
| } |
| return true; |
| } |
| |
| void RegBankLegalizeHelper::applyMappingTrivial(MachineInstr &MI) { |
| const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg()); |
| // Put RB on all registers |
| unsigned NumDefs = MI.getNumDefs(); |
| unsigned NumOperands = MI.getNumOperands(); |
| |
| assert(verifyRegBankOnOperands(MI, RB, MRI, 0, NumDefs - 1)); |
| if (RB == SgprRB) |
| assert(verifyRegBankOnOperands(MI, RB, MRI, NumDefs, NumOperands - 1)); |
| |
| if (RB == VgprRB) { |
| B.setInstr(MI); |
| for (unsigned i = NumDefs; i < NumOperands; ++i) { |
| Register Reg = MI.getOperand(i).getReg(); |
| if (MRI.getRegBank(Reg) != RB) { |
| auto Copy = B.buildCopy({VgprRB, MRI.getType(Reg)}, Reg); |
| MI.getOperand(i).setReg(Copy.getReg(0)); |
| } |
| } |
| } |
| } |