| //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===// |
| // |
| // This file is distributed under the University of Illinois Open Source |
| // License. See LICENSE.TXT for details. |
| // |
| //===----------------------------------------------------------------------===// |
| /// |
| /// \file |
| /// \brief A pass that instruments code with fast checks for indirect calls and |
| /// hooks for a function to check violations. |
| /// |
| //===----------------------------------------------------------------------===// |
| |
| #define DEBUG_TYPE "cfi" |
| |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/Statistic.h" |
| #include "llvm/Analysis/JumpInstrTableInfo.h" |
| #include "llvm/CodeGen/ForwardControlFlowIntegrity.h" |
| #include "llvm/CodeGen/JumpInstrTables.h" |
| #include "llvm/CodeGen/Passes.h" |
| #include "llvm/IR/Attributes.h" |
| #include "llvm/IR/CallSite.h" |
| #include "llvm/IR/Constants.h" |
| #include "llvm/IR/DerivedTypes.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/GlobalValue.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/InlineAsm.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/LLVMContext.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/Operator.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/IR/Verifier.h" |
| #include "llvm/Pass.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace llvm; |
| |
| STATISTIC(NumCFIIndirectCalls, |
| "Number of indirect call sites rewritten by the CFI pass"); |
| |
| char ForwardControlFlowIntegrity::ID = 0; |
| INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi", |
| "Control-Flow Integrity", true, true) |
| INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo); |
| INITIALIZE_PASS_DEPENDENCY(JumpInstrTables); |
| INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi", |
| "Control-Flow Integrity", true, true) |
| |
| ModulePass *llvm::createForwardControlFlowIntegrityPass() { |
| return new ForwardControlFlowIntegrity(); |
| } |
| |
| ModulePass *llvm::createForwardControlFlowIntegrityPass( |
| JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, |
| StringRef CFIFuncName) { |
| return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing, |
| CFIFuncName); |
| } |
| |
| // Checks to see if a given CallSite is making an indirect call, including |
| // cases where the indirect call is made through a bitcast. |
| static bool isIndirectCall(CallSite &CS) { |
| if (CS.getCalledFunction()) |
| return false; |
| |
| // Check the value to see if it is merely a bitcast of a function. In |
| // this case, it will translate to a direct function call in the resulting |
| // assembly, so we won't treat it as an indirect call here. |
| const Value *V = CS.getCalledValue(); |
| if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { |
| return !(CE->isCast() && isa<Function>(CE->getOperand(0))); |
| } |
| |
| // Otherwise, since we know it's a call, it must be an indirect call |
| return true; |
| } |
| |
| static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning"; |
| |
| ForwardControlFlowIntegrity::ForwardControlFlowIntegrity() |
| : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single), |
| CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") { |
| initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); |
| } |
| |
| ForwardControlFlowIntegrity::ForwardControlFlowIntegrity( |
| JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, |
| std::string CFIFuncName) |
| : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType), |
| CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) { |
| initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); |
| } |
| |
| ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {} |
| |
| void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const { |
| AU.addRequired<JumpInstrTableInfo>(); |
| AU.addRequired<JumpInstrTables>(); |
| } |
| |
| void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) { |
| // To get the indirect calls, we iterate over all functions and iterate over |
| // the list of basic blocks in each. We extract a total list of indirect calls |
| // before modifying any of them, since our modifications will modify the list |
| // of basic blocks. |
| for (Function &F : M) { |
| for (BasicBlock &BB : F) { |
| for (Instruction &I : BB) { |
| CallSite CS(&I); |
| if (!(CS && isIndirectCall(CS))) |
| continue; |
| |
| Value *CalledValue = CS.getCalledValue(); |
| |
| // Don't rewrite this instruction if the indirect call is actually just |
| // inline assembly, since our transformation will generate an invalid |
| // module in that case. |
| if (isa<InlineAsm>(CalledValue)) |
| continue; |
| |
| IndirectCalls.push_back(&I); |
| } |
| } |
| } |
| } |
| |
| void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M, |
| CFITables &CFIT) { |
| Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
| for (Instruction *I : IndirectCalls) { |
| CallSite CS(I); |
| Value *CalledValue = CS.getCalledValue(); |
| |
| // Get the function type for this call and look it up in the tables. |
| Type *VTy = CalledValue->getType(); |
| PointerType *PTy = dyn_cast<PointerType>(VTy); |
| Type *EltTy = PTy->getElementType(); |
| FunctionType *FunTy = dyn_cast<FunctionType>(EltTy); |
| FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy); |
| ++NumCFIIndirectCalls; |
| Constant *JumpTableStart = nullptr; |
| Constant *JumpTableMask = nullptr; |
| Constant *JumpTableSize = nullptr; |
| |
| // Some call sites have function types that don't correspond to any |
| // address-taken function in the module. This happens when function pointers |
| // are passed in from external code. |
| auto it = CFIT.find(TransformedTy); |
| if (it == CFIT.end()) { |
| // In this case, make sure that the function pointer will change by |
| // setting the mask and the start to be 0 so that the transformed |
| // function is 0. |
| JumpTableStart = ConstantInt::get(Int64Ty, 0); |
| JumpTableMask = ConstantInt::get(Int64Ty, 0); |
| JumpTableSize = ConstantInt::get(Int64Ty, 0); |
| } else { |
| JumpTableStart = it->second.StartValue; |
| JumpTableMask = it->second.MaskValue; |
| JumpTableSize = it->second.Size; |
| } |
| |
| rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask, |
| JumpTableSize); |
| } |
| |
| return; |
| } |
| |
| bool ForwardControlFlowIntegrity::runOnModule(Module &M) { |
| JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>(); |
| Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
| Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); |
| |
| // JumpInstrTableInfo stores information about the alignment of each entry. |
| // The alignment returned by JumpInstrTableInfo is alignment in bytes, not |
| // in the exponent. |
| ByteAlignment = JITI->entryByteAlignment(); |
| LogByteAlignment = llvm::Log2_64(ByteAlignment); |
| |
| // Set up tables for control-flow integrity based on information about the |
| // jump-instruction tables. |
| CFITables CFIT; |
| for (const auto &KV : JITI->getTables()) { |
| uint64_t Size = static_cast<uint64_t>(KV.second.size()); |
| uint64_t TableSize = NextPowerOf2(Size); |
| |
| int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment; |
| Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue); |
| Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size); |
| |
| // The base of the table is defined to be the first jumptable function in |
| // the table. |
| Function *First = KV.second.begin()->second; |
| Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy); |
| CFIT[KV.first].StartValue = JumpTableStartValue; |
| CFIT[KV.first].MaskValue = JumpTableMaskValue; |
| CFIT[KV.first].Size = JumpTableSize; |
| } |
| |
| if (CFIT.empty()) |
| return false; |
| |
| getIndirectCalls(M); |
| |
| if (!CFIEnforcing) { |
| addWarningFunction(M); |
| } |
| |
| // Update the instructions with the check and the indirect jump through our |
| // table. |
| updateIndirectCalls(M, CFIT); |
| |
| return true; |
| } |
| |
| void ForwardControlFlowIntegrity::addWarningFunction(Module &M) { |
| PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext()); |
| |
| // Get the type of the Warning Function: void (i8*, i8*), |
| // where the first argument is the name of the function in which the violation |
| // occurs, and the second is the function pointer that violates CFI. |
| SmallVector<Type *, 2> WarningFunArgs; |
| WarningFunArgs.push_back(CharPtrTy); |
| WarningFunArgs.push_back(CharPtrTy); |
| FunctionType *WarningFunTy = |
| FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false); |
| |
| if (!CFIFuncName.empty()) { |
| Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy); |
| if (!FailureFun) |
| report_fatal_error("Could not get or insert the function specified by" |
| " -cfi-func-name"); |
| } else { |
| // The default warning function swallows the warning and lets the call |
| // continue, since there's no generic way for it to print out this |
| // information. |
| Function *WarningFun = M.getFunction(cfi_failure_func_name); |
| if (!WarningFun) { |
| WarningFun = |
| Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage, |
| cfi_failure_func_name, &M); |
| } |
| |
| BasicBlock *Entry = |
| BasicBlock::Create(M.getContext(), "entry", WarningFun, 0); |
| ReturnInst::Create(M.getContext(), Entry); |
| } |
| } |
| |
| void ForwardControlFlowIntegrity::rewriteFunctionPointer( |
| Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart, |
| Constant *JumpTableMask, Constant *JumpTableSize) { |
| IRBuilder<> TempBuilder(I); |
| |
| Type *OrigFunType = FunPtr->getType(); |
| |
| BasicBlock *CurBB = cast<BasicBlock>(I->getParent()); |
| Function *CurF = cast<Function>(CurBB->getParent()); |
| Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
| |
| Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty); |
| Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty); |
| |
| Value *NewFunPtr = nullptr; |
| Value *Check = nullptr; |
| switch (CFIType) { |
| case CFIntegrity::Sub: { |
| // This is the subtract, mask, and add version. |
| // Subtract from the base. |
| Value *Sub = TempBuilder.CreateSub(TI, TStartInt); |
| |
| // Mask the difference to force this to be a table offset. |
| Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask); |
| |
| // Add it back to the base. |
| Value *Result = TempBuilder.CreateAdd(And, TStartInt); |
| |
| // Convert it back into a function pointer that we can call. |
| NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); |
| break; |
| } |
| case CFIntegrity::Ror: { |
| // This is the subtract and rotate version. |
| // Rotate right by the alignment value. The optimizer should recognize |
| // this sequence as a rotation. |
| |
| // This cast is safe, since unsigned is always a subset of uint64_t. |
| uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment); |
| Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64); |
| Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64); |
| |
| // Subtract from the base. |
| Value *Sub = TempBuilder.CreateSub(TI, TStartInt); |
| |
| // Create the equivalent of a rotate-right instruction. |
| Value *Shr = TempBuilder.CreateLShr(Sub, RightShift); |
| Value *Shl = TempBuilder.CreateShl(Sub, LeftShift); |
| Value *Or = TempBuilder.CreateOr(Shr, Shl); |
| |
| // Perform unsigned comparison to check for inclusion in the table. |
| Check = TempBuilder.CreateICmpULT(Or, JumpTableSize); |
| NewFunPtr = FunPtr; |
| break; |
| } |
| case CFIntegrity::Add: { |
| // This is the mask and add version. |
| // Mask the function pointer to turn it into an offset into the table. |
| Value *And = TempBuilder.CreateAnd(TI, JumpTableMask); |
| |
| // Then or this offset to the base and get the pointer value. |
| Value *Result = TempBuilder.CreateAdd(And, TStartInt); |
| |
| // Convert it back into a function pointer that we can call. |
| NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); |
| break; |
| } |
| } |
| |
| if (!CFIEnforcing) { |
| // If a check hasn't been added (in the rotation version), then check to see |
| // if it's the same as the original function. This check determines whether |
| // or not we call the CFI failure function. |
| if (!Check) |
| Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr); |
| BasicBlock *InvalidPtrBlock = |
| BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0); |
| BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I); |
| |
| // Remove the unconditional branch that connects the two blocks. |
| TerminatorInst *TermInst = CurBB->getTerminator(); |
| TermInst->eraseFromParent(); |
| |
| // Add a conditional branch that depends on the Check above. |
| BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB); |
| |
| // Call the warning function for this pointer, then continue. |
| Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock); |
| insertWarning(M, InvalidPtrBlock, BI, FunPtr); |
| } else { |
| // Modify the instruction to call this value. |
| CallSite CS(I); |
| CS.setCalledFunction(NewFunPtr); |
| } |
| } |
| |
| void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block, |
| Instruction *I, Value *FunPtr) { |
| Function *ParentFun = cast<Function>(Block->getParent()); |
| |
| // Get the function to call right before the instruction. |
| Function *WarningFun = nullptr; |
| if (CFIFuncName.empty()) { |
| WarningFun = M.getFunction(cfi_failure_func_name); |
| } else { |
| WarningFun = M.getFunction(CFIFuncName); |
| } |
| |
| assert(WarningFun && "Could not find the CFI failure function"); |
| |
| Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); |
| |
| IRBuilder<> WarningInserter(I); |
| // Create a mergeable GlobalVariable containing the name of the function. |
| Value *ParentNameGV = |
| WarningInserter.CreateGlobalString(ParentFun->getName()); |
| Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy); |
| Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy); |
| WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr); |
| } |