blob: 3383175fc1bdb8dee6d2898298622494b5964028 [file] [log] [blame]
//===-- AMDGPURegBankLegalizeHelper.cpp -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// Implements actual lowering algorithms for each ID that can be used in
/// Rule.OperandMapping. Similar to legalizer helper but with register banks.
//
//===----------------------------------------------------------------------===//
#include "AMDGPURegBankLegalizeHelper.h"
#include "AMDGPUGlobalISelUtils.h"
#include "AMDGPUInstrInfo.h"
#include "AMDGPURegisterBankInfo.h"
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
#define DEBUG_TYPE "amdgpu-regbanklegalize"
using namespace llvm;
using namespace AMDGPU;
RegBankLegalizeHelper::RegBankLegalizeHelper(
MachineIRBuilder &B, const MachineUniformityInfo &MUI,
const RegisterBankInfo &RBI, const RegBankLegalizeRules &RBLRules)
: B(B), MRI(*B.getMRI()), MUI(MUI), RBI(RBI), RBLRules(RBLRules),
SgprRB(&RBI.getRegBank(AMDGPU::SGPRRegBankID)),
VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)),
VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {}
void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
const SetOfRulesForOpcode &RuleSet = RBLRules.getRulesForOpc(MI);
const RegBankLLTMapping &Mapping = RuleSet.findMappingForMI(MI, MRI, MUI);
SmallSet<Register, 4> WaterfallSgprs;
unsigned OpIdx = 0;
if (Mapping.DstOpMapping.size() > 0) {
B.setInsertPt(*MI.getParent(), std::next(MI.getIterator()));
applyMappingDst(MI, OpIdx, Mapping.DstOpMapping);
}
if (Mapping.SrcOpMapping.size() > 0) {
B.setInstr(MI);
applyMappingSrc(MI, OpIdx, Mapping.SrcOpMapping, WaterfallSgprs);
}
lower(MI, Mapping, WaterfallSgprs);
}
void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
MachineFunction &MF = B.getMF();
assert(MI.getNumMemOperands() == 1);
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
Register Dst = MI.getOperand(0).getReg();
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
Register Base = MI.getOperand(1).getReg();
LLT PtrTy = MRI.getType(Base);
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
SmallVector<Register, 4> LoadPartRegs;
unsigned ByteOffset = 0;
for (LLT PartTy : LLTBreakdown) {
Register BasePlusOffset;
if (ByteOffset == 0) {
BasePlusOffset = Base;
} else {
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
}
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
LoadPartRegs.push_back(LoadPart.getReg(0));
ByteOffset += PartTy.getSizeInBytes();
}
if (!MergeTy.isValid()) {
// Loads are of same size, concat or merge them together.
B.buildMergeLikeInstr(Dst, LoadPartRegs);
} else {
// Loads are not all of same size, need to unmerge them to smaller pieces
// of MergeTy type, then merge pieces to Dst.
SmallVector<Register, 4> MergeTyParts;
for (Register Reg : LoadPartRegs) {
if (MRI.getType(Reg) == MergeTy) {
MergeTyParts.push_back(Reg);
} else {
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
MergeTyParts.push_back(Unmerge.getReg(i));
}
}
B.buildMergeLikeInstr(Dst, MergeTyParts);
}
MI.eraseFromParent();
}
void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
LLT MergeTy) {
MachineFunction &MF = B.getMF();
assert(MI.getNumMemOperands() == 1);
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
Register Dst = MI.getOperand(0).getReg();
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
Register Base = MI.getOperand(1).getReg();
MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
if (WideTy.isScalar()) {
B.buildTrunc(Dst, WideLoad);
} else {
SmallVector<Register, 4> MergeTyParts;
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
LLT DstTy = MRI.getType(Dst);
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
for (unsigned i = 0; i < NumElts; ++i) {
MergeTyParts.push_back(Unmerge.getReg(i));
}
B.buildMergeLikeInstr(Dst, MergeTyParts);
}
MI.eraseFromParent();
}
void RegBankLegalizeHelper::lower(MachineInstr &MI,
const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &WaterfallSgprs) {
switch (Mapping.LoweringMethod) {
case DoNotLower:
return;
case VccExtToSel: {
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
Register Src = MI.getOperand(1).getReg();
unsigned Opc = MI.getOpcode();
if (Ty == S32 || Ty == S16) {
auto True = B.buildConstant({VgprRB, Ty}, Opc == G_SEXT ? -1 : 1);
auto False = B.buildConstant({VgprRB, Ty}, 0);
B.buildSelect(MI.getOperand(0).getReg(), Src, True, False);
}
if (Ty == S64) {
auto True = B.buildConstant({VgprRB, S32}, Opc == G_SEXT ? -1 : 1);
auto False = B.buildConstant({VgprRB, S32}, 0);
auto Sel = B.buildSelect({VgprRB, S32}, Src, True, False);
B.buildMergeValues(
MI.getOperand(0).getReg(),
{Sel.getReg(0), Opc == G_SEXT ? Sel.getReg(0) : False.getReg(0)});
}
MI.eraseFromParent();
return;
}
case UniExtToSel: {
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
auto True = B.buildConstant({SgprRB, Ty},
MI.getOpcode() == AMDGPU::G_SEXT ? -1 : 1);
auto False = B.buildConstant({SgprRB, Ty}, 0);
// Input to G_{Z|S}EXT is 'Legalizer legal' S1. Most common case is compare.
// We are making select here. S1 cond was already 'any-extended to S32' +
// 'AND with 1 to clean high bits' by Sgpr32AExtBoolInReg.
B.buildSelect(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(), True,
False);
MI.eraseFromParent();
return;
}
case Ext32To64: {
const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
MachineInstrBuilder Hi;
if (MI.getOpcode() == AMDGPU::G_ZEXT) {
Hi = B.buildConstant({RB, S32}, 0);
} else {
// Replicate sign bit from 32-bit extended part.
auto ShiftAmt = B.buildConstant({RB, S32}, 31);
Hi = B.buildAShr({RB, S32}, MI.getOperand(1).getReg(), ShiftAmt);
}
B.buildMergeLikeInstr(MI.getOperand(0).getReg(),
{MI.getOperand(1).getReg(), Hi});
MI.eraseFromParent();
return;
}
case UniCstExt: {
uint64_t ConstVal = MI.getOperand(1).getCImm()->getZExtValue();
B.buildConstant(MI.getOperand(0).getReg(), ConstVal);
MI.eraseFromParent();
return;
}
case VgprToVccCopy: {
Register Src = MI.getOperand(1).getReg();
LLT Ty = MRI.getType(Src);
// Take lowest bit from each lane and put it in lane mask.
// Lowering via compare, but we need to clean high bits first as compare
// compares all bits in register.
Register BoolSrc = MRI.createVirtualRegister({VgprRB, Ty});
if (Ty == S64) {
auto Src64 = B.buildUnmerge({VgprRB, Ty}, Src);
auto One = B.buildConstant(VgprRB_S32, 1);
auto AndLo = B.buildAnd(VgprRB_S32, Src64.getReg(0), One);
auto Zero = B.buildConstant(VgprRB_S32, 0);
auto AndHi = B.buildAnd(VgprRB_S32, Src64.getReg(1), Zero);
B.buildMergeLikeInstr(BoolSrc, {AndLo, AndHi});
} else {
assert(Ty == S32 || Ty == S16);
auto One = B.buildConstant({VgprRB, Ty}, 1);
B.buildAnd(BoolSrc, Src, One);
}
auto Zero = B.buildConstant({VgprRB, Ty}, 0);
B.buildICmp(CmpInst::ICMP_NE, MI.getOperand(0).getReg(), BoolSrc, Zero);
MI.eraseFromParent();
return;
}
case SplitTo32: {
auto Op1 = B.buildUnmerge(VgprRB_S32, MI.getOperand(1).getReg());
auto Op2 = B.buildUnmerge(VgprRB_S32, MI.getOperand(2).getReg());
unsigned Opc = MI.getOpcode();
auto Lo = B.buildInstr(Opc, {VgprRB_S32}, {Op1.getReg(0), Op2.getReg(0)});
auto Hi = B.buildInstr(Opc, {VgprRB_S32}, {Op1.getReg(1), Op2.getReg(1)});
B.buildMergeLikeInstr(MI.getOperand(0).getReg(), {Lo, Hi});
MI.eraseFromParent();
break;
}
case SplitLoad: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
unsigned Size = DstTy.getSizeInBits();
// Even split to 128-bit loads
if (Size > 128) {
LLT B128;
if (DstTy.isVector()) {
LLT EltTy = DstTy.getElementType();
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
} else {
B128 = LLT::scalar(128);
}
if (Size / 128 == 2)
splitLoad(MI, {B128, B128});
else if (Size / 128 == 4)
splitLoad(MI, {B128, B128, B128, B128});
else {
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
llvm_unreachable("SplitLoad type not supported for MI");
}
}
// 64 and 32 bit load
else if (DstTy == S96)
splitLoad(MI, {S64, S32}, S32);
else if (DstTy == V3S32)
splitLoad(MI, {V2S32, S32}, S32);
else if (DstTy == V6S16)
splitLoad(MI, {V4S16, V2S16}, V2S16);
else {
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
llvm_unreachable("SplitLoad type not supported for MI");
}
break;
}
case WidenLoad: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
if (DstTy == S96)
widenLoad(MI, S128);
else if (DstTy == V3S32)
widenLoad(MI, V4S32, S32);
else if (DstTy == V6S16)
widenLoad(MI, V8S16, V2S16);
else {
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
llvm_unreachable("WidenLoad type not supported for MI");
}
break;
}
}
// TODO: executeInWaterfallLoop(... WaterfallSgprs)
}
LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
switch (ID) {
case Vcc:
case UniInVcc:
return LLT::scalar(1);
case Sgpr16:
return LLT::scalar(16);
case Sgpr32:
case Sgpr32Trunc:
case Sgpr32AExt:
case Sgpr32AExtBoolInReg:
case Sgpr32SExt:
case UniInVgprS32:
case Vgpr32:
return LLT::scalar(32);
case Sgpr64:
case Vgpr64:
return LLT::scalar(64);
case VgprP0:
return LLT::pointer(0, 64);
case SgprP1:
case VgprP1:
return LLT::pointer(1, 64);
case SgprP3:
case VgprP3:
return LLT::pointer(3, 32);
case SgprP4:
case VgprP4:
return LLT::pointer(4, 64);
case SgprP5:
case VgprP5:
return LLT::pointer(5, 32);
case SgprV4S32:
case VgprV4S32:
case UniInVgprV4S32:
return LLT::fixed_vector(4, 32);
default:
return LLT();
}
}
LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
switch (ID) {
case SgprB32:
case VgprB32:
case UniInVgprB32:
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
Ty == LLT::pointer(6, 32))
return Ty;
return LLT();
case SgprB64:
case VgprB64:
case UniInVgprB64:
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
return Ty;
return LLT();
case SgprB96:
case VgprB96:
case UniInVgprB96:
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
Ty == LLT::fixed_vector(6, 16))
return Ty;
return LLT();
case SgprB128:
case VgprB128:
case UniInVgprB128:
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
Ty == LLT::fixed_vector(2, 64))
return Ty;
return LLT();
case SgprB256:
case VgprB256:
case UniInVgprB256:
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
return Ty;
return LLT();
case SgprB512:
case VgprB512:
case UniInVgprB512:
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
Ty == LLT::fixed_vector(8, 64))
return Ty;
return LLT();
default:
return LLT();
}
}
const RegisterBank *
RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
switch (ID) {
case Vcc:
return VccRB;
case Sgpr16:
case Sgpr32:
case Sgpr64:
case SgprP1:
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV4S32:
case SgprB32:
case SgprB64:
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512:
case UniInVcc:
case UniInVgprS32:
case UniInVgprV4S32:
case UniInVgprB32:
case UniInVgprB64:
case UniInVgprB96:
case UniInVgprB128:
case UniInVgprB256:
case UniInVgprB512:
case Sgpr32Trunc:
case Sgpr32AExt:
case Sgpr32AExtBoolInReg:
case Sgpr32SExt:
return SgprRB;
case Vgpr32:
case Vgpr64:
case VgprP0:
case VgprP1:
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV4S32:
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512:
return VgprRB;
default:
return nullptr;
}
}
void RegBankLegalizeHelper::applyMappingDst(
MachineInstr &MI, unsigned &OpIdx,
const SmallVectorImpl<RegBankLLTMappingApplyID> &MethodIDs) {
// Defs start from operand 0
for (; OpIdx < MethodIDs.size(); ++OpIdx) {
if (MethodIDs[OpIdx] == None)
continue;
MachineOperand &Op = MI.getOperand(OpIdx);
Register Reg = Op.getReg();
LLT Ty = MRI.getType(Reg);
[[maybe_unused]] const RegisterBank *RB = MRI.getRegBank(Reg);
switch (MethodIDs[OpIdx]) {
// vcc, sgpr and vgpr scalars, pointers and vectors
case Vcc:
case Sgpr16:
case Sgpr32:
case Sgpr64:
case SgprP1:
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV4S32:
case Vgpr32:
case Vgpr64:
case VgprP0:
case VgprP1:
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
break;
}
// sgpr and vgpr B-types
case SgprB32:
case SgprB64:
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512:
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512: {
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
break;
}
// uniform in vcc/vgpr: scalars, vectors and B-types
case UniInVcc: {
assert(Ty == S1);
assert(RB == SgprRB);
Register NewDst = MRI.createVirtualRegister(VccRB_S1);
Op.setReg(NewDst);
auto CopyS32_Vcc =
B.buildInstr(AMDGPU::G_AMDGPU_COPY_SCC_VCC, {SgprRB_S32}, {NewDst});
B.buildTrunc(Reg, CopyS32_Vcc);
break;
}
case UniInVgprS32:
case UniInVgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == SgprRB);
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
Op.setReg(NewVgprDst);
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
break;
}
case UniInVgprB32:
case UniInVgprB64:
case UniInVgprB96:
case UniInVgprB128:
case UniInVgprB256:
case UniInVgprB512: {
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
assert(RB == SgprRB);
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
Op.setReg(NewVgprDst);
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
break;
}
// sgpr trunc
case Sgpr32Trunc: {
assert(Ty.getSizeInBits() < 32);
assert(RB == SgprRB);
Register NewDst = MRI.createVirtualRegister(SgprRB_S32);
Op.setReg(NewDst);
B.buildTrunc(Reg, NewDst);
break;
}
case InvalidMapping: {
LLVM_DEBUG(dbgs() << "Instruction with Invalid mapping: "; MI.dump(););
llvm_unreachable("missing fast rule for MI");
}
default:
llvm_unreachable("ID not supported");
}
}
}
void RegBankLegalizeHelper::applyMappingSrc(
MachineInstr &MI, unsigned &OpIdx,
const SmallVectorImpl<RegBankLLTMappingApplyID> &MethodIDs,
SmallSet<Register, 4> &SgprWaterfallOperandRegs) {
for (unsigned i = 0; i < MethodIDs.size(); ++OpIdx, ++i) {
if (MethodIDs[i] == None || MethodIDs[i] == IntrId || MethodIDs[i] == Imm)
continue;
MachineOperand &Op = MI.getOperand(OpIdx);
Register Reg = Op.getReg();
LLT Ty = MRI.getType(Reg);
const RegisterBank *RB = MRI.getRegBank(Reg);
switch (MethodIDs[i]) {
case Vcc: {
assert(Ty == S1);
assert(RB == VccRB || RB == SgprRB);
if (RB == SgprRB) {
auto Aext = B.buildAnyExt(SgprRB_S32, Reg);
auto CopyVcc_Scc =
B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {VccRB_S1}, {Aext});
Op.setReg(CopyVcc_Scc.getReg(0));
}
break;
}
// sgpr scalars, pointers and vectors
case Sgpr16:
case Sgpr32:
case Sgpr64:
case SgprP1:
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
assert(RB == getRegBankFromID(MethodIDs[i]));
break;
}
// sgpr B-types
case SgprB32:
case SgprB64:
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
assert(RB == getRegBankFromID(MethodIDs[i]));
break;
}
// vgpr scalars, pointers and vectors
case Vgpr32:
case Vgpr64:
case VgprP0:
case VgprP1:
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
if (RB != VgprRB) {
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
Op.setReg(CopyToVgpr.getReg(0));
}
break;
}
// vgpr B-types
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
if (RB != VgprRB) {
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
Op.setReg(CopyToVgpr.getReg(0));
}
break;
}
// sgpr and vgpr scalars with extend
case Sgpr32AExt: {
// Note: this ext allows S1, and it is meant to be combined away.
assert(Ty.getSizeInBits() < 32);
assert(RB == SgprRB);
auto Aext = B.buildAnyExt(SgprRB_S32, Reg);
Op.setReg(Aext.getReg(0));
break;
}
case Sgpr32AExtBoolInReg: {
// Note: this ext allows S1, and it is meant to be combined away.
assert(Ty.getSizeInBits() == 1);
assert(RB == SgprRB);
auto Aext = B.buildAnyExt(SgprRB_S32, Reg);
// Zext SgprS1 is not legal, this instruction is most of times meant to be
// combined away in RB combiner, so do not make AND with 1.
auto Cst1 = B.buildConstant(SgprRB_S32, 1);
auto BoolInReg = B.buildAnd(SgprRB_S32, Aext, Cst1);
Op.setReg(BoolInReg.getReg(0));
break;
}
case Sgpr32SExt: {
assert(1 < Ty.getSizeInBits() && Ty.getSizeInBits() < 32);
assert(RB == SgprRB);
auto Sext = B.buildSExt(SgprRB_S32, Reg);
Op.setReg(Sext.getReg(0));
break;
}
default:
llvm_unreachable("ID not supported");
}
}
}
void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
Register Dst = MI.getOperand(0).getReg();
LLT Ty = MRI.getType(Dst);
if (Ty == LLT::scalar(1) && MUI.isUniform(Dst)) {
B.setInsertPt(*MI.getParent(), MI.getParent()->getFirstNonPHI());
Register NewDst = MRI.createVirtualRegister(SgprRB_S32);
MI.getOperand(0).setReg(NewDst);
B.buildTrunc(Dst, NewDst);
for (unsigned i = 1; i < MI.getNumOperands(); i += 2) {
Register UseReg = MI.getOperand(i).getReg();
auto DefMI = MRI.getVRegDef(UseReg)->getIterator();
MachineBasicBlock *DefMBB = DefMI->getParent();
B.setInsertPt(*DefMBB, DefMBB->SkipPHIsAndLabels(std::next(DefMI)));
auto NewUse = B.buildAnyExt(SgprRB_S32, UseReg);
MI.getOperand(i).setReg(NewUse.getReg(0));
}
return;
}
// ALL divergent i1 phis should be already lowered and inst-selected into PHI
// with sgpr reg class and S1 LLT.
// Note: this includes divergent phis that don't require lowering.
if (Ty == LLT::scalar(1) && MUI.isDivergent(Dst)) {
LLVM_DEBUG(dbgs() << "Divergent S1 G_PHI: "; MI.dump(););
llvm_unreachable("Make sure to run AMDGPUGlobalISelDivergenceLowering "
"before RegBankLegalize to lower lane mask(vcc) phis");
}
// We accept all types that can fit in some register class.
// Uniform G_PHIs have all sgpr registers.
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(1, 64) ||
Ty == LLT::pointer(4, 64)) {
return;
}
LLVM_DEBUG(dbgs() << "G_PHI not handled: "; MI.dump(););
llvm_unreachable("type not supported");
}
[[maybe_unused]] static bool verifyRegBankOnOperands(MachineInstr &MI,
const RegisterBank *RB,
MachineRegisterInfo &MRI,
unsigned StartOpIdx,
unsigned EndOpIdx) {
for (unsigned i = StartOpIdx; i <= EndOpIdx; ++i) {
if (MRI.getRegBankOrNull(MI.getOperand(i).getReg()) != RB)
return false;
}
return true;
}
void RegBankLegalizeHelper::applyMappingTrivial(MachineInstr &MI) {
const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
// Put RB on all registers
unsigned NumDefs = MI.getNumDefs();
unsigned NumOperands = MI.getNumOperands();
assert(verifyRegBankOnOperands(MI, RB, MRI, 0, NumDefs - 1));
if (RB == SgprRB)
assert(verifyRegBankOnOperands(MI, RB, MRI, NumDefs, NumOperands - 1));
if (RB == VgprRB) {
B.setInstr(MI);
for (unsigned i = NumDefs; i < NumOperands; ++i) {
Register Reg = MI.getOperand(i).getReg();
if (MRI.getRegBank(Reg) != RB) {
auto Copy = B.buildCopy({VgprRB, MRI.getType(Reg)}, Reg);
MI.getOperand(i).setReg(Copy.getReg(0));
}
}
}
}