blob: e22398888bcbba2a1413881600287ddb7984948a [file] [log] [blame]
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/CallSite.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "vmkit/compiler.h"
#include "vmkit/system.h"
#include "vmkit/vmkit.h"
#include <dlfcn.h>
namespace vmkit {
class InlineBB {
public:
Symbol* symbol;
llvm::BasicBlock* bb;
uint64_t threshold;
InlineBB(Symbol* _symbol, llvm::BasicBlock* _bb, uint64_t _threshold) {
symbol = _symbol;
bb = _bb;
threshold = _threshold;
}
};
class FunctionInliner {
public:
llvm::Function* function;
llvm::SmallPtrSet<llvm::BasicBlock*, 32> visited;
llvm::SmallVector<InlineBB, 8> visitStack;
CompilationUnit* originalUnit;
Symbol* curSymbol;
bool onlyAlwaysInline;
uint64_t inlineThreshold;
FunctionInliner(CompilationUnit* unit, llvm::Function* _function, uint64_t _inlineThreshold, bool _onlyAlwaysInline) {
function = _function;
originalUnit = unit;
onlyAlwaysInline = _onlyAlwaysInline;
inlineThreshold = _inlineThreshold;
push(0, &function->getEntryBlock());
}
void push(Symbol* symbol, llvm::BasicBlock* bb) {
if(visited.insert(bb))
visitStack.push_back(InlineBB(symbol, bb, inlineThreshold));
}
llvm::BasicBlock* pop() {
InlineBB top = visitStack.pop_back_val();
curSymbol = top.symbol;
inlineThreshold = top.threshold;
return top.bb;
}
Symbol* tryInline(llvm::Function* callee, uint64_t* weight) {
if(callee->isIntrinsic())
return 0;
const char* id = callee->getName().data();
CompilationUnit* unit = curSymbol ? curSymbol->unit() : originalUnit;
if(!unit)
unit = originalUnit;
Symbol* symbol = unit->getSymbol(id, 0);
llvm::Function* bc;
//fprintf(stderr, " processing: %s => %p\n", id, symbol);
if(symbol) {
bc = symbol->llvmFunction();
if(!bc)
return 0;
} else {
bc = callee;
if(callee->isDeclaration() && callee->isMaterializable())
callee->Materialize();
if(callee->isDeclaration())
return 0;
uint8_t* addr = (uint8_t*)dlsym(SELF_HANDLE, id);
symbol = new(unit->allocator()) NativeSymbol(callee, addr);
unit->addSymbol(id, symbol);
}
if(bc->hasFnAttribute(llvm::Attribute::NoInline))
return 0;
if(bc->hasFnAttribute(llvm::Attribute::AlwaysInline))
return symbol;
if(onlyAlwaysInline)
return 0;
*weight = symbol->inlineWeight();
if(*weight == -1)
return 0;
//fprintf(stderr, " %s weight: %lld/%lld\n", bc->getName().data(), *weight, inlineThreshold);
return *weight < inlineThreshold ? symbol : 0;
}
bool visitBB(llvm::BasicBlock* bb) {
bool changed = 0;
bool takeNext = 0;
//fprintf(stderr, " visit basic block: %s\n", bb->getName().data());
for(llvm::BasicBlock::iterator it=bb->begin(), prev=0; it!=bb->end(); takeNext && (prev=it++)) {
llvm::Instruction *insn = it;
takeNext = 1;
//fprintf(stderr, " visit insn: ");
//insn->dump();
//fprintf(stderr, " %d operands\n", insn->getNumOperands());
for(unsigned i=0; i<insn->getNumOperands(); i++) {
llvm::Value* op = insn->getOperand(i);
//fprintf(stderr, " ----> ");
//op->dump();
//fprintf(stderr, " => %s\n", llvm::isa<llvm::GlobalValue>(op) ? "global" : "not global");
if(llvm::isa<llvm::GlobalValue>(op)) {
llvm::GlobalValue* gv = llvm::cast<llvm::GlobalValue>(op);
if(gv->getParent() != function->getParent()) {
llvm::Value* copy =
llvm::isa<llvm::Function>(gv) ?
function->getParent()->getOrInsertFunction(gv->getName().data(),
llvm::cast<llvm::Function>(gv)->getFunctionType()) :
function->getParent()->getOrInsertGlobal(gv->getName().data(), gv->getType()->getContainedType(0));
if(curSymbol && curSymbol->unit()) {
Symbol* remoteSymbol = curSymbol->unit()->getSymbol(gv->getName().data(), 0);
if(remoteSymbol)
originalUnit->addSymbol(gv->getName().data(), remoteSymbol);
}
//fprintf(stderr, "<<<reimporting>>>: %s\n", gv->getName().data());
insn->setOperand(i, copy);
}
}
}
if(insn->getOpcode() != llvm::Instruction::Call &&
insn->getOpcode() != llvm::Instruction::Invoke) {
continue;
}
llvm::CallSite call(insn);
llvm::Function* callee = call.getCalledFunction();
if(!callee)
continue;
uint64_t weight;
Symbol* symbol = tryInline(callee, &weight);
if(symbol) {
llvm::Function* bc = symbol->llvmFunction();
//fprintf(stderr, " inlining %s in %s\n", bc->getName().data(), function->getName().data());
if(llvm::isa<llvm::TerminatorInst>(insn)) {
llvm::TerminatorInst* terminator = llvm::cast<llvm::TerminatorInst>(insn);
for(unsigned i=0; i<terminator->getNumSuccessors(); i++)
push(curSymbol, terminator->getSuccessor(i));
} else {
size_t len = strlen(bc->getName().data());
char buf[len + 14];
memcpy(buf, bc->getName().data(), len);
memcpy(buf+len, ".after-inline", 14);
push(curSymbol, bb->splitBasicBlock(insn->getNextNode(), buf));
}
if(bc != callee)
call.setCalledFunction(bc);
llvm::InlineFunctionInfo ifi(0);
bool isInlined = llvm::InlineFunction(call, ifi, false);
//fprintf(stderr, " ok?: %d\n", isInlined);
changed |= isInlined;
if(isInlined) {
curSymbol = symbol;
inlineThreshold -= weight;
if(prev)
it = prev;
else {
takeNext = 0;
it = bb->begin();
}
} else {
symbol->markAsNeverInline();
if(bc != callee)
call.setCalledFunction(callee);
}
}
}
return changed;
}
bool proceed() {
bool changed = 0;
//fprintf(stderr, "visit function: %s\n", function->getName().data());
while(visitStack.size()) {
llvm::BasicBlock* bb = pop();
changed |= visitBB(bb);
llvm::TerminatorInst* terminator = bb->getTerminator();
if(terminator) {
for(unsigned i=0; i<terminator->getNumSuccessors(); i++)
push(curSymbol, terminator->getSuccessor(i));
}
}
if(changed) {
//function->dump();
}
return changed;
}
};
class FunctionInlinerPass : public llvm::FunctionPass {
public:
static char ID;
CompilationUnit* unit;
llvm::InlineCostAnalysis costAnalysis;
unsigned int inlineThreshold; // 225 in llvm
bool onlyAlwaysInline;
FunctionInlinerPass(CompilationUnit* _unit, unsigned int _inlineThreshold, bool _onlyAlwaysInline) :
FunctionPass(ID) {
unit = _unit;
inlineThreshold = _inlineThreshold;
onlyAlwaysInline = _onlyAlwaysInline;
}
virtual const char* getPassName() const { return "VMKit inliner"; }
bool runOnFunction(llvm::Function& function) {
FunctionInliner inliner(unit, &function, inlineThreshold, onlyAlwaysInline);
return inliner.proceed();
}
};
char FunctionInlinerPass::ID = 0;
#if 0
llvm::RegisterPass<FunctionInlinerPass> X("FunctionInlinerPass",
"Inlining Pass that inlines evaluator's functions.");
#endif
llvm::FunctionPass* createFunctionInlinerPass(CompilationUnit* compiler, bool onlyAlwaysInline) {
return new FunctionInlinerPass(compiler, 50*256, onlyAlwaysInline); /* aka 50 llvm instructions */
}
}