Bug 1262671 - IPC sentinel checking (r=froydnj)

This commit is contained in:
Bill McCloskey 2016-04-27 11:13:53 -07:00
parent 291c555f34
commit b7441af61a
3 changed files with 92 additions and 20 deletions

View File

@ -18,6 +18,10 @@
#include "nsDebug.h"
#if !defined(RELEASE_BUILD) || defined(DEBUG)
#define SENTINEL_CHECKING
#endif
//------------------------------------------------------------------------------
static_assert(MOZ_ALIGNOF(Pickle::memberAlignmentType) >= MOZ_ALIGNOF(uint32_t),
@ -442,6 +446,26 @@ bool Pickle::ReadData(PickleIterator* iter, const char** data, int* length) cons
return ReadBytes(iter, data, *length);
}
bool Pickle::ReadSentinel(PickleIterator* iter, uint32_t sentinel) const {
#ifdef SENTINEL_CHECKING
uint32_t found;
if (!ReadUInt32(iter, &found)) {
return false;
}
return found == sentinel;
#else
return true;
#endif
}
bool Pickle::WriteSentinel(uint32_t sentinel) {
#ifdef SENTINEL_CHECKING
return WriteUInt32(sentinel);
#else
return true;
#endif
}
char* Pickle::BeginWrite(uint32_t length, uint32_t alignment) {
DCHECK(alignment % 4 == 0) << "Must be at least 32-bit aligned!";

View File

@ -122,6 +122,8 @@ class Pickle {
// Use it for reading the object sizes.
MOZ_MUST_USE bool ReadLength(PickleIterator* iter, int* result) const;
MOZ_WARN_UNUSED_RESULT bool ReadSentinel(PickleIterator* iter, uint32_t sentinel) const;
void EndRead(PickleIterator& iter) const {
DCHECK(iter.iter_ == end_of_payload());
}
@ -186,6 +188,8 @@ class Pickle {
bool WriteBytes(const void* data, int data_len,
uint32_t alignment = sizeof(memberAlignmentType));
bool WriteSentinel(uint32_t sentinel);
// Same as WriteData, but allows the caller to write directly into the
// Pickle. This saves a copy in cases where the data is not already
// available in a buffer. The caller should take care to not write more

View File

@ -51,6 +51,11 @@ lowered form of |tu|'''
## Helper code
##
def hashfunc(value):
h = hash(value) % 2**32
if h < 0: h += 2**32
return h
_NULL_ACTOR_ID = ExprLiteral.ZERO
_FREED_ACTOR_ID = ExprLiteral.ONE
@ -696,6 +701,7 @@ class StructDecl(ipdl.ast.StructDecl, HasFQName):
class _StructField(_CompoundTypeComponent):
def __init__(self, ipdltype, name, sd, side=None):
self.basename = name
fname = name
special = _hasVisibleActor(ipdltype)
if special:
@ -4403,12 +4409,13 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
ExprLiteral.ZERO),
cond=ExprBinary(ivar, '<', lenvar),
update=ExprPrefixUnop(ivar, '++'))
forwrite.addstmt(StmtExpr(
self.write(eltipdltype, ExprIndex(var, ivar), msgvar)))
forwrite.addstmt(
self.checkedWrite(eltipdltype, ExprIndex(var, ivar), msgvar,
sentinelKey=arraytype.name()))
write.addstmts([
StmtDecl(Decl(Type.UINT32, lenvar.name),
init=_callCxxArrayLength(var)),
StmtExpr(self.write(None, lenvar, msgvar)),
self.checkedWrite(None, lenvar, msgvar, sentinelKey=('length', arraytype.name())),
Whitespace.NL,
forwrite
])
@ -4422,13 +4429,15 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
forread.addstmt(
self.checkedRead(eltipdltype, ExprAddrOf(ExprIndex(favar, ivar)),
msgvar, itervar, errfnRead,
'\'' + eltipdltype.name() + '[i]\''))
'\'' + eltipdltype.name() + '[i]\'',
sentinelKey=arraytype.name()))
read.addstmts([
StmtDecl(Decl(_cxxArrayType(_cxxBareType(arraytype.basetype, self.side)), favar.name)),
StmtDecl(Decl(Type.UINT32, lenvar.name)),
self.checkedRead(None, ExprAddrOf(lenvar),
msgvar, itervar, errfnArrayLength,
[ arraytype.name() ]),
[ arraytype.name() ],
sentinelKey=('length', arraytype.name())),
Whitespace.NL,
StmtExpr(_callCxxArraySetLength(favar, lenvar)),
forread,
@ -4560,10 +4569,10 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
for f in sd.fields:
desc = '\'' + f.getMethod().name + '\' (' + f.ipdltype.name() + \
') member of \'' + intype.name + '\''
writefield = StmtExpr(self.write(f.ipdltype, get('.', f), msgvar))
writefield = self.checkedWrite(f.ipdltype, get('.', f), msgvar, sentinelKey=f.basename)
readfield = self.checkedRead(f.ipdltype,
ExprAddrOf(get('->', f)),
msgvar, itervar, errfnRead, desc)
msgvar, itervar, errfnRead, desc, sentinelKey=f.basename)
if f.special and f.side != self.side:
writefield = Whitespace(
"// skipping actor field that's meaningless on this side\n", indent=1)
@ -4596,13 +4605,14 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
ct = c.ipdltype
isactor = (ct.isIPDL() and ct.isActor())
caselabel = CaseLabel(typename +'::'+ c.enum())
origenum = c.enum()
writecase = StmtBlock()
if c.special and c.side != self.side:
writecase.addstmt(_fatalError('wrong side!'))
else:
wexpr = ExprCall(ExprSelect(var, '.', c.getTypeName()))
writecase.addstmt(StmtExpr(self.write(ct, wexpr, msgvar)))
writecase.addstmt(self.checkedWrite(ct, wexpr, msgvar, sentinelKey=c.enum()))
writecase.addstmt(StmtReturn())
writeswitch.addcase(caselabel, writecase)
@ -4622,11 +4632,12 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
readcase.addstmts([
StmtDecl(Decl(ct, tmpvar.name), init=c.defaultValue()),
StmtExpr(ExprAssn(ExprDeref(var), tmpvar)),
StmtReturn(self.read(
self.checkedRead(
c.ipdltype,
ExprAddrOf(ExprCall(ExprSelect(var, '->',
c.getTypeName()))),
msgvar, itervar))
msgvar, itervar, errfnRead, 'Union type', sentinelKey=origenum),
StmtReturn(ExprLiteral.TRUE)
])
readswitch.addcase(caselabel, readcase)
@ -4640,8 +4651,9 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
write = MethodDefn(self.writeMethodDecl(intype, var))
write.addstmts([
uniontdef,
StmtExpr(self.write(
None, ExprCall(Type.INT, args=[ ud.callType(var) ]), msgvar)),
self.checkedWrite(
None, ExprCall(Type.INT, args=[ ud.callType(var) ]), msgvar,
sentinelKey=uniontype.name()),
Whitespace.NL,
writeswitch
])
@ -4652,7 +4664,8 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
StmtDecl(Decl(Type.INT, typevar.name)),
self.checkedRead(
None, ExprAddrOf(typevar), msgvar, itervar, errfnUnionType,
[ uniontype.name() ]),
[ uniontype.name() ],
sentinelKey=uniontype.name()),
Whitespace.NL,
readswitch,
])
@ -4698,6 +4711,20 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
return self.maybeAddNullabilityArg(
ipdltype, ExprCall(read, args=[ expr, from_, iterexpr ]))
def checkedWrite(self, ipdltype, expr, msgvar, sentinelKey, this=None):
assert sentinelKey
write = StmtExpr(self.write(ipdltype, expr, msgvar, this))
sentinel = StmtExpr(ExprCall(ExprSelect(msgvar, '->', 'WriteSentinel'),
args=[ ExprLiteral.Int(hashfunc(sentinelKey)) ]))
block = Block()
block.addstmts([
write,
Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1),
sentinel ])
return block
def visitMessageDecl(self, md):
isctor = md.decl.type.isCtor()
@ -5075,7 +5102,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
init=ExprCall(ExprVar(md.pqMsgCtorFunc()),
args=[ routingId ])) ]
+ [ Whitespace.NL ]
+ [ StmtExpr(self.write(p.ipdltype, p.var(), msgvar, this))
+ [ self.checkedWrite(p.ipdltype, p.var(), msgvar, sentinelKey=p.name, this=this)
for p in md.params ]
+ [ Whitespace.NL ]
+ self.setMessageFlags(md, msgvar, reply=0))
@ -5094,7 +5121,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
[ StmtExpr(ExprAssn(
replyvar, ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[ routingId ]))),
Whitespace.NL ]
+ [ StmtExpr(self.write(r.ipdltype, r.var(), replyvar))
+ [ self.checkedWrite(r.ipdltype, r.var(), replyvar, sentinelKey=r.name)
for r in md.returns ]
+ self.setMessageFlags(md, replyvar, reply=1)
+ [ self.logMessage(md, replyvar, 'Sending reply ') ])
@ -5150,7 +5177,8 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
decls = [ StmtDecl(Decl(handletype, handlevar.name)) ]
reads = [ self.checkedRead(None, ExprAddrOf(handlevar), msgexpr,
ExprAddrOf(self.itervar),
errfn, "'%s'" % handletype.name) ]
errfn, "'%s'" % handletype.name,
sentinelKey='actor') ]
start = 1
stmts.extend((
@ -5162,7 +5190,8 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
+ [ Whitespace.NL ]
+ reads + [ self.checkedRead(p.ipdltype, ExprAddrOf(p.var()),
msgexpr, ExprAddrOf(itervar),
errfn, "'%s'" % p.bareType(side).name)
errfn, "'%s'" % p.bareType(side).name,
sentinelKey=p.name)
for p in md.params[start:] ]
+ [ self.endRead(msgvar, itervar) ]))
@ -5185,7 +5214,8 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
+ [ self.checkedRead(r.ipdltype, r.var(),
ExprAddrOf(self.replyvar),
ExprAddrOf(self.itervar),
errfn, "'%s'" % r.bareType(side).name)
errfn, "'%s'" % r.bareType(side).name,
sentinelKey=r.name)
for r in md.returns ]
+ [ self.endRead(self.replyvar, itervar) ])
@ -5336,14 +5366,28 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor):
ifbad.addifstmts(_badTransition())
return [ ifbad ]
def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype):
def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype, sentinelKey, sentinel=True):
ifbad = StmtIf(ExprNot(self.read(ipdltype, expr, msgexpr, iterexpr)))
if isinstance(paramtype, list):
errorcall = errfn(*paramtype)
else:
errorcall = errfn('Error deserializing ' + paramtype)
ifbad.addifstmts(errorcall)
return ifbad
block = Block()
block.addstmt(ifbad)
if sentinel:
assert sentinelKey
block.addstmt(Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1))
read = ExprCall(ExprSelect(msgexpr, '->', 'ReadSentinel'),
args=[ iterexpr, ExprLiteral.Int(hashfunc(sentinelKey)) ])
ifsentinel = StmtIf(ExprNot(read))
ifsentinel.addifstmts(errorcall)
block.addstmt(ifsentinel)
return block
def endRead(self, msgexpr, iterexpr):
return StmtExpr(ExprCall(ExprSelect(msgexpr, '.', 'EndRead'),