Fix the case when structs have multiple function pointer members

This commit is contained in:
topjohnwu 2018-12-14 12:11:00 -05:00
parent a619a9b3d7
commit e1dd408c4a
2 changed files with 56 additions and 49 deletions

View File

@ -41,7 +41,7 @@ struct CPI : public ModulePass {
for (auto s : M.getIdentifiedStructTypes()) {
for (unsigned i = 0; i < s->getNumElements(); ++i) {
if (isFunctionPtr(s->getElementType(i))) {
sensitiveStructs[s].insert(i);
ssMap[s].push_back(i);
}
}
}
@ -66,13 +66,14 @@ private:
PointerType *voidPT;
PointerType *voidPPT;
std::map<Type*, std::set<int> > sensitiveStructs;
// A map of StructType to the list of entries numbers that are function pointers
std::map<Type*, std::vector<int> > ssMap;
void runOnFunction(Function &F) {
bool hasInject = false;
for (auto &bb : F) {
hasInject |= swapFunctionPtrAlloca(bb);
hasInject |= swapStructAlloca(bb);
hasInject |= handleStructAlloca(bb);
}
// Stack maintenance
@ -140,32 +141,29 @@ private:
return !v.empty();
}
bool swapStructAlloca(BasicBlock &bb) {
bool handleStructAlloca(BasicBlock &bb) {
bool hasInject = false;
for (auto alloc : getSSAlloca(bb)) {
std::vector<GetElementPtrInst *> rmList;
for (auto user : alloc->users()) {
/* Find struct entry to function pointer (2nd GEP index, or 3rd operand) */
GetElementPtrInst *gep;
if ((gep = dyn_cast<GetElementPtrInst>(user))) {
ConstantInt *ci;
if (gep->getNumOperands() >= 3 && (ci = dyn_cast<ConstantInt>(gep->getOperand(2)))) {
int entry = ci->getSExtValue();
if (sensitiveStructs[alloc->getAllocatedType()].count(entry)) {
for (auto alloc: getSSAlloca(bb)) {
auto elist = ssMap.find(alloc->getAllocatedType());
if (elist != ssMap.end()) {
for (int fpentry : elist->second) {
std::vector<GetElementPtrInst *> rmList;
for (auto user : alloc->users()) {
auto *gep = dyn_cast<GetElementPtrInst>(user);
if (isSensitiveGEP(gep, fpentry))
rmList.push_back(gep);
}
if (!rmList.empty()) {
hasInject = true;
IRBuilder<> b(alloc->getNextNode());
auto idx = b.CreateCall(smAlloca, None, alloc->getName() + "." + std::to_string(fpentry));
DEBUG(dbgs() << "ADD:" << *idx << "\n");
for (auto u: rmList) {
swapPtr(u, idx);
}
}
}
}
if (!rmList.empty()) {
hasInject = true;
IRBuilder<> b(alloc->getNextNode());
auto idx = b.CreateCall(smAlloca);
DEBUG(dbgs() << "ADD:" << *idx << "\n");
for (auto u: rmList) {
swapPtr(u, idx);
}
}
}
return hasInject;
}
@ -186,17 +184,24 @@ private:
std::vector<AllocaInst *> v;
AllocaInst *ai;
for (auto &I : bb) {
if ((ai = dyn_cast<AllocaInst>(&I))) {
auto i = sensitiveStructs.find(ai->getAllocatedType());
if (i != sensitiveStructs.end()) {
DEBUG(dbgs() << "SENS:" << *ai << "\n");
v.push_back(ai);
}
if ((ai = dyn_cast<AllocaInst>(&I)) && ssMap.count(ai->getAllocatedType())) {
DEBUG(dbgs() << "SENS:" << *ai << "\n");
v.push_back(ai);
}
}
return v;
}
/* Check struct entry to function pointer (2nd GEP index, or 3rd operand) */
bool isSensitiveGEP(GetElementPtrInst *gep, int fpentry) {
if (gep == nullptr)
return false;
ConstantInt *ci;
return gep->getNumOperands() >= 3 &&
(ci = dyn_cast<ConstantInt>(gep->getOperand(2))) &&
ci->getSExtValue() == fpentry;
}
bool isFunctionPtr(Type *T) {
PointerType *t;
return (t = dyn_cast<PointerType>(T)) && t->getElementType()->isFunctionTy();

View File

@ -3,28 +3,32 @@
#include <string.h>
struct foo {
int i;
void (*func)();
int i;
void (*func)();
};
struct bar {
void (*f1)();
int i;
void (*f2)();
void (*f1)();
int i;
void (*f2)();
};
static void T() {
printf("true\n");
printf("true\n");
}
static void F() {
printf("false\n");
printf("false\n");
}
static void test_2(int i);
static void test_1(int i) {
void (*fptr)();
fptr = i ? T : F;
fptr();
printf("* test_2\n");
test_2(i);
}
static void test_2(int i) {
@ -35,12 +39,12 @@ static void test_2(int i) {
fptr = T;
f.func = T;
b.f1 = T;
b.f2 = T;
b.f2 = F;
} else {
fptr = F;
f.func = F;
b.f1 = F;
b.f2 = F;
b.f2 = T;
}
fptr();
f.func();
@ -49,18 +53,16 @@ static void test_2(int i) {
}
int main(int argc, char const *argv[]) {
/* Prevent segfault */
if (argc < 2)
return 1;
/* Prevent segfault */
if (argc < 2)
return 1;
int val = atoi(argv[1]);
int val = atoi(argv[1]);
for (int i = 0; i < 3; ++i) {
printf("* test_1\n");
test_1(val);
printf("* test_2\n");
test_2(val);
}
printf("* test_1\n");
test_1(val);
printf("* test_1\n");
test_1(val);
return 0;
return 0;
}