diff --git a/prog/checksum.go b/prog/checksum.go index 3806c59e..9df541dc 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -4,6 +4,7 @@ package prog import ( + "encoding/binary" "fmt" "unsafe" @@ -90,7 +91,7 @@ func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) { } } -func encodeStruct(arg *Arg, pid int) []byte { +func encodeArg(arg *Arg, pid int) []byte { bytes := make([]byte, arg.Size()) foreachSubargOffset(arg, func(arg *Arg, offset uintptr) { switch arg.Kind { @@ -120,38 +121,96 @@ func encodeStruct(arg *Arg, pid int) []byte { return bytes } -func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { - var csumField *Arg +func getFieldByName(arg *Arg, name string) *Arg { for _, field := range arg.Inner { - if _, ok := field.Type.(*sys.CsumType); ok { - csumField = field - break + if field.Type.FieldName() == name { + return field } } - if csumField == nil { - panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name())) + panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name())) +} + +func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { + csumField := getFieldByName(arg, "csum") + if typ, ok := csumField.Type.(*sys.CsumType); !ok { + panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) + } else if typ.Kind != sys.CsumIPv4 { + panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) } if csumField.Value(pid) != 0 { panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField)) } - bytes := encodeStruct(arg, pid) + bytes := encodeArg(arg, pid) csum := ipChecksum(bytes) newCsumField := *csumField newCsumField.Val = uintptr(csum) return csumField, &newCsumField } +func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg, *Arg) { + srcAddr := getFieldByName(arg, "src_ip") + if srcAddr.Size() != 4 { + panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name())) + } + dstAddr := getFieldByName(arg, "dst_ip") + if dstAddr.Size() != 4 { + panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name())) + } + protocol := getFieldByName(arg, "protocol") + if protocol.Size() != 1 { + panic(fmt.Sprintf("protocol field in %v must be 1 byte", arg.Type.Name())) + } + return srcAddr, dstAddr, protocol +} + +func calcChecksumTCP(tcpPacket, srcAddr, dstAddr, protocol *Arg, pid int) (*Arg, *Arg) { + tcpHeaderField := getFieldByName(tcpPacket, "header") + csumField := getFieldByName(tcpHeaderField, "csum") + if typ, ok := csumField.Type.(*sys.CsumType); !ok { + panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) + } else if typ.Kind != sys.CsumTCP { + panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) + } + + var csum IPChecksum + csum.Update(encodeArg(srcAddr, pid)) + csum.Update(encodeArg(dstAddr, pid)) + csum.Update([]byte{0, byte(protocol.Value(pid))}) + length := []byte{0, 0} + binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size())) + csum.Update(length) + csum.Update(encodeArg(tcpPacket, pid)) + + newCsumField := *csumField + newCsumField.Val = uintptr(csum.Digest()) + return csumField, &newCsumField +} + func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { - var m map[*Arg]*Arg + var csumMap map[*Arg]*Arg + ipv4HeaderParsed := false + var ipv4SrcAddr *Arg + var ipv4DstAddr *Arg + var ipv4Protocol *Arg foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { - // syz_csum_ipv4 struct is used in tests - if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" { - if m == nil { - m = make(map[*Arg]*Arg) + // syz_csum_ipv4_header struct is used in tests + if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4_header" { + if csumMap == nil { + csumMap = make(map[*Arg]*Arg) } - k, v := calcChecksumIPv4(arg, pid) - m[k] = v + csumField, newCsumField := calcChecksumIPv4(arg, pid) + csumMap[csumField] = newCsumField + ipv4SrcAddr, ipv4DstAddr, ipv4Protocol = extractHeaderParamsIPv4(arg) + ipv4HeaderParsed = true + } + // syz_csum_tcp_packet struct is used in tests + if arg.Type.Name() == "tcp_packet" || arg.Type.Name() == "syz_csum_tcp_packet" { + if !ipv4HeaderParsed { + panic("tcp_packet is being parsed before ipv4_header") + } + csumField, newCsumField := calcChecksumTCP(arg, ipv4SrcAddr, ipv4DstAddr, ipv4Protocol, pid) + csumMap[csumField] = newCsumField } }) - return m + return csumMap } diff --git a/prog/checksum_test.go b/prog/checksum_test.go index bade7f72..56561d4e 100644 --- a/prog/checksum_test.go +++ b/prog/checksum_test.go @@ -6,6 +6,8 @@ package prog import ( "bytes" "testing" + + "github.com/google/syzkaller/sys" ) func TestChecksumIP(t *testing.T) { @@ -53,6 +55,10 @@ func TestChecksumIP(t *testing.T) { "\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", 0xe143, }, + { + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\xab\xcd", + 0x542e, + }, } for _, test := range tests { @@ -102,7 +108,7 @@ func TestChecksumEncode(t *testing.T) { if err != nil { t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) } - encoded := encodeStruct(p.Calls[0].Args[0].Res, 0) + encoded := encodeArg(p.Calls[0].Args[0].Res, 0) if !bytes.Equal(encoded, []byte(test.encoded)) { t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded)) } @@ -115,7 +121,7 @@ func TestChecksumIPv4Calc(t *testing.T) { csum uint16 }{ { - "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}})", + "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}, 0x0, 0x0, 0x0})", 0xe143, }, } @@ -133,6 +139,37 @@ func TestChecksumIPv4Calc(t *testing.T) { } } +func TestChecksumTCPCalc(t *testing.T) { + tests := []struct { + prog string + csum uint16 + }{ + { + "syz_test$csum_ipv4_tcp(&(0x7f0000000000)={{0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}, 0x0, 0x0, 0x0}, {{0x0}, \"abcd\"}})", + 0x542e, + }, + } + for i, test := range tests { + p, err := Deserialize([]byte(test.prog)) + if err != nil { + t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) + } + csumMap := calcChecksumsCall(p.Calls[0], i % 32) + for oldField, newField := range csumMap { + if typ, ok := newField.Type.(*sys.CsumType); ok { + if typ.Kind == sys.CsumTCP { + csum := newField.Value(i % 32) + if csum != uintptr(test.csum) { + t.Fatalf("failed to calc tcp checksum, got %x, want %x, prog: '%v'", csum, test.csum, test.prog) + } + } + } else { + t.Fatalf("non csum key %+v in csum map %+v", oldField, csumMap) + } + } + } +} + func TestChecksumCalcRandom(t *testing.T) { rs, iters := initTest(t) for i := 0; i < iters; i++ { diff --git a/sys/decl.go b/sys/decl.go index 57d78857..ba180683 100644 --- a/sys/decl.go +++ b/sys/decl.go @@ -194,6 +194,7 @@ type CsumKind int const ( CsumIPv4 CsumKind = iota + CsumTCP ) type CsumType struct { diff --git a/sys/socket.txt b/sys/socket.txt index c123a5db..9e86c4de 100644 --- a/sys/socket.txt +++ b/sys/socket.txt @@ -138,6 +138,7 @@ sockaddr_in { family const[AF_INET, int16] port proc[int16be, 20000, 4] addr in_addr + pad array[const[0, int8], 8] } sockaddr_storage_in { @@ -193,6 +194,28 @@ cmsghdr { +# AF_INET: TCP support + +resource sock_tcp[sock] + +socket$tcp(domain const[AF_INET], type const[SOCK_STREAM], proto const[0]) sock_tcp +socketpair$tcp(domain const[AF_INET], type const[SOCK_STREAM], proto const[0], fds ptr[out, tcp_pair]) +accept$tcp(fd sock_tcp, peer ptr[out, sockaddr_in, opt], peerlen ptr[inout, len[peer, int32]]) sock_tcp +accept4$tcp(fd sock_tcp, peer ptr[out, sockaddr_in, opt], peerlen ptr[inout, len[peer, int32]], flags flags[accept_flags]) sock_tcp +bind$tcp(fd sock_tcp, addr ptr[in, sockaddr_in], addrlen len[addr]) +connect$tcp(fd sock_tcp, addr ptr[in, sockaddr_in], addrlen len[addr]) +sendto$tcp(fd sock_tcp, buf buffer[in], len len[buf], f flags[send_flags], addr ptr[in, sockaddr_in, opt], addrlen len[addr]) +recvfrom$tcp(fd sock_tcp, buf buffer[out], len len[buf], f flags[recv_flags], addr ptr[in, sockaddr_in, opt], addrlen len[addr]) +getsockname$tcp(fd sock_tcp, addr ptr[out, sockaddr_in], addrlen ptr[inout, len[addr, int32]]) +getpeername$tcp(fd sock_tcp, peer ptr[out, sockaddr_in], peerlen ptr[inout, len[peer, int32]]) + +tcp_pair { + f0 sock_tcp + f1 sock_tcp +} + + + # AF_UNIX support. resource sock_unix[sock] diff --git a/sys/sys.txt b/sys/sys.txt index be43f078..b4e6bdf5 100644 --- a/sys/sys.txt +++ b/sys/sys.txt @@ -941,7 +941,7 @@ in_addr [ loopback const[0x7f000001, int32be] # 255.255.255.255 broadcast const[0xffffffff, int32be] -} +] in6_addr_empty { a0 const[0, int64be] diff --git a/sys/test.txt b/sys/test.txt index 75b8428f..932ddac0 100644 --- a/sys/test.txt +++ b/sys/test.txt @@ -393,7 +393,8 @@ syz_test$bf1(a0 ptr[in, syz_bf_struct1]) # Checksums syz_test$csum_encode(a0 ptr[in, syz_csum_encode]) -syz_test$csum_ipv4(a0 ptr[in, syz_csum_ipv4]) +syz_test$csum_ipv4(a0 ptr[in, syz_csum_ipv4_header]) +syz_test$csum_ipv4_tcp(a0 ptr[in, syz_csum_ipv4_tcp_packet]) syz_csum_encode { f0 int16 @@ -404,7 +405,24 @@ syz_csum_encode { f5 array[int8, 4] } [packed] -syz_csum_ipv4 { - f0 csum[ipv4, int16] - f1 syz_csum_encode +syz_csum_ipv4_header { + csum csum[ipv4, int16] + data syz_csum_encode + protocol int8 + src_ip int32be + dst_ip int32be +} [packed] + +syz_csum_tcp_header { + csum csum[tcp, int16] +} [packed] + +syz_csum_tcp_packet { + header syz_csum_tcp_header + payload array[int8] +} [packed] + +syz_csum_ipv4_tcp_packet { + header syz_csum_ipv4_header + payload syz_csum_tcp_packet } [packed] diff --git a/sys/vnet.txt b/sys/vnet.txt index 296d0479..5e9becc7 100644 --- a/sys/vnet.txt +++ b/sys/vnet.txt @@ -257,6 +257,139 @@ ipv4_packet { payload ip_payload } [packed] +################################################################################ +###################################### IP ###################################### +################################################################################ + ip_payload { - dummy array[int8, 0:128] + tcp tcp_packet +} [packed] + +################################################################################ +###################################### TCP ##################################### +################################################################################ + +# https://en.wikipedia.org/wiki/Transmission_Control_Protocol#TCP_segment_structure +# http://www.iana.org/assignments/tcp-parameters/tcp-parameters.xhtml + +include +include + +tcp_option [ + generic tcp_generic_option + nop tcp_nop_option + eol tcp_eol_option + mss tcp_mss_option + window tcp_window_option + sack_perm tcp_sack_perm_option + sack tcp_sack_option + timestamp tcp_timestamp_option + md5sig tcp_md5sig_option + fastopen tcp_fastopen_option +# TODO: TCPOPT_EXP option +] [varlen] + +tcp_option_types = TCPOPT_NOP, TCPOPT_EOL, TCPOPT_MSS, TCPOPT_WINDOW, TCPOPT_SACK_PERM, TCPOPT_SACK, TCPOPT_TIMESTAMP, TCPOPT_MD5SIG, TCPOPT_FASTOPEN, TCPOPT_EXP + +tcp_generic_option { + type flags[tcp_option_types, int8] + length len[parent, int8] + data array[int8, 0:16] +} [packed] + +# https://tools.ietf.org/html/rfc793#section-3.1 +tcp_nop_option { + type const[TCPOPT_NOP, int8] +} [packed] + +# https://tools.ietf.org/html/rfc793#section-3.1 +tcp_eol_option { + type const[TCPOPT_EOL, int8] +} [packed] + +# https://tools.ietf.org/html/rfc793#section-3.1 +tcp_mss_option { + type const[TCPOPT_MSS, int8] + length len[parent, int8] + seg_size int16 +} [packed] + +# https://tools.ietf.org/html/rfc7323#section-2 +tcp_window_option { + type const[TCPOPT_WINDOW, int8] + length len[parent, int8] + shift int8 +} [packed] + +# https://tools.ietf.org/html/rfc2018#section-2 +tcp_sack_perm_option { + type const[TCPOPT_SACK_PERM, int8] + length len[parent, int8] +} [packed] + +# https://tools.ietf.org/html/rfc2018#section-3 +tcp_sack_option { + type const[TCPOPT_SACK, int8] + length len[parent, int8] + data array[int32be] +} [packed] + +# https://tools.ietf.org/html/rfc7323#section-3 +tcp_timestamp_option { + type const[TCPOPT_TIMESTAMP, int8] + length len[parent, int8] + tsval int32be + tsecr int32be +} [packed] + +# https://tools.ietf.org/html/rfc2385#section-3.0 +tcp_md5sig_option { + type const[TCPOPT_MD5SIG, int8] + length len[parent, int8] + md5 array[int8, 16] +} [packed] + +# https://tools.ietf.org/html/rfc7413#section-4.1.1 +tcp_fastopen_option { + type const[TCPOPT_FASTOPEN, int8] + length len[parent, int8] + data array[int8, 0:16] +} [packed] + +tcp_options { + options array[tcp_option] +} [packed, align_4] + +# TODO: extract sequence numbers from packets +tcp_seq_num [ + init const[0x56565656, int32be] + next const[0x56565657, int32be] + nextn int32be[0x56565656:0x56566000] + random int32be +] + +tcp_flags = 0, TCPHDR_FIN, TCPHDR_SYN, TCPHDR_RST, TCPHDR_PSH, TCPHDR_ACK, TCPHDR_URG, TCPHDR_ECE, TCPHDR_CWR, TCPHDR_SYN_ECN + +tcp_header { + src_port proc[int16be, 20000, 4] + dst_port proc[int16be, 20000, 4] + seq_num tcp_seq_num + ack_num tcp_seq_num + ns int8:1 + reserved const[0, int8:3] + data_off bytesize4[parent, int8:4] + flags flags[tcp_flags, int8] + window_size int16be + csum csum[tcp, int16be] + urg_ptr int16be + options tcp_options +} [packed] + +tcp_packet { + header tcp_header + payload tcp_payload +} [packed] + +tcp_payload { + payload array[int8] } [packed] diff --git a/sys/vnet_amd64.const b/sys/vnet_amd64.const index cce212a4..9d9a8855 100644 --- a/sys/vnet_amd64.const +++ b/sys/vnet_amd64.const @@ -101,3 +101,22 @@ IPPROTO_TCP = 6 IPPROTO_TP = 29 IPPROTO_UDP = 17 IPPROTO_UDPLITE = 136 +TCPHDR_ACK = 16 +TCPHDR_CWR = 128 +TCPHDR_ECE = 64 +TCPHDR_FIN = 1 +TCPHDR_PSH = 8 +TCPHDR_RST = 4 +TCPHDR_SYN = 2 +TCPHDR_SYN_ECN = 194 +TCPHDR_URG = 32 +TCPOPT_EOL = 0 +TCPOPT_EXP = 254 +TCPOPT_FASTOPEN = 34 +TCPOPT_MD5SIG = 19 +TCPOPT_MSS = 2 +TCPOPT_NOP = 1 +TCPOPT_SACK = 5 +TCPOPT_SACK_PERM = 4 +TCPOPT_TIMESTAMP = 8 +TCPOPT_WINDOW = 3 diff --git a/sys/vnet_arm64.const b/sys/vnet_arm64.const index cce212a4..9d9a8855 100644 --- a/sys/vnet_arm64.const +++ b/sys/vnet_arm64.const @@ -101,3 +101,22 @@ IPPROTO_TCP = 6 IPPROTO_TP = 29 IPPROTO_UDP = 17 IPPROTO_UDPLITE = 136 +TCPHDR_ACK = 16 +TCPHDR_CWR = 128 +TCPHDR_ECE = 64 +TCPHDR_FIN = 1 +TCPHDR_PSH = 8 +TCPHDR_RST = 4 +TCPHDR_SYN = 2 +TCPHDR_SYN_ECN = 194 +TCPHDR_URG = 32 +TCPOPT_EOL = 0 +TCPOPT_EXP = 254 +TCPOPT_FASTOPEN = 34 +TCPOPT_MD5SIG = 19 +TCPOPT_MSS = 2 +TCPOPT_NOP = 1 +TCPOPT_SACK = 5 +TCPOPT_SACK_PERM = 4 +TCPOPT_TIMESTAMP = 8 +TCPOPT_WINDOW = 3 diff --git a/sys/vnet_ppc64le.const b/sys/vnet_ppc64le.const index cce212a4..9d9a8855 100644 --- a/sys/vnet_ppc64le.const +++ b/sys/vnet_ppc64le.const @@ -101,3 +101,22 @@ IPPROTO_TCP = 6 IPPROTO_TP = 29 IPPROTO_UDP = 17 IPPROTO_UDPLITE = 136 +TCPHDR_ACK = 16 +TCPHDR_CWR = 128 +TCPHDR_ECE = 64 +TCPHDR_FIN = 1 +TCPHDR_PSH = 8 +TCPHDR_RST = 4 +TCPHDR_SYN = 2 +TCPHDR_SYN_ECN = 194 +TCPHDR_URG = 32 +TCPOPT_EOL = 0 +TCPOPT_EXP = 254 +TCPOPT_FASTOPEN = 34 +TCPOPT_MD5SIG = 19 +TCPOPT_MSS = 2 +TCPOPT_NOP = 1 +TCPOPT_SACK = 5 +TCPOPT_SACK_PERM = 4 +TCPOPT_TIMESTAMP = 8 +TCPOPT_WINDOW = 3 diff --git a/sysgen/sysgen.go b/sysgen/sysgen.go index 832f5d4b..3da19247 100644 --- a/sysgen/sysgen.go +++ b/sysgen/sysgen.go @@ -509,6 +509,8 @@ func generateArg( switch a[0] { case "ipv4": kind = "CsumIPv4" + case "tcp": + kind = "CsumTCP" default: failf("unknown checksum kind '%v'", a[0]) }