From 63b16a5d5cfd3b41f596daccd56d32b2548ec119 Mon Sep 17 00:00:00 2001 From: Andrey Konovalov Date: Wed, 25 Jan 2017 16:18:05 +0100 Subject: [PATCH] prog, sys: add csum type, embed checksums for ipv4 packets This change adds a `csum[kind, type]` type. The only available kind right now is `ipv4`. Using `csum[ipv4, int16be]` in `ipv4_header` makes syzkaller calculate and embed correct checksums into ipv4 packets. --- prog/analysis.go | 30 ++++++++ prog/checksum.go | 157 ++++++++++++++++++++++++++++++++++++++++++ prog/checksum_test.go | 150 ++++++++++++++++++++++++++++++++++++++++ prog/encodingexec.go | 42 ++++------- prog/mutation.go | 7 +- prog/prog.go | 4 +- prog/rand.go | 4 +- prog/validation.go | 6 +- sys/decl.go | 13 +++- sys/test.txt | 19 +++++ sys/vnet.txt | 7 +- sysgen/sysgen.go | 13 ++++ 12 files changed, 412 insertions(+), 40 deletions(-) create mode 100644 prog/checksum.go create mode 100644 prog/checksum_test.go diff --git a/prog/analysis.go b/prog/analysis.go index d008f9c..a267f7d 100644 --- a/prog/analysis.go +++ b/prog/analysis.go @@ -150,6 +150,36 @@ func foreachArg(c *Call, f func(arg, base *Arg, parent *[]*Arg)) { foreachArgArray(&c.Args, nil, f) } +func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) { + var rec func(*Arg, uintptr) uintptr + rec = func(arg1 *Arg, offset uintptr) uintptr { + switch arg1.Kind { + case ArgGroup: + var totalSize uintptr + for _, arg2 := range arg1.Inner { + size := rec(arg2, offset) + if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() { + offset += size + totalSize += size + } + } + if totalSize > arg1.Size() { + panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1)) + } + case ArgUnion: + size := rec(arg1.Option, offset) + offset += size + if size > arg1.Size() { + panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type)) + } + default: + f(arg1, offset) + } + return arg1.Size() + } + rec(arg, 0) +} + func sanitizeCall(c *Call) { switch c.Meta.CallName { case "mmap": diff --git a/prog/checksum.go b/prog/checksum.go new file mode 100644 index 0000000..3806c59 --- /dev/null +++ b/prog/checksum.go @@ -0,0 +1,157 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package prog + +import ( + "fmt" + "unsafe" + + "github.com/google/syzkaller/sys" +) + +type IPChecksum struct { + acc uint32 +} + +func (csum *IPChecksum) Update(data []byte) { + length := len(data) - 1 + for i := 0; i < length; i += 2 { + csum.acc += uint32(data[i]) << 8 + csum.acc += uint32(data[i+1]) + } + if len(data)%2 == 1 { + csum.acc += uint32(data[length]) << 8 + } + for csum.acc > 0xffff { + csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff) + } +} + +func (csum *IPChecksum) Digest() uint16 { + return ^uint16(csum.acc) +} + +func ipChecksum(data []byte) uint16 { + var csum IPChecksum + csum.Update(data) + return csum.Digest() +} + +func bitmaskLen(bfLen uint64) uint64 { + return (1 << bfLen) - 1 +} + +func bitmaskLenOff(bfOff, bfLen uint64) uint64 { + return bitmaskLen(bfLen) << bfOff +} + +func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) { + if bfOff == 0 && bfLen == 0 { + *addr = value + } else { + newValue := *addr + newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen)) + newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff + *addr = newValue + } +} + +func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) { + if bfOff == 0 && bfLen == 0 { + *addr = value + } else { + newValue := *addr + newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen)) + newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff + *addr = newValue + } +} + +func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) { + if bfOff == 0 && bfLen == 0 { + *addr = value + } else { + newValue := *addr + newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen)) + newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff + *addr = newValue + } +} + +func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) { + if bfOff == 0 && bfLen == 0 { + *addr = value + } else { + newValue := *addr + newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen)) + newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff + *addr = newValue + } +} + +func encodeStruct(arg *Arg, pid int) []byte { + bytes := make([]byte, arg.Size()) + foreachSubargOffset(arg, func(arg *Arg, offset uintptr) { + switch arg.Kind { + case ArgConst: + addr := unsafe.Pointer(&bytes[offset]) + val := arg.Value(pid) + bfOff := uint64(arg.Type.BitfieldOffset()) + bfLen := uint64(arg.Type.BitfieldLength()) + switch arg.Size() { + case 1: + storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen) + case 2: + storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen) + case 4: + storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen) + case 8: + storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen) + default: + panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg)) + } + case ArgData: + copy(bytes[offset:], arg.Data) + default: + panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type)) + } + }) + return bytes +} + +func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { + var csumField *Arg + for _, field := range arg.Inner { + if _, ok := field.Type.(*sys.CsumType); ok { + csumField = field + break + } + } + if csumField == nil { + panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name())) + } + if csumField.Value(pid) != 0 { + panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField)) + } + bytes := encodeStruct(arg, pid) + csum := ipChecksum(bytes) + newCsumField := *csumField + newCsumField.Val = uintptr(csum) + return csumField, &newCsumField +} + +func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { + var m map[*Arg]*Arg + foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { + // syz_csum_ipv4 struct is used in tests + if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" { + if m == nil { + m = make(map[*Arg]*Arg) + } + k, v := calcChecksumIPv4(arg, pid) + m[k] = v + } + }) + return m +} diff --git a/prog/checksum_test.go b/prog/checksum_test.go new file mode 100644 index 0000000..bade7f7 --- /dev/null +++ b/prog/checksum_test.go @@ -0,0 +1,150 @@ +// Copyright 2016 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package prog + +import ( + "bytes" + "testing" +) + +func TestChecksumIP(t *testing.T) { + tests := []struct { + data string + csum uint16 + }{ + { + "", + 0xffff, + }, + { + "\x00", + 0xffff, + }, + { + "\x00\x00", + 0xffff, + }, + { + "\x00\x00\xff\xff", + 0x0000, + }, + { + "\xfc", + 0x03ff, + }, + { + "\xfc\x12", + 0x03ed, + }, + { + "\xfc\x12\x3e", + 0xc5ec, + }, + { + "\xfc\x12\x3e\x00\xc5\xec", + 0x0000, + }, + { + "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", + 0xe143, + }, + { + "\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", + 0xe143, + }, + } + + for _, test := range tests { + csum := ipChecksum([]byte(test.data)) + if csum != test.csum { + t.Fatalf("incorrect ip checksum, got: %x, want: %x, data: %+v", csum, test.csum, []byte(test.data)) + } + } +} + +func TestChecksumIPAcc(t *testing.T) { + rs, iters := initTest(t) + r := newRand(rs) + + for i := 0; i < iters; i++ { + bytes := make([]byte, r.Intn(256)) + for i := 0; i < len(bytes); i++ { + bytes[i] = byte(r.Intn(256)) + } + step := int(r.randRange(1, 8)) * 2 + var csumAcc IPChecksum + for i := 0; i < len(bytes)/step; i++ { + csumAcc.Update(bytes[i*step : (i+1)*step]) + } + if len(bytes)%step != 0 { + csumAcc.Update(bytes[len(bytes)-(len(bytes)%step) : len(bytes)]) + } + csum := ipChecksum(bytes) + if csum != csumAcc.Digest() { + t.Fatalf("inconsistent ip checksum: %x vs %x, step: %v, data: %+v", csum, csumAcc.Digest(), step, bytes) + } + } +} + +func TestChecksumEncode(t *testing.T) { + tests := []struct { + prog string + encoded string + }{ + { + "syz_test$csum_encode(&(0x7f0000000000)={0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"})", + "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", + }, + } + for i, test := range tests { + p, err := Deserialize([]byte(test.prog)) + if err != nil { + t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) + } + encoded := encodeStruct(p.Calls[0].Args[0].Res, 0) + if !bytes.Equal(encoded, []byte(test.encoded)) { + t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded)) + } + } +} + +func TestChecksumIPv4Calc(t *testing.T) { + tests := []struct { + prog string + csum uint16 + }{ + { + "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}})", + 0xe143, + }, + } + for i, test := range tests { + p, err := Deserialize([]byte(test.prog)) + if err != nil { + t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) + } + _, csumField := calcChecksumIPv4(p.Calls[0].Args[0].Res, i%32) + // Can't compare serialized progs, since checksums are zerod on serialization. + csum := csumField.Value(i % 32) + if csum != uintptr(test.csum) { + t.Fatalf("failed to calc ipv4 checksum, got %x, want %x, prog: '%v'", csum, test.csum, test.prog) + } + } +} + +func TestChecksumCalcRandom(t *testing.T) { + rs, iters := initTest(t) + for i := 0; i < iters; i++ { + p := Generate(rs, 10, nil) + for _, call := range p.Calls { + calcChecksumsCall(call, i%32) + } + for try := 0; try <= 10; try++ { + p.Mutate(rs, 10, nil, nil) + for _, call := range p.Calls { + calcChecksumsCall(call, i%32) + } + } + } +} diff --git a/prog/encodingexec.go b/prog/encodingexec.go index 304440d..9a9cc4b 100644 --- a/prog/encodingexec.go +++ b/prog/encodingexec.go @@ -47,55 +47,32 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error { args: make(map[*Arg]argInfo), } for _, c := range p.Calls { + // Calculate checksums. + csumMap := calcChecksumsCall(c, pid) // Calculate arg offsets within structs. // Generate copyin instructions that fill in data into pointer arguments. foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) { if arg.Kind == ArgPointer && arg.Res != nil { - var rec func(*Arg, uintptr) uintptr - rec = func(arg1 *Arg, offset uintptr) uintptr { + foreachSubargOffset(arg.Res, func(arg1 *Arg, offset uintptr) { if len(arg1.Uses) != 0 { w.args[arg1] = argInfo{Offset: offset} } - if arg1.Kind == ArgGroup { - var totalSize uintptr - for _, arg2 := range arg1.Inner { - size := rec(arg2, offset) - if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() { - offset += size - totalSize += size - } - } - if totalSize > arg1.Size() { - panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1)) - } - return arg1.Size() - } - if arg1.Kind == ArgUnion { - size := rec(arg1.Option, offset) - offset += size - if size > arg1.Size() { - panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type)) - } - return arg1.Size() - } if !sys.IsPad(arg1.Type) && !(arg1.Kind == ArgData && len(arg1.Data) == 0) && arg1.Type.Dir() != sys.DirOut { w.write(ExecInstrCopyin) w.write(physicalAddr(arg) + offset) - w.writeArg(arg1, pid) + w.writeArg(arg1, pid, csumMap) instrSeq++ } - return arg1.Size() - } - rec(arg.Res, 0) + }) } }) // Generate the call itself. w.write(uintptr(c.Meta.ID)) w.write(uintptr(len(c.Args))) for _, arg := range c.Args { - w.writeArg(arg, pid) + w.writeArg(arg, pid, csumMap) } if len(c.Ret.Uses) != 0 { w.args[c.Ret] = argInfo{Idx: instrSeq} @@ -173,9 +150,14 @@ func (w *execContext) write(v uintptr) { w.buf = w.buf[8:] } -func (w *execContext) writeArg(arg *Arg, pid int) { +func (w *execContext) writeArg(arg *Arg, pid int, csumMap map[*Arg]*Arg) { switch arg.Kind { case ArgConst: + if _, ok := arg.Type.(*sys.CsumType); ok { + if arg, ok = csumMap[arg]; !ok { + panic("csum arg is not in csum map") + } + } w.write(ExecArgConst) w.write(arg.Size()) w.write(arg.Value(pid)) diff --git a/prog/mutation.go b/prog/mutation.go index 04465a4..358a2b1 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -197,6 +197,8 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro p.replaceArg(c, arg, arg1, calls) case *sys.LenType: panic("bad arg returned by mutationArgs: LenType") + case *sys.CsumType: + panic("bad arg returned by mutationArgs: CsumType") case *sys.ConstType: panic("bad arg returned by mutationArgs: ConstType") default: @@ -397,7 +399,7 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool) } } p0 = p - case *sys.VmaType, *sys.LenType, *sys.ConstType: + case *sys.VmaType, *sys.LenType, *sys.CsumType, *sys.ConstType: // TODO: try to remove offset from vma return false default: @@ -460,6 +462,9 @@ func mutationArgs(c *Call) (args, bases []*Arg) { case *sys.LenType: // Size is updated when the size-of arg change. return + case *sys.CsumType: + // Checksum is updated when the checksummed data changes. + return case *sys.ConstType: // Well, this is const. return diff --git a/prog/prog.go b/prog/prog.go index 13265e4..fbd8507 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -95,6 +95,8 @@ func (a *Arg) Value(pid int) uintptr { return encodeValue(a.Val, typ.Size(), typ.BigEndian) case *sys.LenType: return encodeValue(a.Val, typ.Size(), typ.BigEndian) + case *sys.CsumType: + return encodeValue(a.Val, typ.Size(), typ.BigEndian) case *sys.ProcType: val := uintptr(typ.ValuesStart) + uintptr(typ.ValuesPerProc)*uintptr(pid) + a.Val return encodeValue(val, typ.Size(), typ.BigEndian) @@ -105,7 +107,7 @@ func (a *Arg) Value(pid int) uintptr { func (a *Arg) Size() uintptr { switch typ := a.Type.(type) { case *sys.IntType, *sys.LenType, *sys.FlagsType, *sys.ConstType, - *sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType: + *sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType, *sys.CsumType: return typ.Size() case *sys.BufferType: return uintptr(len(a.Data)) diff --git a/prog/rand.go b/prog/rand.go index 4a7bbe0..3eebe4b 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -765,8 +765,8 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call) arg, calls1 := r.addr(s, a, inner.Size(), inner) calls = append(calls, calls1...) return arg, calls - case *sys.LenType: - // Return placeholder value of 0 while generating len args. + case *sys.LenType, *sys.CsumType: + // Return placeholder value of 0 while generating len and csum args. return constArg(a, 0), nil default: panic("unknown argument type") diff --git a/prog/validation.go b/prog/validation.go index 564c0b0..28c6198 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -106,9 +106,13 @@ func (c *Call) validate(ctx *validCtx) error { switch typ1.Kind { case sys.BufferString: if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) { - return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, len(arg.Data), typ1.Length) + return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, typ.Name(), len(arg.Data), typ1.Length) } } + case *sys.CsumType: + if arg.Val != 0 { + return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, typ.Name(), arg.Val) + } } switch arg.Kind { case ArgConst: diff --git a/sys/decl.go b/sys/decl.go index 3c98d10..14f1105 100644 --- a/sys/decl.go +++ b/sys/decl.go @@ -191,6 +191,17 @@ type ProcType struct { ValuesPerProc uint64 } +type CsumKind int + +const ( + CsumIPv4 CsumKind = iota +) + +type CsumType struct { + IntTypeCommon + Kind CsumKind +} + type VmaType struct { TypeCommon RangeBegin int64 // in pages @@ -573,7 +584,7 @@ func ForeachType(meta *Call, f func(Type)) { rec(opt) } case *ResourceType, *BufferType, *VmaType, *LenType, - *FlagsType, *ConstType, *IntType, *ProcType: + *FlagsType, *ConstType, *IntType, *ProcType, *CsumType: default: panic("unknown type") } diff --git a/sys/test.txt b/sys/test.txt index 0ce7ec8..75b8428 100644 --- a/sys/test.txt +++ b/sys/test.txt @@ -389,3 +389,22 @@ syz_bf_struct1 { syz_test$bf0(a0 ptr[in, syz_bf_struct0]) syz_test$bf1(a0 ptr[in, syz_bf_struct1]) + +# Checksums + +syz_test$csum_encode(a0 ptr[in, syz_csum_encode]) +syz_test$csum_ipv4(a0 ptr[in, syz_csum_ipv4]) + +syz_csum_encode { + f0 int16 + f1 int16be + f2 array[int32, 0:4] + f3 int8:4 + f4 int8:4 + f5 array[int8, 4] +} [packed] + +syz_csum_ipv4 { + f0 csum[ipv4, int16] + f1 syz_csum_encode +} [packed] diff --git a/sys/vnet.txt b/sys/vnet.txt index 67344b1..a959f89 100644 --- a/sys/vnet.txt +++ b/sys/vnet.txt @@ -149,14 +149,13 @@ ipv4_header { version const[4, int8:4] ecn int8:2 dscp int8:6 - tot_len len[ipv4_packet, int16be] - identification int16be + total_len len[ipv4_packet, int16be] + id int16be frag_off int16:13 flags int16:3 ttl int8 protocol flags[ipv4_types, int8] -# TODO: embed correct checksum - csum const[0, int16] + csum csum[ipv4, int16be] src_ip ipv4_addr dst_ip ipv4_addr options ipv4_options diff --git a/sysgen/sysgen.go b/sysgen/sysgen.go index 6906c76..bfead20 100644 --- a/sysgen/sysgen.go +++ b/sysgen/sysgen.go @@ -500,6 +500,19 @@ func generateArg( byteSize = decodeByteSizeType(typ) } fmt.Fprintf(out, "&LenType{%v, Buf: \"%v\", ByteSize: %v}", intCommon(size, bigEndian, bitfieldLen), a[0], byteSize) + case "csum": + if want := 2; len(a) != want { + failf("wrong number of arguments for %v arg %v, want %v, got %v", typ, name, want, len(a)) + } + size, bigEndian, bitfieldLen := decodeIntType(a[1]) + var kind string + switch a[0] { + case "ipv4": + kind = "CsumIPv4" + default: + failf("unknown checksum kind '%v'", a[0]) + } + fmt.Fprintf(out, "&CsumType{%v, Kind: %v}", intCommon(size, bigEndian, bitfieldLen), kind) case "flags": canBeArg = true size := uint64(ptrSize)