prog, sys: add csum type, embed checksums for ipv4 packets

This change adds a `csum[kind, type]` type.
The only available kind right now is `ipv4`.
Using `csum[ipv4, int16be]` in `ipv4_header` makes syzkaller calculate
and embed correct checksums into ipv4 packets.
This commit is contained in:
Andrey Konovalov 2017-01-25 16:18:05 +01:00
parent c8d03a05f3
commit 63b16a5d5c
12 changed files with 412 additions and 40 deletions

View File

@ -150,6 +150,36 @@ func foreachArg(c *Call, f func(arg, base *Arg, parent *[]*Arg)) {
foreachArgArray(&c.Args, nil, f)
}
func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) {
var rec func(*Arg, uintptr) uintptr
rec = func(arg1 *Arg, offset uintptr) uintptr {
switch arg1.Kind {
case ArgGroup:
var totalSize uintptr
for _, arg2 := range arg1.Inner {
size := rec(arg2, offset)
if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() {
offset += size
totalSize += size
}
}
if totalSize > arg1.Size() {
panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1))
}
case ArgUnion:
size := rec(arg1.Option, offset)
offset += size
if size > arg1.Size() {
panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type))
}
default:
f(arg1, offset)
}
return arg1.Size()
}
rec(arg, 0)
}
func sanitizeCall(c *Call) {
switch c.Meta.CallName {
case "mmap":

157
prog/checksum.go Normal file
View File

@ -0,0 +1,157 @@
// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"fmt"
"unsafe"
"github.com/google/syzkaller/sys"
)
type IPChecksum struct {
acc uint32
}
func (csum *IPChecksum) Update(data []byte) {
length := len(data) - 1
for i := 0; i < length; i += 2 {
csum.acc += uint32(data[i]) << 8
csum.acc += uint32(data[i+1])
}
if len(data)%2 == 1 {
csum.acc += uint32(data[length]) << 8
}
for csum.acc > 0xffff {
csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff)
}
}
func (csum *IPChecksum) Digest() uint16 {
return ^uint16(csum.acc)
}
func ipChecksum(data []byte) uint16 {
var csum IPChecksum
csum.Update(data)
return csum.Digest()
}
func bitmaskLen(bfLen uint64) uint64 {
return (1 << bfLen) - 1
}
func bitmaskLenOff(bfOff, bfLen uint64) uint64 {
return bitmaskLen(bfLen) << bfOff
}
func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func encodeStruct(arg *Arg, pid int) []byte {
bytes := make([]byte, arg.Size())
foreachSubargOffset(arg, func(arg *Arg, offset uintptr) {
switch arg.Kind {
case ArgConst:
addr := unsafe.Pointer(&bytes[offset])
val := arg.Value(pid)
bfOff := uint64(arg.Type.BitfieldOffset())
bfLen := uint64(arg.Type.BitfieldLength())
switch arg.Size() {
case 1:
storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen)
case 2:
storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen)
case 4:
storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen)
case 8:
storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen)
default:
panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg))
}
case ArgData:
copy(bytes[offset:], arg.Data)
default:
panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type))
}
})
return bytes
}
func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
var csumField *Arg
for _, field := range arg.Inner {
if _, ok := field.Type.(*sys.CsumType); ok {
csumField = field
break
}
}
if csumField == nil {
panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name()))
}
if csumField.Value(pid) != 0 {
panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField))
}
bytes := encodeStruct(arg, pid)
csum := ipChecksum(bytes)
newCsumField := *csumField
newCsumField.Val = uintptr(csum)
return csumField, &newCsumField
}
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
var m map[*Arg]*Arg
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
// syz_csum_ipv4 struct is used in tests
if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" {
if m == nil {
m = make(map[*Arg]*Arg)
}
k, v := calcChecksumIPv4(arg, pid)
m[k] = v
}
})
return m
}

150
prog/checksum_test.go Normal file
View File

