[GlobalISel][AArch64] Tail call libcalls. (#74929)
This tries to allow libcalls to be tail called, using a similar method
to DAG where the type is checked to make sure they match, and if so the
backend, through lowerCall checks that the tailcall is valid for all
arguments.
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index f2eb32e..7522353 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -149,7 +149,8 @@
return moreElementsVector(MI, Step.TypeIdx, Step.NewType);
case Custom:
LLVM_DEBUG(dbgs() << ".. Custom legalization\n");
- return LI.legalizeCustom(*this, MI) ? Legalized : UnableToLegalize;
+ return LI.legalizeCustom(*this, MI, LocObserver) ? Legalized
+ : UnableToLegalize;
default:
LLVM_DEBUG(dbgs() << ".. Unable to legalize\n");
return UnableToLegalize;
@@ -567,7 +568,8 @@
/// True if an instruction is in tail position in its caller. Intended for
/// legalizing libcalls as tail calls when possible.
-static bool isLibCallInTailPosition(MachineInstr &MI,
+static bool isLibCallInTailPosition(const CallLowering::ArgInfo &Result,
+ MachineInstr &MI,
const TargetInstrInfo &TII,
MachineRegisterInfo &MRI) {
MachineBasicBlock &MBB = *MI.getParent();
@@ -596,17 +598,12 @@
// RET_ReallyLR implicit $x0
auto Next = next_nodbg(MI.getIterator(), MBB.instr_end());
if (Next != MBB.instr_end() && Next->isCopy()) {
- switch (MI.getOpcode()) {
- default:
- llvm_unreachable("unsupported opcode");
- case TargetOpcode::G_BZERO:
+ if (MI.getOpcode() == TargetOpcode::G_BZERO)
return false;
- case TargetOpcode::G_MEMCPY:
- case TargetOpcode::G_MEMMOVE:
- case TargetOpcode::G_MEMSET:
- break;
- }
+ // For MEMCPY/MOMMOVE/MEMSET these will be the first use (the dst), as the
+ // mempy/etc routines return the same parameter. For other it will be the
+ // returned value.
Register VReg = MI.getOperand(0).getReg();
if (!VReg.isVirtual() || VReg != Next->getOperand(1).getReg())
return false;
@@ -622,7 +619,7 @@
if (Ret->getNumImplicitOperands() != 1)
return false;
- if (PReg != Ret->getOperand(0).getReg())
+ if (!Ret->getOperand(0).isReg() || PReg != Ret->getOperand(0).getReg())
return false;
// Skip over the COPY that we just validated.
@@ -639,34 +636,64 @@
llvm::createLibcall(MachineIRBuilder &MIRBuilder, const char *Name,
const CallLowering::ArgInfo &Result,
ArrayRef<CallLowering::ArgInfo> Args,
- const CallingConv::ID CC) {
+ const CallingConv::ID CC, LostDebugLocObserver &LocObserver,
+ MachineInstr *MI) {
auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
CallLowering::CallLoweringInfo Info;
Info.CallConv = CC;
Info.Callee = MachineOperand::CreateES(Name);
Info.OrigRet = Result;
+ if (MI)
+ Info.IsTailCall =
+ (Result.Ty->isVoidTy() ||
+ Result.Ty == MIRBuilder.getMF().getFunction().getReturnType()) &&
+ isLibCallInTailPosition(Result, *MI, MIRBuilder.getTII(),
+ *MIRBuilder.getMRI());
+
std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
if (!CLI.lowerCall(MIRBuilder, Info))
return LegalizerHelper::UnableToLegalize;
+ if (MI && Info.LoweredTailCall) {
+ assert(Info.IsTailCall && "Lowered tail call when it wasn't a tail call?");
+
+ // Check debug locations before removing the return.
+ LocObserver.checkpoint(true);
+
+ // We must have a return following the call (or debug insts) to get past
+ // isLibCallInTailPosition.
+ do {
+ MachineInstr *Next = MI->getNextNode();
+ assert(Next &&
+ (Next->isCopy() || Next->isReturn() || Next->isDebugInstr()) &&
+ "Expected instr following MI to be return or debug inst?");
+ // We lowered a tail call, so the call is now the return from the block.
+ // Delete the old return.
+ Next->eraseFromParent();
+ } while (MI->getNextNode());
+
+ // We expect to lose the debug location from the return.
+ LocObserver.checkpoint(false);
+ }
return LegalizerHelper::Legalized;
}
LegalizerHelper::LegalizeResult
llvm::createLibcall(MachineIRBuilder &MIRBuilder, RTLIB::Libcall Libcall,
const CallLowering::ArgInfo &Result,
- ArrayRef<CallLowering::ArgInfo> Args) {
+ ArrayRef<CallLowering::ArgInfo> Args,
+ LostDebugLocObserver &LocObserver, MachineInstr *MI) {
auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
const char *Name = TLI.getLibcallName(Libcall);
const CallingConv::ID CC = TLI.getLibcallCallingConv(Libcall);
- return createLibcall(MIRBuilder, Name, Result, Args, CC);
+ return createLibcall(MIRBuilder, Name, Result, Args, CC, LocObserver, MI);
}
// Useful for libcalls where all operands have the same type.
static LegalizerHelper::LegalizeResult
simpleLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, unsigned Size,
- Type *OpType) {
+ Type *OpType, LostDebugLocObserver &LocObserver) {
auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
// FIXME: What does the original arg index mean here?
@@ -674,7 +701,8 @@
for (const MachineOperand &MO : llvm::drop_begin(MI.operands()))
Args.push_back({MO.getReg(), OpType, 0});
return createLibcall(MIRBuilder, Libcall,
- {MI.getOperand(0).getReg(), OpType, 0}, Args);
+ {MI.getOperand(0).getReg(), OpType, 0}, Args,
+ LocObserver, &MI);
}
LegalizerHelper::LegalizeResult
@@ -733,8 +761,9 @@
Info.CallConv = TLI.getLibcallCallingConv(RTLibcall);
Info.Callee = MachineOperand::CreateES(Name);
Info.OrigRet = CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0);
- Info.IsTailCall = MI.getOperand(MI.getNumOperands() - 1).getImm() &&
- isLibCallInTailPosition(MI, MIRBuilder.getTII(), MRI);
+ Info.IsTailCall =
+ MI.getOperand(MI.getNumOperands() - 1).getImm() &&
+ isLibCallInTailPosition(Info.OrigRet, MI, MIRBuilder.getTII(), MRI);
std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
if (!CLI.lowerCall(MIRBuilder, Info))
@@ -789,11 +818,11 @@
static LegalizerHelper::LegalizeResult
conversionLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, Type *ToType,
- Type *FromType) {
+ Type *FromType, LostDebugLocObserver &LocObserver) {
RTLIB::Libcall Libcall = getConvRTLibDesc(MI.getOpcode(), ToType, FromType);
- return createLibcall(MIRBuilder, Libcall,
- {MI.getOperand(0).getReg(), ToType, 0},
- {{MI.getOperand(1).getReg(), FromType, 0}});
+ return createLibcall(
+ MIRBuilder, Libcall, {MI.getOperand(0).getReg(), ToType, 0},
+ {{MI.getOperand(1).getReg(), FromType, 0}}, LocObserver, &MI);
}
static RTLIB::Libcall
@@ -829,7 +858,8 @@
//
LegalizerHelper::LegalizeResult
LegalizerHelper::createGetStateLibcall(MachineIRBuilder &MIRBuilder,
- MachineInstr &MI) {
+ MachineInstr &MI,
+ LostDebugLocObserver &LocObserver) {
const DataLayout &DL = MIRBuilder.getDataLayout();
auto &MF = MIRBuilder.getMF();
auto &MRI = *MIRBuilder.getMRI();
@@ -850,7 +880,8 @@
auto Res =
createLibcall(MIRBuilder, RTLibcall,
CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
- CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}));
+ CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}),
+ LocObserver, nullptr);
if (Res != LegalizerHelper::Legalized)
return Res;
@@ -867,7 +898,8 @@
// content of memory region.
LegalizerHelper::LegalizeResult
LegalizerHelper::createSetStateLibcall(MachineIRBuilder &MIRBuilder,
- MachineInstr &MI) {
+ MachineInstr &MI,
+ LostDebugLocObserver &LocObserver) {
const DataLayout &DL = MIRBuilder.getDataLayout();
auto &MF = MIRBuilder.getMF();
auto &MRI = *MIRBuilder.getMRI();
@@ -892,7 +924,8 @@
RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
return createLibcall(MIRBuilder, RTLibcall,
CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
- CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}));
+ CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}),
+ LocObserver, nullptr);
}
// The function is used to legalize operations that set default environment
@@ -902,7 +935,8 @@
// it is not true, the target must provide custom lowering.
LegalizerHelper::LegalizeResult
LegalizerHelper::createResetStateLibcall(MachineIRBuilder &MIRBuilder,
- MachineInstr &MI) {
+ MachineInstr &MI,
+ LostDebugLocObserver &LocObserver) {
const DataLayout &DL = MIRBuilder.getDataLayout();
auto &MF = MIRBuilder.getMF();
auto &Ctx = MF.getFunction().getContext();
@@ -919,7 +953,8 @@
RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
return createLibcall(MIRBuilder, RTLibcall,
CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
- CallLowering::ArgInfo({ Dest.getReg(), StatePtrTy, 0}));
+ CallLowering::ArgInfo({Dest.getReg(), StatePtrTy, 0}),
+ LocObserver, &MI);
}
LegalizerHelper::LegalizeResult
@@ -938,7 +973,7 @@
LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
unsigned Size = LLTy.getSizeInBits();
Type *HLTy = IntegerType::get(Ctx, Size);
- auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy);
+ auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy, LocObserver);
if (Status != Legalized)
return Status;
break;
@@ -974,7 +1009,7 @@
LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
return UnableToLegalize;
}
- auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy);
+ auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy, LocObserver);
if (Status != Legalized)
return Status;
break;
@@ -985,7 +1020,8 @@
Type *ToTy = getFloatTypeForLLT(Ctx, MRI.getType(MI.getOperand(0).getReg()));
if (!FromTy || !ToTy)
return UnableToLegalize;
- LegalizeResult Status = conversionLibcall(MI, MIRBuilder, ToTy, FromTy );
+ LegalizeResult Status =
+ conversionLibcall(MI, MIRBuilder, ToTy, FromTy, LocObserver);
if (Status != Legalized)
return Status;
break;
@@ -1000,7 +1036,8 @@
LegalizeResult Status = conversionLibcall(
MI, MIRBuilder,
ToSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx),
- FromSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx));
+ FromSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx),
+ LocObserver);
if (Status != Legalized)
return Status;
break;
@@ -1015,7 +1052,8 @@
LegalizeResult Status = conversionLibcall(
MI, MIRBuilder,
ToSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx),
- FromSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx));
+ FromSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx),
+ LocObserver);
if (Status != Legalized)
return Status;
break;
@@ -1032,19 +1070,20 @@
return Result;
}
case TargetOpcode::G_GET_FPMODE: {
- LegalizeResult Result = createGetStateLibcall(MIRBuilder, MI);
+ LegalizeResult Result = createGetStateLibcall(MIRBuilder, MI, LocObserver);
if (Result != Legalized)
return Result;
break;
}
case TargetOpcode::G_SET_FPMODE: {
- LegalizeResult Result = createSetStateLibcall(MIRBuilder, MI);
+ LegalizeResult Result = createSetStateLibcall(MIRBuilder, MI, LocObserver);
if (Result != Legalized)
return Result;
break;
}
case TargetOpcode::G_RESET_FPMODE: {
- LegalizeResult Result = createResetStateLibcall(MIRBuilder, MI);
+ LegalizeResult Result =
+ createResetStateLibcall(MIRBuilder, MI, LocObserver);
if (Result != Legalized)
return Result;
break;