This commit was manufactured by cvs2svn to create tag 'usenixsec09_final'.

llvm-svn: 88654
diff --git a/safecode/configure b/safecode/configure
index 02cb033..b277977 100755
--- a/safecode/configure
+++ b/safecode/configure
@@ -807,7 +807,6 @@
                           NO)
   --enable-safeio         Enable safe I/O checks (default is NO)
   --enable-safemmu        Enable safe MMU checks (default is NO)
-  --enable-safekstacks    Enable safe kernel stack checks (default is NO)
 
 Optional Packages:
   --with-PACKAGE[=ARG]    use PACKAGE [ARG=yes]
@@ -2472,15 +2471,6 @@
 
 fi;
 
-# Check whether --enable-safekstacks or --disable-safekstacks was given.
-if test "${enable_safekstacks+set}" = set; then
-  enableval="$enable_safekstacks"
-  cat >>confdefs.h <<\_ACEOF
-#define SVA_KSTACKS 1
-_ACEOF
-
-fi;
-
 
 
 # Check whether --with-poolalloc-srcdir or --without-poolalloc-srcdir was given.
diff --git a/safecode/lib/IndirectCallChecks/IndirectCallChecks.cpp b/safecode/lib/IndirectCallChecks/IndirectCallChecks.cpp
new file mode 100755
index 0000000..dcaa487
--- /dev/null
+++ b/safecode/lib/IndirectCallChecks/IndirectCallChecks.cpp
@@ -0,0 +1,600 @@
+#include "safecode/Config/config.h"
+#include "llvm/Pass.h"
+#include "llvm/Module.h"
+#include "llvm/Function.h"
+#include "llvm/Instructions.h"
+#include "llvm/GlobalValue.h"
+#include "llvm/Support/CallSite.h"
+#include "llvm/InlineAsm.h"
+#include "llvm/CallingConv.h"
+#include "llvm/ParameterAttributes.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Constants.h"
+
+#include "IndirectCallChecks.h"
+
+
+#include <fstream>
+#include <vector>
+#include <sstream>
+
+#define ENABLE_DSA 1
+#define USING_SVA 0 //0=safecode user-space, 1=safecode kernel-space
+#define IC_DEBUG 0 //print additional messages such as jump table in .s file
+
+///////////////
+#define OUTPUT_ASM_FILE "pass.s"
+#define JUMP_TABLE_PREFIX "__"
+#define JUMP_TABLE_BEGIN JUMP_TABLE_PREFIX "jump_table_begin" << jumpTableId <<
+#define JUMP_TABLE_END JUMP_TABLE_PREFIX "jump_table_end" << jumpTableId <<
+#define JUMP_TABLE_COLLECTION JUMP_TABLE_PREFIX "jump_table_collection"
+
+#if ENABLE_DSA
+#include <map>
+#include <memory>
+
+#include "dsa/CallTargets.h"
+#endif
+
+#if USING_SVA
+#define LLVM_VERSION 19
+#else
+#define LLVM_VERSION 23
+#endif
+
+using namespace llvm;
+
+#if LLVM_VERSION >= 20
+using llvm::cerr;
+#else
+#include <iostream>
+using std::cerr;
+#endif
+
+#if LLVM_VERSION >= 23
+#define CREATE_LLVM_OBJECT(T, args) T::Create args
+#else
+#define CREATE_LLVM_OBJECT(T, args) new T args
+#endif
+
+#if IC_DEBUG
+#define IC_DMSG(msg) cerr << "[DEBUG]: " << msg << "\n";
+#define IC_PRINT(obj) (obj)->print(cerr);
+#else
+#define IC_DMSG(msg)
+#define IC_PRINT(obj)
+#endif
+
+#define IC_WARN(msg) cerr << "[WARNING]: " << msg << "\n";
+#define IC_PRINTWARN(obj) cerr << "[WARNING]: "; (obj)->print(cerr);
+
+namespace {
+
+#if ENABLE_DSA
+    typedef std::set<Function*> function_set_type;
+
+    class JumpTableEntry {
+        public:
+
+        //declare indirectFunction and register it into the module
+        JumpTableEntry(Function *target) : target(target)  {
+            Module *module = target->getParent();
+
+            std::string indirectName = buildName();
+
+            indirectFunction = CREATE_LLVM_OBJECT(Function, (
+                    target->getFunctionType(), 
+                    GlobalValue::ExternalLinkage,
+                    indirectName, 
+                    module
+            ));  
+        }
+
+        void writeToStream(std::ostream &out) const {
+            assert(target && indirectFunction);
+
+            const std::string &funcName = indirectFunction->getName();
+
+            IC_DMSG("writeToStream called for " << funcName << " entry" );
+
+            out << ".global " << funcName << "\n"
+                << funcName << ":\n"
+                << "jmp " << target->getName() << "\n"
+                ;
+        }
+
+        Function *getIndirectFunction() const {
+            return indirectFunction;
+        }
+        Function *getTarget() const {
+            return target;
+        }
+
+        private:
+        Function *indirectFunction;
+        Function *target;
+
+        std::string buildName() const {
+            std::stringstream stream;
+
+            stream <<  JUMP_TABLE_PREFIX << target->getName();
+
+            return stream.str();
+        }
+    };
+
+    struct JumpTable {
+
+        private:
+            typedef std::vector<JumpTableEntry> entries_t;
+        public:
+
+        JumpTable()  {}
+
+        template <class InputIterator>
+        JumpTable(InputIterator targetsBegin, InputIterator targetsEnd, int tableId) {
+            jumpTableId = tableId;
+
+            Module *M = NULL;
+
+            //create the entries
+            InputIterator iter = targetsBegin, end = targetsEnd;
+            for(; iter != end; ++iter) {
+                Function *target = *iter;
+                M = target->getParent();
+
+                entries.push_back(JumpTableEntry(target));
+            }
+
+            assert(M);
+
+            std::vector<const Type *> emptyFuncTyArgs;
+            FunctionType *emptyFuncTy = FunctionType::get(Type::VoidTy, emptyFuncTyArgs, false); 
+
+            lowerBound = CREATE_LLVM_OBJECT(Function, (
+                    emptyFuncTy,
+                    GlobalValue::ExternalLinkage,
+                    getName(),
+                    M
+                    ));
+
+            upperBound = CREATE_LLVM_OBJECT(Function, (
+                    emptyFuncTy,
+                    GlobalValue::ExternalLinkage,
+                    getNameEnd(),
+                    M
+                    ));
+        }
+
+        //serializes the jump table
+        void writeToStream(std::ostream &out) const {
+
+            IC_DMSG("writeToStream called for " << getName() );
+
+            out << ".text\n"
+                << ".global " << lowerBound->getName() << "\n"
+                << lowerBound->getName() << ":\n"
+                ;
+
+            entries_t::const_iterator iter = entries.begin(), end = entries.end();
+            for(; iter != end; ++iter) {
+                iter->writeToStream(out);
+            }
+
+            out << ".global " << upperBound->getName() << "\n"
+                << upperBound->getName() << ":\n"
+                ;
+        }
+
+        const JumpTableEntry &findEntry(Function *target) const {
+            entries_t::const_iterator iter = entries.begin(), end = entries.end();
+
+            for(; iter != end; ++iter) {
+                if(iter->getTarget() == target) {
+                    return *iter;
+                }
+            }
+
+            //in the unlikely case we dont find an entry
+            return *(static_cast<JumpTableEntry*>(NULL));
+        }
+
+        Function *getLowerBound() const {
+            return lowerBound;
+        }
+
+        Function *getUpperBound() const {
+            return upperBound;
+        }
+
+        private:
+        int jumpTableId; //need this to emit unique begin/end labels
+
+        //the entries in this jump table
+        entries_t entries;
+
+        Function *lowerBound;
+        Function *upperBound;
+
+        std::string getName() const {
+            std::stringstream stream;
+
+            stream <<  "" JUMP_TABLE_BEGIN "";
+
+            return stream.str();
+        }
+
+        std::string getNameEnd() const {
+
+            std::stringstream stream;
+
+            stream <<  "" JUMP_TABLE_END "";
+
+            return stream.str();
+        }
+    };
+
+    class JumpTableCollection {
+        private:
+            //typedef hash_map<function_set_type, JumpTable, hashFunctionSet> jt_hash_type;
+
+            typedef std::vector<JumpTable*> vec_tbl_t;
+
+            typedef std::map<const Function *, JumpTable*> map_tbl_t;
+
+        public:
+        JumpTableCollection() : counter(0) {}
+
+        ~JumpTableCollection() {
+            vec_tbl_t::iterator iter, end;
+
+            for(iter = tables.begin(), end = tables.end(); iter != end; ++iter) {
+                JumpTable *jt = *iter;
+
+                delete jt;
+            }
+        }
+
+        //inserts this table into the collection
+        //note that if the targets set was already in a previous table
+        //then we do nothing
+        //
+        //if the set of targets is fresh, insert into collection
+        //
+        //returns the jump table for these targets
+        template <class InputIterator>
+        JumpTable *createTable(InputIterator targetsBegin, InputIterator targetsEnd) {
+
+            InputIterator iter = targetsBegin, end = targetsEnd;
+
+            assert(iter != end);
+
+            const Function *f = *iter;
+
+            //already have a jump table for this?
+            map_tbl_t::iterator map_iter = tablesByFunction.find(f);
+            if(map_iter != tablesByFunction.end()) {
+                return map_iter->second;
+            }
+
+            //dont have a jump table for this, lets create one
+            JumpTable *jt = new JumpTable(targetsBegin, targetsEnd, counter++);
+
+            tables.push_back(jt);
+
+            //register all functions with the new jump table
+            for(; iter != end; ++iter) {
+                f = *iter;
+                tablesByFunction[f] = jt;
+            }
+
+            return jt;
+        }
+
+        //tries to find the Jump Table by the function in it
+        //
+        //return null if this function is not in a jump table
+        JumpTable *findTable(const Function *target) const {
+            map_tbl_t::const_iterator iter = tablesByFunction.find(target);
+
+            if(iter == tablesByFunction.end())
+                return NULL;
+            else
+                return iter->second;
+        }
+
+        //serializes all the jump tables
+        void writeToStream(std::ostream &out) const {
+            vec_tbl_t::const_iterator iter, end;
+
+            IC_DMSG("writeToStream called for collection");
+
+            for(iter = tables.begin(), end = tables.end(); iter != end; ++iter) {
+                const JumpTable *jt = *iter;
+
+                jt->writeToStream(out);
+            }
+        }
+
+        void createInlineAsm(Module &M) const {
+            std::vector<const Type *> emptyFuncTyArgs;
+            FunctionType *emptyFuncTy = FunctionType::get(Type::VoidTy, emptyFuncTyArgs, false); 
+
+            std::stringstream stream;
+            writeToStream(stream);
+
+            InlineAsm *assembly = InlineAsm::get(
+                    emptyFuncTy, 
+                    stream.str(), 
+                    "~{dirflag},~{fpsr},~{flags}",
+                    true
+            );
+
+            Function *F = CREATE_LLVM_OBJECT(Function, (
+                    emptyFuncTy,
+                    GlobalValue::ExternalLinkage,
+                    JUMP_TABLE_COLLECTION,
+                    &M
+                    ));
+            BasicBlock *BB = CREATE_LLVM_OBJECT(BasicBlock, ("entry", F));
+
+            CallInst *callAsm = CREATE_LLVM_OBJECT(CallInst, (assembly, "", BB));
+            callAsm->setCallingConv(CallingConv::C);
+            callAsm->setTailCall(true);
+
+            CREATE_LLVM_OBJECT(ReturnInst, (BB));
+
+        }
+
+        private:
+        int counter;
+
+        vec_tbl_t tables;
+        map_tbl_t tablesByFunction;
+
+    };
+#endif
+
+    struct IndirectCall : public ModulePass {
+
+        static char ID;
+
+        std::ofstream *asmStream;
+        Module *module;
+
+        JumpTableCollection tableCollection;
+
+#if ENABLE_DSA/*{{{*/
+        typedef std::list<CallSite>::iterator CallSiteIterator;
+        typedef std::vector<Function*>::iterator CalleeIterator;
+
+        virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+          AU.addRequired<CallTargetFinder>();
+        }
+
+#endif/*}}}*/
+
+#if LLVM_VERSION >= 20
+        IndirectCall() : ModulePass((intptr_t) &ID)
+#else
+        IndirectCall()
+#endif
+        {
+#if IC_DEBUG
+            asmStream = new std::ofstream(OUTPUT_ASM_FILE);
+#endif
+        }
+
+        ~IndirectCall() {
+#if IC_DEBUG
+            delete asmStream;
+#endif
+        }
+
+        virtual bool runOnModule(Module &m) {
+            bool changed = false;
+            module = &m;
+
+            std::vector<Function*> functions;
+            {
+                //get all the functions in advance
+                //otherwise when we declare indirect functions we will get into infinite loop
+                Module::iterator iter, end;
+                for(iter = m.begin(), end = m.end(); iter != end; ++iter) {
+                    functions.push_back(iter);
+                }
+            }
+
+#if !ENABLE_DSA
+            //without DSA we just throw in all functions together
+            tableCollection.createTable(functions.begin(), functions.end());
+#else
+            //create jump tables using DSA
+            CallTargetFinder* CTF = &getAnalysis<CallTargetFinder>();
+            CallSiteIterator cs_iter, cs_end = CTF->cs_end();
+            for(cs_iter = CTF->cs_begin(); cs_iter != cs_end; ++cs_iter) {
+                CallSite cs = *cs_iter;
+                
+                if(!isIndirectCall(cs))
+                    continue;
+
+                //handle incomplete callsites or 0-target callsites
+                if(!CTF->isComplete(cs)) {
+                    IC_WARN("Call site is not complete, skipping bounds checks");
+                    IC_PRINTWARN(cs.getInstruction());
+                    continue;
+                }
+                else if(CTF->begin(cs) == CTF->end(cs)) {
+                    IC_WARN("Callsite has no targets, skipping bounds checks");
+                    IC_PRINTWARN(cs.getInstruction());
+                    continue;
+                }
+                else {
+                    IC_DMSG("Currently inspecting callsite: ");
+                    IC_PRINT(cs.getInstruction());
+                }
+
+                JumpTable *jt = tableCollection.createTable(CTF->begin(cs), CTF->end(cs));
+                assert(jt);
+
+                insertBoundaryChecks(cs, jt);
+            }
+
+#endif
+
+            //then go to use of a function and update it to a jump table entry
+            std::vector<Function*>::iterator iter, end;
+            for(iter = functions.begin(), end = functions.end(); iter != end; ++iter) {
+                Function *f = *iter;
+
+                JumpTable *jt = tableCollection.findTable(f);
+                if(!jt) continue; //skip functions that arent ever used indirectly
+
+                const JumpTableEntry &entry = jt->findEntry(f);
+                assert(&entry);
+
+                changed = runOnFunction(*f, entry) || changed;
+            }
+
+#if IC_DEBUG
+            //write to a pass.s file
+            tableCollection.writeToStream(*asmStream);
+#endif
+            tableCollection.createInlineAsm(m);
+
+            return changed;
+        }
+
+        //Split up the BasicBlock of the callsite into two and insert
+        //the boundary checks for the targets of the callsite
+        void insertBoundaryChecks(CallSite cs, JumpTable *jt) {
+            const Type *VoidPtrTy = PointerType::get(Type::Int8Ty, 0);
+
+            Constant *indirectFuncFail = module->getOrInsertFunction (
+                    "bchk_ind_fail",
+                    Type::VoidTy,
+                    VoidPtrTy,
+                    NULL
+            );
+
+            //%x = call %target (...)
+            Instruction *I = cs.getInstruction();
+
+            BasicBlock *topBB = I->getParent();
+            BasicBlock *bottomBB = topBB->splitBasicBlock(I, "do_indirect_call");
+
+            //we have an unconditional branch to bottomBB
+            //but remove it since we'll create a conditional branch later
+            topBB->getTerminator()->eraseFromParent();
+
+            //if outside of bounds call bchk_ind_fail(target)
+            //then resume execution
+
+            Value *targetPointer = cs.getCalledValue();
+
+            /* top:
+             *  ...
+             *  if (target <= jumpTableBegin || target >= jumpTableEnd)
+             *      goto failed_ind_check
+             *  else
+             *      goto bottom
+             * failed_ind_check:
+             *      bchk_ind_failed(%target);
+             *      goto bottom
+             * bottom:
+             *     %x = call %target(...)
+             */
+
+            BitCastInst *castTarget = new BitCastInst(
+                    targetPointer, 
+                    VoidPtrTy, 
+                    "",
+                    topBB
+            );
+
+            ICmpInst *LT = new ICmpInst(
+                    ICmpInst::ICMP_ULT,
+                    castTarget,
+                    ConstantExpr::getBitCast(jt->getLowerBound(), VoidPtrTy),
+                    "",
+                    topBB
+            );
+            ICmpInst *GT = new ICmpInst(
+                    ICmpInst::ICMP_UGT,
+                    castTarget,
+                    ConstantExpr::getBitCast(jt->getUpperBound(), VoidPtrTy),
+                    "",
+                    topBB
+            );
+
+            BinaryOperator *OR = BinaryOperator::createOr(LT, GT, "", topBB);
+
+            
+            BasicBlock *failedCheckBB = CREATE_LLVM_OBJECT(BasicBlock, (
+                    "failed_ind_check", 
+                    bottomBB->getParent(), 
+                    bottomBB
+            ));
+            CREATE_LLVM_OBJECT(CallInst, (indirectFuncFail, castTarget, "", failedCheckBB));
+            CREATE_LLVM_OBJECT(BranchInst, (bottomBB, failedCheckBB));
+
+            CREATE_LLVM_OBJECT(BranchInst, (failedCheckBB, bottomBB, OR, topBB));
+        }
+
+        /*
+         * if f's address is ever taken,
+         * replace that use of f with __f
+         *
+         * __f will be inside a jump table 
+         * with value 'jmp f'
+         */
+        bool runOnFunction(Function &f, const JumpTableEntry &entry) {
+
+            bool changed = false;
+
+            cerr << "Function: " << f.getName() << "\n";
+
+            Function::use_iterator iter, end;
+
+            Function *indirect = entry.getIndirectFunction();
+
+            //go through all uses of this function
+            for(iter = f.use_begin(), end = f.use_end(); iter != end; ++iter) {
+                User *user = *iter;
+
+                //dont replace direct calls to this func with indirect calls
+                unsigned low = isa<CallInst>(user) || isa<InvokeInst>(user);
+
+                //replace all address-taken(f) with indirect address
+                unsigned high = user->getNumOperands();
+                for(unsigned i = low; i < high; ++i) {
+                    Value *value = user->getOperand(i);
+
+                    //replace f with __f
+                    if(value == &f) {
+                        user->setOperand(i, indirect);
+                        changed = true;
+                    }
+                }
+
+            }
+
+            return changed;
+        }
+
+        //returns true if the callsite is indirect, false if its direct
+        bool isIndirectCall(CallSite &cs) {
+            return !cs.getCalledFunction();
+        }
+
+    }; //end of struct IndirectCall
+
+    char IndirectCall::ID = 0;
+    RegisterPass<IndirectCall> X("indirect-call", "Indirect Call Pass");
+}
+
+namespace llvm {
+    ModulePass *createIndirectCallChecksPass() {
+        return new IndirectCall();
+    }
+}
diff --git a/safecode/lib/IndirectCallChecks/Makefile b/safecode/lib/IndirectCallChecks/Makefile
new file mode 100755
index 0000000..4ce86d4
--- /dev/null
+++ b/safecode/lib/IndirectCallChecks/Makefile
@@ -0,0 +1,9 @@
+
+LEVEL = ../../
+
+LIBRARYNAME=indirectcalls
+
+BUILD_RELINKED=1
+
+include $(LEVEL)/Makefile.common
+