2017-01-25 15:18:05 +00:00
|
|
|
// Copyright 2017 syzkaller project authors. All rights reserved.
|
|
|
|
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
|
|
|
|
|
|
|
|
package prog
|
|
|
|
|
|
|
|
import (
|
2017-01-30 18:17:38 +00:00
|
|
|
"encoding/binary"
|
2017-01-25 15:18:05 +00:00
|
|
|
"fmt"
|
|
|
|
"unsafe"
|
|
|
|
|
|
|
|
"github.com/google/syzkaller/sys"
|
|
|
|
)
|
|
|
|
|
|
|
|
type IPChecksum struct {
|
|
|
|
acc uint32
|
|
|
|
}
|
|
|
|
|
|
|
|
func (csum *IPChecksum) Update(data []byte) {
|
|
|
|
length := len(data) - 1
|
|
|
|
for i := 0; i < length; i += 2 {
|
|
|
|
csum.acc += uint32(data[i]) << 8
|
|
|
|
csum.acc += uint32(data[i+1])
|
|
|
|
}
|
|
|
|
if len(data)%2 == 1 {
|
|
|
|
csum.acc += uint32(data[length]) << 8
|
|
|
|
}
|
|
|
|
for csum.acc > 0xffff {
|
|
|
|
csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (csum *IPChecksum) Digest() uint16 {
|
|
|
|
return ^uint16(csum.acc)
|
|
|
|
}
|
|
|
|
|
|
|
|
func ipChecksum(data []byte) uint16 {
|
|
|
|
var csum IPChecksum
|
|
|
|
csum.Update(data)
|
|
|
|
return csum.Digest()
|
|
|
|
}
|
|
|
|
|
|
|
|
func bitmaskLen(bfLen uint64) uint64 {
|
|
|
|
return (1 << bfLen) - 1
|
|
|
|
}
|
|
|
|
|
|
|
|
func bitmaskLenOff(bfOff, bfLen uint64) uint64 {
|
|
|
|
return bitmaskLen(bfLen) << bfOff
|
|
|
|
}
|
|
|
|
|
|
|
|
func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) {
|
|
|
|
if bfOff == 0 && bfLen == 0 {
|
|
|
|
*addr = value
|
|
|
|
} else {
|
|
|
|
newValue := *addr
|
|
|
|
newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen))
|
|
|
|
newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff
|
|
|
|
*addr = newValue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) {
|
|
|
|
if bfOff == 0 && bfLen == 0 {
|
|
|
|
*addr = value
|
|
|
|
} else {
|
|
|
|
newValue := *addr
|
|
|
|
newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen))
|
|
|
|
newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff
|
|
|
|
*addr = newValue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) {
|
|
|
|
if bfOff == 0 && bfLen == 0 {
|
|
|
|
*addr = value
|
|
|
|
} else {
|
|
|
|
newValue := *addr
|
|
|
|
newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen))
|
|
|
|
newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff
|
|
|
|
*addr = newValue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) {
|
|
|
|
if bfOff == 0 && bfLen == 0 {
|
|
|
|
*addr = value
|
|
|
|
} else {
|
|
|
|
newValue := *addr
|
|
|
|
newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen))
|
|
|
|
newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff
|
|
|
|
*addr = newValue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-01-30 18:17:38 +00:00
|
|
|
func encodeArg(arg *Arg, pid int) []byte {
|
2017-01-25 15:18:05 +00:00
|
|
|
bytes := make([]byte, arg.Size())
|
|
|
|
foreachSubargOffset(arg, func(arg *Arg, offset uintptr) {
|
|
|
|
switch arg.Kind {
|
|
|
|
case ArgConst:
|
|
|
|
addr := unsafe.Pointer(&bytes[offset])
|
|
|
|
val := arg.Value(pid)
|
|
|
|
bfOff := uint64(arg.Type.BitfieldOffset())
|
|
|
|
bfLen := uint64(arg.Type.BitfieldLength())
|
|
|
|
switch arg.Size() {
|
|
|
|
case 1:
|
|
|
|
storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen)
|
|
|
|
case 2:
|
|
|
|
storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen)
|
|
|
|
case 4:
|
|
|
|
storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen)
|
|
|
|
case 8:
|
|
|
|
storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen)
|
|
|
|
default:
|
|
|
|
panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg))
|
|
|
|
}
|
|
|
|
case ArgData:
|
|
|
|
copy(bytes[offset:], arg.Data)
|
|
|
|
default:
|
|
|
|
panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type))
|
|
|
|
}
|
|
|
|
})
|
|
|
|
return bytes
|
|
|
|
}
|
|
|
|
|
2017-01-30 18:17:38 +00:00
|
|
|
func getFieldByName(arg *Arg, name string) *Arg {
|
2017-01-25 15:18:05 +00:00
|
|
|
for _, field := range arg.Inner {
|
2017-01-30 18:17:38 +00:00
|
|
|
if field.Type.FieldName() == name {
|
|
|
|
return field
|
2017-01-25 15:18:05 +00:00
|
|
|
}
|
|
|
|
}
|
2017-01-30 18:17:38 +00:00
|
|
|
panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name()))
|
|
|
|
}
|
|
|
|
|
2017-02-06 18:27:17 +00:00
|
|
|
func calcChecksumInet(packet, csumField *Arg, pid int) *Arg {
|
2017-02-02 18:19:32 +00:00
|
|
|
bytes := encodeArg(packet, pid)
|
2017-01-25 15:18:05 +00:00
|
|
|
csum := ipChecksum(bytes)
|
|
|
|
newCsumField := *csumField
|
|
|
|
newCsumField.Val = uintptr(csum)
|
2017-02-02 18:19:32 +00:00
|
|
|
return &newCsumField
|
2017-01-25 15:18:05 +00:00
|
|
|
}
|
|
|
|
|
2017-02-02 14:58:15 +00:00
|
|
|
func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) {
|
2017-01-30 18:17:38 +00:00
|
|
|
srcAddr := getFieldByName(arg, "src_ip")
|
|
|
|
if srcAddr.Size() != 4 {
|
|
|
|
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
|
|
|
|
}
|
|
|
|
dstAddr := getFieldByName(arg, "dst_ip")
|
|
|
|
if dstAddr.Size() != 4 {
|
|
|
|
panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
|
|
|
|
}
|
2017-02-02 14:58:15 +00:00
|
|
|
return srcAddr, dstAddr
|
|
|
|
}
|
|
|
|
|
|
|
|
func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) {
|
|
|
|
srcAddr := getFieldByName(arg, "src_ip")
|
|
|
|
if srcAddr.Size() != 16 {
|
|
|
|
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
|
|
|
|
}
|
|
|
|
dstAddr := getFieldByName(arg, "dst_ip")
|
|
|
|
if dstAddr.Size() != 16 {
|
|
|
|
panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
|
2017-01-30 18:17:38 +00:00
|
|
|
}
|
2017-02-02 14:58:15 +00:00
|
|
|
return srcAddr, dstAddr
|
|
|
|
}
|
|
|
|
|
2017-02-02 18:19:32 +00:00
|
|
|
func composePseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte {
|
2017-02-02 14:58:15 +00:00
|
|
|
header := []byte{}
|
|
|
|
header = append(header, encodeArg(srcAddr, pid)...)
|
|
|
|
header = append(header, encodeArg(dstAddr, pid)...)
|
2017-02-02 18:19:32 +00:00
|
|
|
header = append(header, []byte{0, protocol}...)
|
2017-02-02 14:58:15 +00:00
|
|
|
length := []byte{0, 0}
|
|
|
|
binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size()))
|
|
|
|
header = append(header, length...)
|
|
|
|
return header
|
|
|
|
}
|
|
|
|
|
2017-02-02 18:19:32 +00:00
|
|
|
func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte {
|
2017-02-02 14:58:15 +00:00
|
|
|
header := []byte{}
|
|
|
|
header = append(header, encodeArg(srcAddr, pid)...)
|
|
|
|
header = append(header, encodeArg(dstAddr, pid)...)
|
|
|
|
length := []byte{0, 0, 0, 0}
|
|
|
|
binary.BigEndian.PutUint32(length, uint32(tcpPacket.Size()))
|
|
|
|
header = append(header, length...)
|
2017-02-02 18:19:32 +00:00
|
|
|
header = append(header, []byte{0, 0, 0, protocol}...)
|
2017-02-02 14:58:15 +00:00
|
|
|
return header
|
2017-01-30 18:17:38 +00:00
|
|
|
}
|
|
|
|
|
2017-02-08 14:49:03 +00:00
|
|
|
func calcChecksumPseudo(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg {
|
2017-01-30 18:17:38 +00:00
|
|
|
var csum IPChecksum
|
2017-02-02 14:58:15 +00:00
|
|
|
csum.Update(pseudoHeader)
|
2017-02-02 18:19:32 +00:00
|
|
|
csum.Update(encodeArg(packet, pid))
|
2017-01-30 18:17:38 +00:00
|
|
|
newCsumField := *csumField
|
|
|
|
newCsumField.Val = uintptr(csum.Digest())
|
2017-02-02 18:19:32 +00:00
|
|
|
return &newCsumField
|
2017-01-30 18:17:38 +00:00
|
|
|
}
|
|
|
|
|
2017-02-08 14:49:03 +00:00
|
|
|
func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg {
|
|
|
|
if typ.Buf == "parent" {
|
|
|
|
if csummedArg, ok := parentsMap[arg]; ok {
|
|
|
|
return csummedArg
|
|
|
|
}
|
|
|
|
panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name()))
|
|
|
|
} else {
|
|
|
|
for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
|
|
|
|
if typ.Buf == parent.Type.Name() {
|
|
|
|
return parent
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf))
|
|
|
|
}
|
|
|
|
|
2017-01-25 15:18:05 +00:00
|
|
|
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
|
2017-02-08 14:49:03 +00:00
|
|
|
var inetCsumFields []*Arg
|
|
|
|
var pseudoCsumFields []*Arg
|
2017-02-06 18:27:17 +00:00
|
|
|
|
2017-02-08 14:49:03 +00:00
|
|
|
// Find all csum fields.
|
|
|
|
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
|
|
|
|
if typ, ok := arg.Type.(*sys.CsumType); ok {
|
|
|
|
switch typ.Kind {
|
|
|
|
case sys.CsumInet:
|
|
|
|
inetCsumFields = append(inetCsumFields, arg)
|
|
|
|
case sys.CsumPseudo:
|
|
|
|
pseudoCsumFields = append(pseudoCsumFields, arg)
|
|
|
|
default:
|
|
|
|
panic(fmt.Sprintf("unknown csum kind %v\n", typ.Kind))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
// Return if no csum fields found.
|
|
|
|
if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Build map of each field to its parent struct.
|
|
|
|
parentsMap := make(map[*Arg]*Arg)
|
2017-02-06 18:27:17 +00:00
|
|
|
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
|
|
|
|
if _, ok := arg.Type.(*sys.StructType); ok {
|
|
|
|
for _, field := range arg.Inner {
|
2017-02-08 14:49:03 +00:00
|
|
|
parentsMap[field.InnerArg()] = arg
|
2017-02-06 18:27:17 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
2017-02-08 14:49:03 +00:00
|
|
|
csumMap := make(map[*Arg]*Arg)
|
|
|
|
|
|
|
|
// Calculate inet checksums.
|
|
|
|
for _, arg := range inetCsumFields {
|
|
|
|
typ, _ := arg.Type.(*sys.CsumType)
|
|
|
|
csummedArg := findCsummedArg(arg, typ, parentsMap)
|
|
|
|
newCsumField := calcChecksumInet(csummedArg, arg, pid)
|
|
|
|
csumMap[arg] = newCsumField
|
|
|
|
}
|
|
|
|
|
|
|
|
// No need to continue if there are no pseudo csum fields.
|
|
|
|
if len(pseudoCsumFields) == 0 {
|
|
|
|
return csumMap
|
|
|
|
}
|
|
|
|
|
|
|
|
// Extract ipv4 or ipv6 source and destination addresses.
|
|
|
|
ipv4HeaderParsed := false
|
|
|
|
ipv6HeaderParsed := false
|
|
|
|
var ipSrcAddr *Arg
|
|
|
|
var ipDstAddr *Arg
|
2017-01-25 15:18:05 +00:00
|
|
|
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
|
2017-02-02 18:19:32 +00:00
|
|
|
// syz_csum_* structs are used in tests
|
|
|
|
switch arg.Type.Name() {
|
|
|
|
case "ipv4_header", "syz_csum_ipv4_header":
|
2017-02-02 14:58:15 +00:00
|
|
|
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg)
|
2017-01-30 18:17:38 +00:00
|
|
|
ipv4HeaderParsed = true
|
2017-02-02 18:19:32 +00:00
|
|
|
case "ipv6_packet", "syz_csum_ipv6_header":
|
2017-02-02 14:58:15 +00:00
|
|
|
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg)
|
|
|
|
ipv6HeaderParsed = true
|
2017-01-25 15:18:05 +00:00
|
|
|
}
|
|
|
|
})
|
2017-02-08 14:49:03 +00:00
|
|
|
if !ipv4HeaderParsed && !ipv6HeaderParsed {
|
|
|
|
panic("no ipv4 nor ipv6 header found")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Calculate pseudo checksums.
|
|
|
|
for _, arg := range pseudoCsumFields {
|
|
|
|
typ, _ := arg.Type.(*sys.CsumType)
|
|
|
|
csummedArg := findCsummedArg(arg, typ, parentsMap)
|
|
|
|
protocol := uint8(typ.Protocol)
|
|
|
|
var pseudoHeader []byte
|
|
|
|
if ipv4HeaderParsed {
|
|
|
|
pseudoHeader = composePseudoHeaderIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid)
|
|
|
|
} else {
|
|
|
|
pseudoHeader = composePseudoHeaderIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid)
|
|
|
|
}
|
|
|
|
newCsumField := calcChecksumPseudo(csummedArg, arg, pseudoHeader, pid)
|
|
|
|
csumMap[arg] = newCsumField
|
|
|
|
}
|
|
|
|
|
2017-01-30 18:17:38 +00:00
|
|
|
return csumMap
|
2017-01-25 15:18:05 +00:00
|
|
|
}
|