prog: dedup checksumming code

Update #538
This commit is contained in:
Dmitry Vyukov 2018-05-07 14:51:28 +02:00
parent 23b5913da9
commit 8041642739

View File

@ -77,7 +77,6 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo {
}
// Extract ipv4 or ipv6 source and destination addresses.
ipv4HeaderParsed, ipv6HeaderParsed := false, false
var ipSrcAddr, ipDstAddr Arg
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
groupArg, ok := arg.(*GroupArg)
@ -87,14 +86,12 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo {
// syz_csum_* structs are used in tests
switch groupArg.Type().Name() {
case "ipv4_header", "syz_csum_ipv4_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(groupArg)
ipv4HeaderParsed = true
ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4)
case "ipv6_packet", "syz_csum_ipv6_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(groupArg)
ipv6HeaderParsed = true
ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16)
}
})
if !ipv4HeaderParsed && !ipv6HeaderParsed {
if ipSrcAddr == nil || ipDstAddr == nil {
panic("no ipv4 nor ipv6 header found")
}
@ -104,7 +101,7 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo {
csummedArg := findCsummedArg(arg, typ, parentsMap)
protocol := uint8(typ.Protocol)
var info CsumInfo
if ipv4HeaderParsed {
if ipSrcAddr.Size() == 4 {
info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol)
} else {
info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol)
@ -151,26 +148,11 @@ func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) Csum
return info
}
func extractHeaderParamsIPv4(arg *GroupArg) (Arg, Arg) {
func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 4 {
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name()))
}
dstAddr := getFieldByName(arg, "dst_ip")
if dstAddr.Size() != 4 {
panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type().Name()))
}
return srcAddr, dstAddr
}
func extractHeaderParamsIPv6(arg *GroupArg) (Arg, Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 16 {
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name()))
}
dstAddr := getFieldByName(arg, "dst_ip")
if dstAddr.Size() != 16 {
panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type().Name()))
if srcAddr.Size() != size || dstAddr.Size() != size {
panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size))
}
return srcAddr, dstAddr
}