@ -0,0 +1,150 @@
// Copyright 2016 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"bytes"
"testing"
)
func TestChecksumIP(t *testing.T) {
tests := []struct {
data string
csum uint16
}{
{
"",
0xffff,
},
{
"\x00",
0xffff,
},
{
"\x00\x00",
0xffff,
},
{
"\x00\x00\xff\xff",
0x0000,
},
{
"\xfc",
0x03ff,
},
{
"\xfc\x12",
0x03ed,
},
{
"\xfc\x12\x3e",
0xc5ec,
},
{
"\xfc\x12\x3e\x00\xc5\xec",
0x0000,
},
{
"\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
0xe143,
},
{
"\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
0xe143,
},
}
for _, test := range tests {
csum := ipChecksum([]byte(test.data))
if csum != test.csum {
t.Fatalf("incorrect ip checksum, got: %x, want: %x, data: %+v", csum, test.csum, []byte(test.data))
}
}
}
func TestChecksumIPAcc(t *testing.T) {
rs, iters := initTest(t)
r := newRand(rs)
for i := 0; i < iters; i++ {
bytes := make([]byte, r.Intn(256))
for i := 0; i < len(bytes); i++ {
bytes[i] = byte(r.Intn(256))
}
step := int(r.randRange(1, 8)) * 2
var csumAcc IPChecksum
for i := 0; i < len(bytes)/step; i++ {
csumAcc.Update(bytes[i*step : (i+1)*step])
}
if len(bytes)%step != 0 {
csumAcc.Update(bytes[len(bytes)-(len(bytes)%step) : len(bytes)])
}
csum := ipChecksum(bytes)
if csum != csumAcc.Digest() {
t.Fatalf("inconsistent ip checksum: %x vs %x, step: %v, data: %+v", csum, csumAcc.Digest(), step, bytes)
}
}
}
func TestChecksumEncode(t *testing.T) {
tests := []struct {
prog string
encoded string
}{
{
"syz_test$csum_encode(&(0x7f0000000000)={0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"})",
"\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
},
}
for i, test := range tests {
p, err := Deserialize([]byte(test.prog))
if err != nil {
t.Fatalf("failed to deserialize prog %v: %v", test.prog, err)
}
encoded := encodeStruct(p.Calls[0].Args[0].Res, 0)
if !bytes.Equal(encoded, []byte(test.encoded)) {
t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded))
}
}
}
func TestChecksumIPv4Calc(t *testing.T) {
tests := []struct {
prog string
csum uint16
}{
{
"syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}})",
0xe143,
},
}
for i, test := range tests {
p, err := Deserialize([]byte(test.prog))
if err != nil {
t.Fatalf("failed to deserialize prog %v: %v", test.prog, err)
}
_, csumField := calcChecksumIPv4(p.Calls[0].Args[0].Res, i%32)
// Can't compare serialized progs, since checksums are zerod on serialization.
csum := csumField.Value(i % 32)
if csum != uintptr(test.csum) {
t.Fatalf("failed to calc ipv4 checksum, got %x, want %x, prog: '%v'", csum, test.csum, test.prog)
}
}
}
func TestChecksumCalcRandom(t *testing.T) {
rs, iters := initTest(t)
for i := 0; i < iters; i++ {
p := Generate(rs, 10, nil)
for _, call := range p.Calls {
calcChecksumsCall(call, i%32)
}
for try := 0; try <= 10; try++ {
p.Mutate(rs, 10, nil, nil)
for _, call := range p.Calls {
calcChecksumsCall(call, i%32)
}
}
}
}

View File

