diff --git a/pass/CPI.cpp b/pass/CPI.cpp index 9eff278..128a459 100644 --- a/pass/CPI.cpp +++ b/pass/CPI.cpp @@ -1,4 +1,7 @@ #include +#include +#include +#include #include "llvm/Pass.h" #include "llvm/Support/Debug.h" @@ -33,9 +36,18 @@ struct CPI : public ModulePass { // Add function references in libsafe_rt smAlloca = cast(M.getOrInsertFunction("smAlloca", intT)); - smStore = cast(M.getOrInsertFunction("smStore", voidT, intT, voidPT)); + /*smStore = cast(M.getOrInsertFunction("smStore", voidT, intT, voidPT)); smLoad = cast(M.getOrInsertFunction("smLoad", voidPT, intT)); - smDeref = cast(M.getOrInsertFunction("smDeref", voidPPT, intT)); + smDeref = cast(M.getOrInsertFunction("smDeref", voidPPT, intT));*/ + + // Find all sensitive structs + for (auto s : M.getIdentifiedStructTypes()) { + for (unsigned i = 0; i < s->getNumElements(); ++i) { + if (isFunctionPtr(s->getElementType(i))) { + sensitiveStructs[s].insert(i); + } + } + } // Loop through all functions for (auto &F: M.getFunctionList()) { @@ -49,9 +61,6 @@ struct CPI : public ModulePass { private: Function *smAlloca; - Function *smStore; - Function *smLoad; - Function *smDeref; /* Temporarily unused */ Value *smPool; Value *smSp; @@ -60,33 +69,18 @@ private: PointerType *voidPT; PointerType *voidPPT; + /* Unused */ + /*Function *smStore; + Function *smLoad; + Function *smDeref;*/ + + std::map > sensitiveStructs; + void runOnFunction(Function &F) { bool hasInject = false; for (auto &bb : F) { - auto v = getCPSPtrs(bb); - hasInject |= !v.empty(); - for (auto I : v) { - IRBuilder<> b(I); - - // Get index from smAlloca - auto idx = b.CreateCall(smAlloca, None, I->getName()); - - // Swap out all uses (store and load) - for (auto a : I->users()) { - StoreInst *s; - LoadInst *l; - if ((s = dyn_cast(a))) { - swapStore(idx, s); - } else if ((l = dyn_cast(a))) { - swapLoad(idx, l); - } else { - /* TODO: Dunno when this will happen, spit logs for now */ - DEBUG(dbgs() << *I << "\n"); - } - } - if (I->getNumUses() == 0) - I->eraseFromParent(); - } + hasInject |= swapFunctionPtrAlloca(bb); + hasInject |= swapStructAlloca(bb); } // Stack maintenance @@ -108,6 +102,7 @@ private: auto off = b.CreateGEP(voidPT, pool, idx); auto cast = b.CreatePointerCast(store->getValueOperand(), voidPT); b.CreateStore(cast, off); + DEBUG(dbgs() << "SWAP:" << *store << "\n"); store->eraseFromParent(); } @@ -117,23 +112,107 @@ private: auto off = b.CreateGEP(voidPT, pool, idx); auto raw = b.CreateLoad(off); auto cast = b.CreatePointerCast(raw, load->getType()); + DEBUG(dbgs() << "SWAP:" << *load << "\n"); BasicBlock::iterator ii(load); ReplaceInstWithValue(load->getParent()->getInstList(), ii, cast); } - /* TODO: Might need better CPS pointer detection - * Currently only covers function pointer alloca */ - std::vector getCPSPtrs(BasicBlock &bb) { + void swapPtr(Instruction *from, Instruction *to) { + // Swap out all uses (store and load) + for (auto a : from->users()) { + StoreInst *s; + LoadInst *l; + if ((s = dyn_cast(a))) { + swapStore(to, s); + } else if ((l = dyn_cast(a))) { + swapLoad(to, l); + } else { + /* TODO: Dunno when this will happen, spit logs for now */ + DEBUG(dbgs() << "Unknown:" << *from << "\n"); + } + } + DEBUG(dbgs() << "RM:" << *from << "\n"); + if (from->getNumUses() == 0) + from->eraseFromParent(); + } + + bool swapFunctionPtrAlloca(BasicBlock &bb) { + bool hasInject = false; + auto v = getFunctionPtrAlloca(bb); + hasInject |= !v.empty(); + for (auto I : v) { + IRBuilder<> b(I); + auto idx = b.CreateCall(smAlloca, None, I->getName()); + DEBUG(dbgs() << "ADD:" << *idx << "\n"); + swapPtr(I, idx); + } + return hasInject; + } + + bool swapStructAlloca(BasicBlock &bb) { + bool hasInject = false; + auto v = getSSAlloca(bb); + for (auto alloc : v) { + std::vector rmList; + for (auto user : alloc->users()) { + /* Find dereference to function pointer */ + GetElementPtrInst *gep; + if ((gep = dyn_cast(user))) { + int i = 0; + for (auto &a : gep->indices()) { + /* Only get second dereference */ + if (i == 1) { + ConstantInt *ci; + if ((ci = dyn_cast(a))) { + int idx = ci->getSExtValue(); + if (sensitiveStructs[alloc->getAllocatedType()].count(idx)) { + rmList.push_back(gep); + } + } + break; + } + ++i; + } + } + } + if (!rmList.empty()) { + IRBuilder<> b(alloc->getNextNode()); + auto idx = b.CreateCall(smAlloca); + DEBUG(dbgs() << "ADD:" << *idx << "\n"); + for (auto u: rmList) { + swapPtr(u, idx); + } + } + } + return hasInject; + } + + std::vector getFunctionPtrAlloca(BasicBlock &bb) { std::vector v; for (Instruction &I : bb) { if (isAllocaFunctionPtr(I)) { - DEBUG(dbgs() << "CPS: " << I << "\n"); + DEBUG(dbgs() << "SENS:" << I << "\n"); v.push_back(&I); } } return v; } + std::vector getSSAlloca(BasicBlock &bb) { + std::vector v; + AllocaInst *ai; + for (auto &I : bb) { + if ((ai = dyn_cast(&I))) { + auto i = sensitiveStructs.find(ai->getAllocatedType()); + if (i != sensitiveStructs.end()) { + DEBUG(dbgs() << "SENS:" << *ai << "\n"); + v.push_back(ai); + } + } + } + return v; + } + bool isAllocaFunctionPtr(Instruction &I) { AllocaInst *i; return (i = dyn_cast(&I)) && isFunctionPtr(i->getAllocatedType()); diff --git a/tests/test.c b/tests/test.c index 5e84eb5..020abbc 100644 --- a/tests/test.c +++ b/tests/test.c @@ -7,6 +7,11 @@ struct foo { void (*func)(); }; +struct bar { + void (*func)(); + int i; +}; + static void f1() { printf("F1\n"); } @@ -17,7 +22,8 @@ static void f2() { int main(int argc, char const *argv[]) { void (*fptr)(); - struct foo bar; + struct foo f; + struct bar b; /* Prevent segfault */ if (argc < 2) @@ -25,13 +31,13 @@ int main(int argc, char const *argv[]) { if (strcmp(argv[1], "1") == 0) { fptr = f1; - bar.func = f2; + f.func = f2; } else { fptr = f2; - bar.func = f1; + f.func = f1; } fptr(); - bar.func(); + f.func(); return 0; }