blob: 4ef2f50a683711b49af3a7a4566afe91cdf6901b [file] [log] [blame]
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Try to reduce a function by inserting new return instructions. Try to insert
// an early return for each instruction value at that point. This requires
// mutating the return type, or finding instructions with a compatible type.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "llvm-reduce"
#include "ReduceValuesToReturn.h"
#include "Delta.h"
#include "Utils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
/// Return true if it is legal to emit a copy of the function with a non-void
/// return type.
static bool canUseNonVoidReturnType(const Function &F) {
// Functions with sret arguments must return void.
return !F.hasStructRetAttr() &&
CallingConv::supportsNonVoidReturnType(F.getCallingConv());
}
/// Return true if it's legal to replace a function return type to use \p Ty.
static bool isReallyValidReturnType(Type *Ty) {
return FunctionType::isValidReturnType(Ty) && !Ty->isTokenTy() &&
Ty->isFirstClassType();
}
/// Insert a ret inst after \p NewRetValue, which returns the value it produces.
static void rewriteFuncWithReturnType(Function &OldF, Value *NewRetValue) {
Type *NewRetTy = NewRetValue->getType();
FunctionType *OldFuncTy = OldF.getFunctionType();
FunctionType *NewFuncTy =
FunctionType::get(NewRetTy, OldFuncTy->params(), OldFuncTy->isVarArg());
LLVMContext &Ctx = OldF.getContext();
BasicBlock &EntryBB = OldF.getEntryBlock();
Instruction *NewRetI = dyn_cast<Instruction>(NewRetValue);
BasicBlock *NewRetBlock = NewRetI ? NewRetI->getParent() : &EntryBB;
BasicBlock::iterator NewValIt =
NewRetI ? NewRetI->getIterator() : EntryBB.end();
Type *OldRetTy = OldFuncTy->getReturnType();
// Hack up any return values in other blocks, we can't leave them as returning OldRetTy.
if (OldRetTy != NewRetTy) {
for (BasicBlock &OtherRetBB : OldF) {
if (&OtherRetBB != NewRetBlock) {
auto *OrigRI = dyn_cast<ReturnInst>(OtherRetBB.getTerminator());
if (!OrigRI)
continue;
OrigRI->eraseFromParent();
ReturnInst::Create(Ctx, getDefaultValue(NewRetTy), &OtherRetBB);
}
}
}
// Now prune any CFG edges we have to deal with.
//
// Use KeepOneInputPHIs in case the instruction we are using for the return is
// that phi.
// TODO: Could avoid this with fancier iterator management.
for (BasicBlock *Succ : successors(NewRetBlock))
Succ->removePredecessor(NewRetBlock, /*KeepOneInputPHIs=*/true);
// Now delete the tail of this block, in reverse to delete uses before defs.
for (Instruction &I : make_early_inc_range(
make_range(NewRetBlock->rbegin(), NewValIt.getReverse()))) {
Value *Replacement = getDefaultValue(I.getType());
I.replaceAllUsesWith(Replacement);
I.eraseFromParent();
}
ReturnInst::Create(Ctx, NewRetValue, NewRetBlock);
// TODO: We may be eliminating blocks that were originally unreachable. We
// probably ought to only be pruning blocks that became dead directly as a
// result of our pruning here.
EliminateUnreachableBlocks(OldF);
// Drop the incompatible attributes before we copy over to the new function.
if (OldRetTy != NewRetTy) {
AttributeList AL = OldF.getAttributes();
AttributeMask IncompatibleAttrs =
AttributeFuncs::typeIncompatible(NewRetTy, AL.getRetAttrs());
OldF.removeRetAttrs(IncompatibleAttrs);
}
// Now we need to remove any returned attributes from parameters.
for (Argument &A : OldF.args())
OldF.removeParamAttr(A.getArgNo(), Attribute::Returned);
Function *NewF =
Function::Create(NewFuncTy, OldF.getLinkage(), OldF.getAddressSpace(), "",
OldF.getParent());
NewF->removeFromParent();
OldF.getParent()->getFunctionList().insertAfter(OldF.getIterator(), NewF);
NewF->takeName(&OldF);
NewF->copyAttributesFrom(&OldF);
// Adjust the callsite uses to the new return type. We pre-filtered cases
// where the original call type was incorrectly non-void.
for (User *U : make_early_inc_range(OldF.users())) {
if (auto *CB = dyn_cast<CallBase>(U);
CB && CB->getCalledOperand() == &OldF) {
if (CB->getType()->isVoidTy()) {
FunctionType *CallType = CB->getFunctionType();
// The callsite may not match the new function type, in an undefined
// behavior way. Only mutate the local return type.
FunctionType *NewCallType = FunctionType::get(
NewRetTy, CallType->params(), CallType->isVarArg());
CB->mutateType(NewRetTy);
CB->setCalledFunction(NewCallType, NewF);
} else {
assert(CB->getType() == NewRetTy &&
"only handle exact return type match with non-void returns");
}
}
}
NewF->splice(NewF->begin(), &OldF);
OldF.replaceAllUsesWith(NewF);
// Preserve the parameters of OldF.
for (auto Z : zip_first(OldF.args(), NewF->args())) {
Argument &OldArg = std::get<0>(Z);
Argument &NewArg = std::get<1>(Z);
OldArg.replaceAllUsesWith(&NewArg);
NewArg.takeName(&OldArg);
}
OldF.eraseFromParent();
}
// Check if all the callsites of the void function are void, or happen to
// incorrectly use the new return type.
//
// TODO: We could make better effort to handle call type mismatches.
static bool canReplaceFuncUsers(const Function &F, Type *NewRetTy) {
for (const Use &U : F.uses()) {
const CallBase *CB = dyn_cast<CallBase>(U.getUser());
if (!CB)
continue;
// Normal pointer uses are trivially replacable.
if (!CB->isCallee(&U))
continue;
// We can trivially replace the correct void call sites.
if (CB->getType()->isVoidTy())
continue;
// We can trivially replace the call if the return type happened to match
// the new return type.
if (CB->getType() == NewRetTy)
continue;
// TODO: If all callsites have no uses, we could mutate the type of all the
// callsites. This will complicate the visit and rewrite ordering though.
LLVM_DEBUG(dbgs() << "Cannot replace used callsite with wrong type: " << *CB
<< '\n');
return false;
}
return true;
}
/// Return true if it's worthwhile replacing the non-void return value of \p BB
/// with \p Replacement
static bool shouldReplaceNonVoidReturnValue(const BasicBlock &BB,
const Value *Replacement) {
if (const auto *RI = dyn_cast<ReturnInst>(BB.getTerminator()))
return RI->getReturnValue() != Replacement;
return true;
}
static bool canHandleSuccessors(const BasicBlock &BB) {
// TODO: Handle invoke and other exotic terminators
if (!isa<ReturnInst, UnreachableInst, BranchInst, SwitchInst>(
BB.getTerminator()))
return false;
for (const BasicBlock *Succ : successors(&BB)) {
if (!Succ->canSplitPredecessors())
return false;
}
return true;
}
static bool shouldForwardValueToReturn(const BasicBlock &BB, const Value *V,
Type *RetTy) {
if (!isReallyValidReturnType(V->getType()))
return false;
return (RetTy->isVoidTy() || shouldReplaceNonVoidReturnValue(BB, V)) &&
canReplaceFuncUsers(*BB.getParent(), V->getType());
}
static bool tryForwardingInstructionsToReturn(
Function &F, Oracle &O,
std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {
// TODO: Should we try to expand returns to aggregate for function that
// already have a return value?
Type *RetTy = F.getReturnType();
for (BasicBlock &BB : F) {
if (!canHandleSuccessors(BB))
continue;
for (Instruction &I : BB) {
if (shouldForwardValueToReturn(BB, &I, RetTy) && !O.shouldKeep()) {
FuncsToReplace.emplace_back(&F, &I);
return true;
}
}
}
return false;
}
static bool tryForwardingArgumentsToReturn(
Function &F, Oracle &O,
std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {
Type *RetTy = F.getReturnType();
BasicBlock &EntryBB = F.getEntryBlock();
for (Argument &A : F.args()) {
if (shouldForwardValueToReturn(EntryBB, &A, RetTy) && !O.shouldKeep()) {
FuncsToReplace.emplace_back(&F, &A);
return true;
}
}
return false;
}
void llvm::reduceArgumentsToReturnDeltaPass(Oracle &O,
ReducerWorkItem &WorkItem) {
Module &Program = WorkItem.getModule();
// We're going to chaotically hack on the other users of the function in other
// functions, so we need to collect a worklist of returns to replace.
std::vector<std::pair<Function *, Value *>> FuncsToReplace;
for (Function &F : Program.functions()) {
if (!F.isDeclaration() && canUseNonVoidReturnType(F))
tryForwardingArgumentsToReturn(F, O, FuncsToReplace);
}
for (auto [F, NewRetVal] : FuncsToReplace)
rewriteFuncWithReturnType(*F, NewRetVal);
}
void llvm::reduceInstructionsToReturnDeltaPass(Oracle &O,
ReducerWorkItem &WorkItem) {
Module &Program = WorkItem.getModule();
// We're going to chaotically hack on the other users of the function in other
// functions, so we need to collect a worklist of returns to replace.
std::vector<std::pair<Function *, Value *>> FuncsToReplace;
for (Function &F : Program.functions()) {
if (!F.isDeclaration() && canUseNonVoidReturnType(F))
tryForwardingInstructionsToReturn(F, O, FuncsToReplace);
}
for (auto [F, NewRetVal] : FuncsToReplace)
rewriteFuncWithReturnType(*F, NewRetVal);
}