syzkaller/prog/checksum.go

283 lines
8.0 KiB
Go
Raw Normal View History

// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"encoding/binary"
"fmt"
"unsafe"
"github.com/google/syzkaller/sys"
)
type IPChecksum struct {
acc uint32
}
func (csum *IPChecksum) Update(data []byte) {
length := len(data) - 1
for i := 0; i < length; i += 2 {
csum.acc += uint32(data[i]) << 8
csum.acc += uint32(data[i+1])
}
if len(data)%2 == 1 {
csum.acc += uint32(data[length]) << 8
}
for csum.acc > 0xffff {
csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff)
}
}
func (csum *IPChecksum) Digest() uint16 {
return ^uint16(csum.acc)
}
func ipChecksum(data []byte) uint16 {
var csum IPChecksum
csum.Update(data)
return csum.Digest()
}
func bitmaskLen(bfLen uint64) uint64 {
return (1 << bfLen) - 1
}
func bitmaskLenOff(bfOff, bfLen uint64) uint64 {
return bitmaskLen(bfLen) << bfOff
}
func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) {
if bfOff == 0 && bfLen == 0 {
*addr = value
} else {
newValue := *addr
newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen))
newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff
*addr = newValue
}
}
func encodeArg(arg *Arg, pid int) []byte {
bytes := make([]byte, arg.Size())
foreachSubargOffset(arg, func(arg *Arg, offset uintptr) {
switch arg.Kind {
case ArgConst:
addr := unsafe.Pointer(&bytes[offset])
val := arg.Value(pid)
bfOff := uint64(arg.Type.BitfieldOffset())
bfLen := uint64(arg.Type.BitfieldLength())
switch arg.Size() {
case 1:
storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen)
case 2:
storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen)
case 4:
storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen)
case 8:
storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen)
default:
panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg))
}
case ArgData:
copy(bytes[offset:], arg.Data)
default:
panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type))
}
})
return bytes
}
func getFieldByName(arg *Arg, name string) *Arg {
for _, field := range arg.Inner {
if field.Type.FieldName() == name {
return field
}
}
panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name()))
}
func calcChecksumInet(packet, csumField *Arg, pid int) *Arg {
bytes := encodeArg(packet, pid)
csum := ipChecksum(bytes)
newCsumField := *csumField
newCsumField.Val = uintptr(csum)
return &newCsumField
}
func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 4 {
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
}
dstAddr := getFieldByName(arg, "dst_ip")
if dstAddr.Size() != 4 {
panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
}
return srcAddr, dstAddr
}
func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 16 {
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
}
dstAddr := getFieldByName(arg, "dst_ip")
if dstAddr.Size() != 16 {
panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
}
return srcAddr, dstAddr
}
func composePseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte {
header := []byte{}
header = append(header, encodeArg(srcAddr, pid)...)
header = append(header, encodeArg(dstAddr, pid)...)
header = append(header, []byte{0, protocol}...)
length := []byte{0, 0}
binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size()))
header = append(header, length...)
return header
}
func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte {
header := []byte{}
header = append(header, encodeArg(srcAddr, pid)...)
header = append(header, encodeArg(dstAddr, pid)...)
length := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(length, uint32(tcpPacket.Size()))
header = append(header, length...)
header = append(header, []byte{0, 0, 0, protocol}...)
return header
}
func findCsumFieldUDP(udpPacket *Arg) *Arg {
csumField := getFieldByName(udpPacket, "csum")
if typ, ok := csumField.Type.(*sys.CsumType); !ok {
panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
} else if typ.Kind != sys.CsumUDP {
panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
}
return csumField
}
func findCsumFieldTCP(tcpPacket *Arg) *Arg {
tcpHeaderField := getFieldByName(tcpPacket, "header")
csumField := getFieldByName(tcpHeaderField, "csum")
if typ, ok := csumField.Type.(*sys.CsumType); !ok {
panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
} else if typ.Kind != sys.CsumTCP {
panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
}
return csumField
}
func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg {
var csum IPChecksum
csum.Update(pseudoHeader)
csum.Update(encodeArg(packet, pid))
newCsumField := *csumField
newCsumField.Val = uintptr(csum.Digest())
return &newCsumField
}
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
var csumMap map[*Arg]*Arg
ipv4HeaderParsed := false
ipv6HeaderParsed := false
var ipSrcAddr *Arg
var ipDstAddr *Arg
tcp := false
// Calculate inet checksums.
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
if _, ok := arg.Type.(*sys.StructType); ok {
for _, field := range arg.Inner {
if typ, ok1 := field.Type.(*sys.CsumType); ok1 {
if typ.Kind == sys.CsumInet {
newCsumField := calcChecksumInet(arg, field, pid)
if csumMap == nil {
csumMap = make(map[*Arg]*Arg)
}
csumMap[field] = newCsumField
}
}
}
}
})
// Calculate tcp and udp checksums.
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
// syz_csum_* structs are used in tests
switch arg.Type.Name() {
case "ipv4_header", "syz_csum_ipv4_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg)
ipv4HeaderParsed = true
case "ipv6_packet", "syz_csum_ipv6_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg)
ipv6HeaderParsed = true
case "tcp_packet", "syz_csum_tcp_packet":
tcp = true
fallthrough
case "udp_packet", "syz_csum_udp_packet":
if !ipv4HeaderParsed && !ipv6HeaderParsed {
panic(fmt.Sprintf("%s is being parsed before ipv4 or ipv6 header", arg.Type.Name()))
}
var csumField *Arg
var protocol uint8
if tcp {
csumField = findCsumFieldTCP(arg)
protocol = 6 // IPPROTO_TCP
} else {
csumField = findCsumFieldUDP(arg)
protocol = 17 // IPPROTO_UDP
}
var pseudoHeader []byte
if ipv4HeaderParsed {
pseudoHeader = composePseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, protocol, pid)
} else {
pseudoHeader = composePseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, protocol, pid)
}
if csumMap == nil {
csumMap = make(map[*Arg]*Arg)
}
newCsumField := calcChecksumTCPUDP(arg, csumField, pseudoHeader, pid)
csumMap[csumField] = newCsumField
}
})
return csumMap
}