@ -47,55 +47,32 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
args: make(map[*Arg]argInfo),
}
for _, c := range p.Calls {
// Calculate checksums.
csumMap := calcChecksumsCall(c, pid)
// Calculate arg offsets within structs.
// Generate copyin instructions that fill in data into pointer arguments.
foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
if arg.Kind == ArgPointer && arg.Res != nil {
var rec func(*Arg, uintptr) uintptr
rec = func(arg1 *Arg, offset uintptr) uintptr {
foreachSubargOffset(arg.Res, func(arg1 *Arg, offset uintptr) {
if len(arg1.Uses) != 0 {
w.args[arg1] = argInfo{Offset: offset}
}
if arg1.Kind == ArgGroup {
var totalSize uintptr
for _, arg2 := range arg1.Inner {
size := rec(arg2, offset)
if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() {
offset += size
totalSize += size
}
}
if totalSize > arg1.Size() {
panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1))
}
return arg1.Size()
}
if arg1.Kind == ArgUnion {
size := rec(arg1.Option, offset)
offset += size
if size > arg1.Size() {
panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type))
}
return arg1.Size()
}
if !sys.IsPad(arg1.Type) &&
!(arg1.Kind == ArgData && len(arg1.Data) == 0) &&
arg1.Type.Dir() != sys.DirOut {
w.write(ExecInstrCopyin)
w.write(physicalAddr(arg) + offset)
w.writeArg(arg1, pid)
w.writeArg(arg1, pid, csumMap)
instrSeq++
}
return arg1.Size()
}
rec(arg.Res, 0)
})
}
})
// Generate the call itself.
w.write(uintptr(c.Meta.ID))
w.write(uintptr(len(c.Args)))
for _, arg := range c.Args {
w.writeArg(arg, pid)
w.writeArg(arg, pid, csumMap)
}
if len(c.Ret.Uses) != 0 {
w.args[c.Ret] = argInfo{Idx: instrSeq}
@ -173,9 +150,14 @@ func (w *execContext) write(v uintptr) {
w.buf = w.buf[8:]
}
func (w *execContext) writeArg(arg *Arg, pid int) {
func (w *execContext) writeArg(arg *Arg, pid int, csumMap map[*Arg]*Arg) {
switch arg.Kind {
case ArgConst:
if _, ok := arg.Type.(*sys.CsumType); ok {
if arg, ok = csumMap[arg]; !ok {
panic("csum arg is not in csum map")
}
}
w.write(ExecArgConst)
w.write(arg.Size())
w.write(arg.Value(pid))

View File

@ -197,6 +197,8 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro
p.replaceArg(c, arg, arg1, calls)
case *sys.LenType:
panic("bad arg returned by mutationArgs: LenType")
case *sys.CsumType:
panic("bad arg returned by mutationArgs: CsumType")
case *sys.ConstType:
panic("bad arg returned by mutationArgs: ConstType")
default:
@ -397,7 +399,7 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool)
}
}
p0 = p
case *sys.VmaType, *sys.LenType, *sys.ConstType:
case *sys.VmaType, *sys.LenType, *sys.CsumType, *sys.ConstType:
// TODO: try to remove offset from vma
return false
default:
@ -460,6 +462,9 @@ func mutationArgs(c *Call) (args, bases []*Arg) {
case *sys.LenType:
// Size is updated when the size-of arg change.
return
case *sys.CsumType:
// Checksum is updated when the checksummed data changes.
return
case *sys.ConstType:
// Well, this is const.
return

View File

@ -95,6 +95,8 @@ func (a *Arg) Value(pid int) uintptr {
return encodeValue(a.Val, typ.Size(), typ.BigEndian)
case *sys.LenType:
return encodeValue(a.Val, typ.Size(), typ.BigEndian)
case *sys.CsumType:
return encodeValue(a.Val, typ.Size(), typ.BigEndian)
case *sys.ProcType:
val := uintptr(typ.ValuesStart) + uintptr(typ.ValuesPerProc)*uintptr(pid) + a.Val
return encodeValue(val, typ.Size(), typ.BigEndian)
@ -105,7 +107,7 @@ func (a *Arg) Value(pid int) uintptr {
func (a *Arg) Size() uintptr {
switch typ := a.Type.(type) {
case *sys.IntType, *sys.LenType, *sys.FlagsType, *sys.ConstType,
*sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType:
*sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType, *sys.CsumType:
return typ.Size()
case *sys.BufferType:
return uintptr(len(a.Data))

View File

@ -765,8 +765,8 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
arg, calls1 := r.addr(s, a, inner.Size(), inner)
calls = append(calls, calls1...)
return arg, calls
case *sys.LenType:
// Return placeholder value of 0 while generating len args.
case *sys.LenType, *sys.CsumType:
// Return placeholder value of 0 while generating len and csum args.
return constArg(a, 0), nil
default:
panic("unknown argument type")

View File

@ -106,9 +106,13 @@ func (c *Call) validate(ctx *validCtx) error {
switch typ1.Kind {
case sys.BufferString:
if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) {
return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, len(arg.Data), typ1.Length)
return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, typ.Name(), len(arg.Data), typ1.Length)
}
}
case *sys.CsumType:
if arg.Val != 0 {
return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, typ.Name(), arg.Val)
}
}
switch arg.Kind {
case ArgConst:

View File

@ -191,6 +191,17 @@ type ProcType struct {
ValuesPerProc uint64
}
type CsumKind int
const (
CsumIPv4 CsumKind = iota
)
type CsumType struct {
IntTypeCommon
Kind CsumKind
}
type VmaType struct {
TypeCommon
RangeBegin int64 // in pages
@ -573,7 +584,7 @@ func ForeachType(meta *Call, f func(Type)) {
rec(opt)
}
case *ResourceType, *BufferType, *VmaType, *LenType,
*FlagsType, *ConstType, *IntType, *ProcType:
*FlagsType, *ConstType, *IntType, *ProcType, *CsumType:
default:
panic("unknown type")
}

View File

@ -389,3 +389,22 @@ syz_bf_struct1 {
syz_test$bf0(a0 ptr[in, syz_bf_struct0])
syz_test$bf1(a0 ptr[in, syz_bf_struct1])
# Checksums
syz_test$csum_encode(a0 ptr[in, syz_csum_encode])
syz_test$csum_ipv4(a0 ptr[in, syz_csum_ipv4])
syz_csum_encode {
f0 int16
f1 int16be
f2 array[int32, 0:4]
f3 int8:4
f4 int8:4
f5 array[int8, 4]
} [packed]
syz_csum_ipv4 {
f0 csum[ipv4, int16]
f1 syz_csum_encode
} [packed]

View File

@ -149,14 +149,13 @@ ipv4_header {
version const[4, int8:4]
ecn int8:2
dscp int8:6
tot_len len[ipv4_packet, int16be]
identification int16be
total_len len[ipv4_packet, int16be]
id int16be
frag_off int16:13
flags int16:3
ttl int8
protocol flags[ipv4_types, int8]
# TODO: embed correct checksum
csum const[0, int16]
csum csum[ipv4, int16be]
src_ip ipv4_addr
dst_ip ipv4_addr
options ipv4_options

View File

@ -500,6 +500,19 @@ func generateArg(
byteSize = decodeByteSizeType(typ)
}
fmt.Fprintf(out, "&LenType{%v, Buf: \"%v\", ByteSize: %v}", intCommon(size, bigEndian, bitfieldLen), a[0], byteSize)
case "csum":
if want := 2; len(a) != want {
failf("wrong number of arguments for %v arg %v, want %v, got %v", typ, name, want, len(a))
}
size, bigEndian, bitfieldLen := decodeIntType(a[1])
var kind string
switch a[0] {
case "ipv4":
kind = "CsumIPv4"
default:
failf("unknown checksum kind '%v'", a[0])
}
fmt.Fprintf(out, "&CsumType{%v, Kind: %v}", intCommon(size, bigEndian, bitfieldLen), kind)
case "flags":
canBeArg = true
size := uint64(ptrSize)