blob: 27487172b58822bb7895bc27b3e939337e2fc504 [file] [log] [blame]
//===- IndirectCallChecks.cpp: Fast Indirect Call Check Instrumentation ---===//
// The SAFECode Compiler
//
// This file was developed by the LLVM research group and is distributed under
// the University of Illinois Open Source License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This pass instruments code with fast checks for indirect function calls.
//
//===----------------------------------------------------------------------===//
#include "safecode/Config/config.h"
#include "llvm/Pass.h"
#include "llvm/Module.h"
#include "llvm/Function.h"
#include "llvm/Instructions.h"
#include "llvm/GlobalValue.h"
#include "llvm/Support/CallSite.h"
#include "llvm/InlineAsm.h"
#include "llvm/CallingConv.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Constants.h"
#include "IndirectCallChecks.h"
#include "SCUtils.h"
#include <fstream>
#include <vector>
#include <sstream>
#define ENABLE_DSA 1
#define USING_SVA 0 //0=safecode user-space, 1=safecode kernel-space
#define IC_DEBUG 0 //print additional messages such as jump table in .s file
///////////////
#define OUTPUT_ASM_FILE "pass.s"
#define JUMP_TABLE_PREFIX "__"
#define JUMP_TABLE_BEGIN JUMP_TABLE_PREFIX "jump_table_begin" << jumpTableId <<
#define JUMP_TABLE_END JUMP_TABLE_PREFIX "jump_table_end" << jumpTableId <<
#define JUMP_TABLE_COLLECTION JUMP_TABLE_PREFIX "jump_table_collection"
#if ENABLE_DSA
#include <map>
#include <memory>
#include "dsa/CallTargets.h"
#endif
#if USING_SVA
#define LLVM_VERSION 19
#else
#define LLVM_VERSION 23
#endif
using namespace llvm;
#if LLVM_VERSION >= 20
using llvm::cerr;
#else
#include <iostream>
using std::cerr;
#endif
#if LLVM_VERSION >= 23
#define CREATE_LLVM_OBJECT(T, args) T::Create args
#else
#define CREATE_LLVM_OBJECT(T, args) new T args
#endif
#if IC_DEBUG
#define IC_DMSG(msg) cerr << "[DEBUG]: " << msg << "\n";
#define IC_PRINT(obj) (obj)->print(cerr);
#else
#define IC_DMSG(msg)
#define IC_PRINT(obj)
#endif
#define IC_WARN(msg) cerr << "[WARNING]: " << msg << "\n";
#define IC_PRINTWARN(obj) cerr << "[WARNING]: "; (obj)->print(cerr);
//
// Basic LLVM Types
//
static const Type * VoidType = 0;
namespace {
#if ENABLE_DSA
typedef std::set<Function*> function_set_type;
class JumpTableEntry {
public:
//declare indirectFunction and register it into the module
JumpTableEntry(const Function *target, Module* module) : target(target) {
std::string indirectName = buildName();
indirectFunction = CREATE_LLVM_OBJECT(Function, (
target->getFunctionType(),
GlobalValue::ExternalLinkage,
indirectName,
module
));
}
void writeToStream(std::ostream &out) const {
assert(target && indirectFunction);
const std::string &funcName = indirectFunction->getName();
IC_DMSG("writeToStream called for " << funcName << " entry" );
out << ".global " << funcName << "\n"
<< funcName << ":\n"
<< "jmp " << target->getName().str() << "\n"
;
}
Function *getIndirectFunction() const {
return indirectFunction;
}
const Function *getTarget() const {
return target;
}
private:
Function *indirectFunction;
const Function *target;
std::string buildName() const {
std::stringstream stream;
stream << JUMP_TABLE_PREFIX << target->getName().str();
return stream.str();
}
};
struct JumpTable {
private:
typedef std::vector<JumpTableEntry> entries_t;
Module* module;
public:
JumpTable(Module* m) : module(m) {}
template <class InputIterator>
JumpTable(InputIterator targetsBegin, InputIterator targetsEnd, int tableId) {
jumpTableId = tableId;
//create the entries
InputIterator iter = targetsBegin, end = targetsEnd;
for(; iter != end; ++iter) {
const Function *target = *iter;
entries.push_back(JumpTableEntry(target, module));
}
std::vector<const Type *> emptyFuncTyArgs;
FunctionType *emptyFuncTy = FunctionType::get(VoidType, emptyFuncTyArgs, false);
lowerBound = CREATE_LLVM_OBJECT(Function, (
emptyFuncTy,
GlobalValue::ExternalLinkage,
getName(),
module
));
upperBound = CREATE_LLVM_OBJECT(Function, (
emptyFuncTy,
GlobalValue::ExternalLinkage,
getNameEnd(),
module
));
}
//serializes the jump table
void writeToStream(std::ostream &out) const {
IC_DMSG("writeToStream called for " << getName() );
out << ".text\n"
<< ".global " << lowerBound->getName().str() << "\n"
<< lowerBound->getName().str() << ":\n"
;
entries_t::const_iterator iter = entries.begin(), end = entries.end();
for(; iter != end; ++iter) {
iter->writeToStream(out);
}
out << ".global " << upperBound->getName().str() << "\n"
<< upperBound->getName().str() << ":\n"
;
}
const JumpTableEntry &findEntry(Function *target) const {
entries_t::const_iterator iter = entries.begin(), end = entries.end();
for(; iter != end; ++iter) {
if(iter->getTarget() == target) {
return *iter;
}
}
//in the unlikely case we dont find an entry
return *(static_cast<JumpTableEntry*>(NULL));
}
Function *getLowerBound() const {
return lowerBound;
}
Function *getUpperBound() const {
return upperBound;
}
private:
int jumpTableId; //need this to emit unique begin/end labels
//the entries in this jump table
entries_t entries;
Function *lowerBound;
Function *upperBound;
std::string getName() const {
std::stringstream stream;
stream << "" JUMP_TABLE_BEGIN "";
return stream.str();
}
std::string getNameEnd() const {
std::stringstream stream;
stream << "" JUMP_TABLE_END "";
return stream.str();
}
};
class JumpTableCollection {
private:
//typedef hash_map<function_set_type, JumpTable, hashFunctionSet> jt_hash_type;
typedef std::vector<JumpTable*> vec_tbl_t;
typedef std::map<const Function *, JumpTable*> map_tbl_t;
Module*& module;
public:
JumpTableCollection(Module*& m) : module(m), counter(0) {}
~JumpTableCollection() {
vec_tbl_t::iterator iter, end;
for(iter = tables.begin(), end = tables.end(); iter != end; ++iter) {
JumpTable *jt = *iter;
delete jt;
}
}
//inserts this table into the collection
//note that if the targets set was already in a previous table
//then we do nothing
//
//if the set of targets is fresh, insert into collection
//
//returns the jump table for these targets
template <class InputIterator>
JumpTable *createTable(InputIterator targetsBegin, InputIterator targetsEnd) {
InputIterator iter = targetsBegin, end = targetsEnd;
assert(iter != end);
const Function *f = *iter;
//already have a jump table for this?
map_tbl_t::iterator map_iter = tablesByFunction.find(f);
if(map_iter != tablesByFunction.end()) {
return map_iter->second;
}
//dont have a jump table for this, lets create one
JumpTable *jt = new JumpTable(targetsBegin, targetsEnd, counter++);
tables.push_back(jt);
//register all functions with the new jump table
for(; iter != end; ++iter) {
f = *iter;
tablesByFunction[f] = jt;
}
return jt;
}
//tries to find the Jump Table by the function in it
//
//return null if this function is not in a jump table
JumpTable *findTable(const Function *target) const {
map_tbl_t::const_iterator iter = tablesByFunction.find(target);
if(iter == tablesByFunction.end())
return NULL;
else
return iter->second;
}
//serializes all the jump tables
void writeToStream(std::ostream &out) const {
vec_tbl_t::const_iterator iter, end;
IC_DMSG("writeToStream called for collection");
for(iter = tables.begin(), end = tables.end(); iter != end; ++iter) {
const JumpTable *jt = *iter;
jt->writeToStream(out);
}
}
void createInlineAsm(Module &M) const {
std::vector<const Type *> emptyFuncTyArgs;
FunctionType *emptyFuncTy = FunctionType::get(VoidType, emptyFuncTyArgs, false);
std::stringstream stream;
writeToStream(stream);
InlineAsm *assembly = InlineAsm::get(
emptyFuncTy,
stream.str(),
"~{dirflag},~{fpsr},~{flags}",
true
);
Function *F = CREATE_LLVM_OBJECT(Function, (
emptyFuncTy,
GlobalValue::ExternalLinkage,
JUMP_TABLE_COLLECTION,
&M
));
BasicBlock *BB = CREATE_LLVM_OBJECT(BasicBlock, (getGlobalContext(), "entry", F));
CallInst *callAsm = CREATE_LLVM_OBJECT(CallInst, (assembly, "", BB));
callAsm->setCallingConv(CallingConv::C);
callAsm->setTailCall(true);
CREATE_LLVM_OBJECT(ReturnInst, (getGlobalContext(), BB));
}
private:
int counter;
vec_tbl_t tables;
map_tbl_t tablesByFunction;
};
#endif
struct IndirectCall : public ModulePass {
static char ID;
std::ofstream *asmStream;
Module *module;
JumpTableCollection tableCollection;
#if ENABLE_DSA/*{{{*/
typedef std::list<CallSite>::iterator CallSiteIterator;
typedef std::vector<Function*>::iterator CalleeIterator;
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<CallTargetFinder>();
}
#endif/*}}}*/
#if LLVM_VERSION >= 20
IndirectCall() : ModulePass((intptr_t) &ID), tableCollection(module)
#else
IndirectCall()
#endif
{
#if IC_DEBUG
asmStream = new std::ofstream(OUTPUT_ASM_FILE);
#endif
}
~IndirectCall() {
#if IC_DEBUG
delete asmStream;
#endif
}
virtual bool runOnModule(Module &m) {
bool changed = false;
module = &m;
//
// Create needed LLVM types.
//
VoidType = Type::getVoidTy(getGlobalContext());
std::vector<Function*> functions;
{
//get all the functions in advance
//otherwise when we declare indirect functions we will get into infinite loop
Module::iterator iter, end;
for(iter = m.begin(), end = m.end(); iter != end; ++iter) {
functions.push_back(iter);
}
}
#if !ENABLE_DSA
//without DSA we just throw in all functions together
tableCollection.createTable(functions.begin(), functions.end());
#else
//create jump tables using DSA
CallTargetFinder* CTF = &getAnalysis<CallTargetFinder>();
CallSiteIterator cs_iter, cs_end = CTF->cs_end();
for(cs_iter = CTF->cs_begin(); cs_iter != cs_end; ++cs_iter) {
CallSite cs = *cs_iter;
if(!isIndirectCall(cs))
continue;
//handle incomplete callsites or 0-target callsites
if(!CTF->isComplete(cs)) {
IC_WARN("Call site is not complete, skipping bounds checks");
#if 0
IC_PRINTWARN(cs.getInstruction());
#endif
continue;
}
else if(CTF->begin(cs) == CTF->end(cs)) {
IC_WARN("Callsite has no targets, skipping bounds checks");
#if 0
IC_PRINTWARN(cs.getInstruction());
#endif
continue;
}
else {
IC_DMSG("Currently inspecting callsite: ");
IC_PRINT(cs.getInstruction());
}
JumpTable *jt = tableCollection.createTable(CTF->begin(cs), CTF->end(cs));
assert(jt);
insertBoundaryChecks(cs, jt);
}
#endif
//then go to use of a function and update it to a jump table entry
std::vector<Function*>::iterator iter, end;
for(iter = functions.begin(), end = functions.end(); iter != end; ++iter) {
Function *f = *iter;
JumpTable *jt = tableCollection.findTable(f);
if(!jt) continue; //skip functions that arent ever used indirectly
const JumpTableEntry &entry = jt->findEntry(f);
assert(&entry);
changed = runOnFunction(*f, entry) || changed;
}
#if IC_DEBUG
//write to a pass.s file
tableCollection.writeToStream(*asmStream);
#endif
tableCollection.createInlineAsm(m);
return changed;
}
//Split up the BasicBlock of the callsite into two and insert
//the boundary checks for the targets of the callsite
void insertBoundaryChecks(CallSite cs, JumpTable *jt) {
const Type *VoidPtrTy = getVoidPtrType();
Constant *indirectFuncFail = module->getOrInsertFunction (
"bchk_ind_fail",
VoidType,
VoidPtrTy,
NULL
);
//%x = call %target (...)
Instruction *I = cs.getInstruction();
BasicBlock *topBB = I->getParent();
BasicBlock *bottomBB = topBB->splitBasicBlock(I, "do_indirect_call");
//we have an unconditional branch to bottomBB
//but remove it since we'll create a conditional branch later
topBB->getTerminator()->eraseFromParent();
//if outside of bounds call bchk_ind_fail(target)
//then resume execution
Value *targetPointer = cs.getCalledValue();
/* top:
* ...
* if (target <= jumpTableBegin || target >= jumpTableEnd)
* goto failed_ind_check
* else
* goto bottom
* failed_ind_check:
* bchk_ind_failed(%target);
* goto bottom
* bottom:
* %x = call %target(...)
*/
BitCastInst *castTarget = new BitCastInst(
targetPointer,
VoidPtrTy,
"",
topBB
);
ICmpInst *LT = new ICmpInst(*topBB,
ICmpInst::ICMP_ULT,
castTarget,
ConstantExpr::getBitCast(jt->getLowerBound(), VoidPtrTy),
""
);
ICmpInst *GT = new ICmpInst(*topBB,
ICmpInst::ICMP_UGT,
castTarget,
ConstantExpr::getBitCast(jt->getUpperBound(), VoidPtrTy),
""
);
BinaryOperator *OR = BinaryOperator::CreateOr(LT, GT, "", topBB);
BasicBlock *failedCheckBB = CREATE_LLVM_OBJECT(BasicBlock, (getGlobalContext(),
"failed_ind_check",
bottomBB->getParent(),
bottomBB
));
CREATE_LLVM_OBJECT(CallInst, (indirectFuncFail, castTarget, "", failedCheckBB));
CREATE_LLVM_OBJECT(BranchInst, (bottomBB, failedCheckBB));
CREATE_LLVM_OBJECT(BranchInst, (failedCheckBB, bottomBB, OR, topBB));
}
/*
* if f's address is ever taken,
* replace that use of f with __f
*
* __f will be inside a jump table
* with value 'jmp f'
*/
bool runOnFunction(Function &f, const JumpTableEntry &entry) {
bool changed = false;
cerr << "Function: " << f.getName().str() << "\n";
Function::use_iterator iter, end;
Function *indirect = entry.getIndirectFunction();
//go through all uses of this function
for(iter = f.use_begin(), end = f.use_end(); iter != end; ++iter) {
User *user = *iter;
//dont replace direct calls to this func with indirect calls
unsigned low = isa<CallInst>(user) || isa<InvokeInst>(user);
//replace all address-taken(f) with indirect address
unsigned high = user->getNumOperands();
for(unsigned i = low; i < high; ++i) {
Value *value = user->getOperand(i);
//replace f with __f
if(value == &f) {
user->setOperand(i, indirect);
changed = true;
}
}
}
return changed;
}
//returns true if the callsite is indirect, false if its direct
bool isIndirectCall(CallSite &cs) {
return !cs.getCalledFunction();
}
}; //end of struct IndirectCall
char IndirectCall::ID = 0;
RegisterPass<IndirectCall> X("indirect-call", "Indirect Call Pass");
}
namespace llvm {
ModulePass *createIndirectCallChecksPass() {
return new IndirectCall();
}
}