From 85d1218f4108e0fe793f63e57e2edadd5da5764f Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Sun, 18 Feb 2018 13:49:48 +0100 Subject: [PATCH] prog: rework foreachArg Make Foreach* callback accept the arg and a context struct that can contain lots of aux info. This (1) removes lots of unuser base/parent args, (2) provides foundation for stopping recursion, (3) allows to merge foreachSubargOffset. --- prog/analysis.go | 86 ++++++++++++++++++++------------------ prog/checksum.go | 30 ++++++------- prog/encodingexec.go | 7 +--- prog/hints.go | 7 ++-- prog/hints_test.go | 4 +- prog/mutation.go | 62 +++++++++++++-------------- prog/prog.go | 4 +- prog/rand.go | 2 +- prog/size.go | 26 +++++++----- prog/target.go | 6 +-- sys/linux/init_iptables.go | 2 +- 11 files changed, 120 insertions(+), 116 deletions(-) diff --git a/prog/analysis.go b/prog/analysis.go index 629ae1dd..0c4102a4 100644 --- a/prog/analysis.go +++ b/prog/analysis.go @@ -49,7 +49,7 @@ func newState(target *Target, ct *ChoiceTable) *state { } func (s *state) analyze(c *Call) { - foreachArgArray(&c.Args, c.Ret, func(arg, base Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { switch typ := arg.Type().(type) { case *ResourceType: if typ.Dir() != DirIn { @@ -80,45 +80,49 @@ func (s *state) analyze(c *Call) { } } -func foreachSubargImpl(arg Arg, parent *[]Arg, f func(arg, base Arg, parent *[]Arg)) { - var rec func(arg, base Arg, parent *[]Arg) - rec = func(arg, base Arg, parent *[]Arg) { - f(arg, base, parent) - switch a := arg.(type) { - case *GroupArg: - for _, arg1 := range a.Inner { - parent1 := parent - if _, ok := arg.Type().(*StructType); ok { - parent1 = &a.Inner - } - rec(arg1, base, parent1) - } - case *PointerArg: - if a.Res != nil { - rec(a.Res, arg, parent) - } - case *UnionArg: - rec(a.Option, base, parent) +type ArgCtx struct { + Parent *[]Arg // GroupArg.Inner (for structs) or Call.Args containing this arg + Base *PointerArg // pointer to the base of the heap object containing this arg + Offset uint64 // offset of this arg from the base + Stop bool // if set by the callback, subargs of this arg are not visited +} + +func ForeachSubArg(arg Arg, f func(Arg, *ArgCtx)) { + foreachArgImpl(arg, ArgCtx{}, f) +} + +func ForeachArg(c *Call, f func(Arg, *ArgCtx)) { + ctx := ArgCtx{} + if c.Ret != nil { + foreachArgImpl(c.Ret, ctx, f) + } + ctx.Parent = &c.Args + for _, arg := range c.Args { + foreachArgImpl(arg, ctx, f) + } +} + +func foreachArgImpl(arg Arg, ctx ArgCtx, f func(Arg, *ArgCtx)) { + f(arg, &ctx) + if ctx.Stop { + return + } + switch a := arg.(type) { + case *GroupArg: + if _, ok := a.Type().(*StructType); ok { + ctx.Parent = &a.Inner } + for _, arg1 := range a.Inner { + foreachArgImpl(arg1, ctx, f) + } + case *PointerArg: + if a.Res != nil { + ctx.Base = a + foreachArgImpl(a.Res, ctx, f) + } + case *UnionArg: + foreachArgImpl(a.Option, ctx, f) } - rec(arg, nil, parent) -} - -func ForeachSubarg(arg Arg, f func(arg, base Arg, parent *[]Arg)) { - foreachSubargImpl(arg, nil, f) -} - -func foreachArgArray(args *[]Arg, ret Arg, f func(arg, base Arg, parent *[]Arg)) { - for _, arg := range *args { - foreachSubargImpl(arg, args, f) - } - if ret != nil { - foreachSubargImpl(ret, nil, f) - } -} - -func foreachArg(c *Call, f func(arg, base Arg, parent *[]Arg)) { - foreachArgArray(&c.Args, nil, f) } func foreachSubargOffset(arg Arg, f func(arg Arg, offset uint64)) { @@ -153,10 +157,12 @@ func foreachSubargOffset(arg Arg, f func(arg Arg, offset uint64)) { rec(arg, 0) } +// TODO(dvyukov): combine RequiresBitmasks and RequiresChecksums into a single function +// to not walk the tree twice. They are always used together anyway. func RequiresBitmasks(p *Prog) bool { result := false for _, c := range p.Calls { - foreachArg(c, func(arg, _ Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if a, ok := arg.(*ConstArg); ok { if a.Type().BitfieldOffset() != 0 || a.Type().BitfieldLength() != 0 { result = true @@ -170,7 +176,7 @@ func RequiresBitmasks(p *Prog) bool { func RequiresChecksums(p *Prog) bool { result := false for _, c := range p.Calls { - foreachArg(c, func(arg, _ Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if _, ok := arg.Type().(*CsumType); ok { result = true } diff --git a/prog/checksum.go b/prog/checksum.go index f1d2c0dd..5062ddc6 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -26,8 +26,8 @@ type CsumChunk struct { Size uint64 // for CsumChunkConst } -func getFieldByName(arg Arg, name string) Arg { - for _, field := range arg.(*GroupArg).Inner { +func getFieldByName(arg *GroupArg, name string) Arg { + for _, field := range arg.Inner { if field.Type().FieldName() == name { return field } @@ -35,7 +35,7 @@ func getFieldByName(arg Arg, name string) Arg { panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name())) } -func extractHeaderParamsIPv4(arg Arg) (Arg, Arg) { +func extractHeaderParamsIPv4(arg *GroupArg) (Arg, Arg) { srcAddr := getFieldByName(arg, "src_ip") if srcAddr.Size() != 4 { panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name())) @@ -47,7 +47,7 @@ func extractHeaderParamsIPv4(arg Arg) (Arg, Arg) { return srcAddr, dstAddr } -func extractHeaderParamsIPv6(arg Arg) (Arg, Arg) { +func extractHeaderParamsIPv6(arg *GroupArg) (Arg, Arg) { srcAddr := getFieldByName(arg, "src_ip") if srcAddr.Size() != 16 { panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name())) @@ -100,7 +100,7 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo { var pseudoCsumFields []Arg // Find all csum fields. - foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if typ, ok := arg.Type().(*CsumType); ok { switch typ.Kind { case CsumInet: @@ -120,7 +120,7 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo { // Build map of each field to its parent struct. parentsMap := make(map[Arg]Arg) - foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if _, ok := arg.Type().(*StructType); ok { for _, field := range arg.(*GroupArg).Inner { parentsMap[InnerArg(field)] = arg @@ -146,18 +146,20 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo { } // Extract ipv4 or ipv6 source and destination addresses. - ipv4HeaderParsed := false - ipv6HeaderParsed := false - var ipSrcAddr Arg - var ipDstAddr Arg - foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) { + ipv4HeaderParsed, ipv6HeaderParsed := false, false + var ipSrcAddr, ipDstAddr Arg + ForeachArg(c, func(arg Arg, _ *ArgCtx) { + groupArg, ok := arg.(*GroupArg) + if !ok { + return + } // syz_csum_* structs are used in tests - switch arg.Type().Name() { + switch groupArg.Type().Name() { case "ipv4_header", "syz_csum_ipv4_header": - ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg) + ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(groupArg) ipv4HeaderParsed = true case "ipv6_packet", "syz_csum_ipv6_header": - ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg) + ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(groupArg) ipv6HeaderParsed = true } }) diff --git a/prog/encodingexec.go b/prog/encodingexec.go index 1d5b5e87..d0c8f6c8 100644 --- a/prog/encodingexec.go +++ b/prog/encodingexec.go @@ -86,7 +86,7 @@ func (p *Prog) SerializeForExec(buffer []byte) (int, error) { } // Calculate arg offsets within structs. // Generate copyin instructions that fill in data into pointer arguments. - foreachArg(c, func(arg, _ Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if a, ok := arg.(*PointerArg); ok && a.Res != nil { foreachSubargOffset(a.Res, func(arg1 Arg, offset uint64) { addr := p.Target.PhysicalAddr(arg) + offset @@ -167,7 +167,7 @@ func (p *Prog) SerializeForExec(buffer []byte) (int, error) { w.writeArg(arg) } // Generate copyout instructions that persist interesting return values. - foreachArg(c, func(arg, base Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if !isUsed(arg) { return } @@ -176,9 +176,6 @@ func (p *Prog) SerializeForExec(buffer []byte) (int, error) { // Idx is already assigned above. case *ConstArg, *ResultArg: // Create a separate copyout instruction that has own Idx. - if _, ok := base.(*PointerArg); !ok { - panic("arg base is not a pointer") - } info := w.args[arg] info.Idx = copyoutSeq copyoutSeq++ diff --git a/prog/hints.go b/prog/hints.go index fb3100c3..e6406124 100644 --- a/prog/hints.go +++ b/prog/hints.go @@ -77,16 +77,17 @@ func (p *Prog) MutateWithHints(callIndex int, comps CompMap, exec func(p *Prog)) } exec(p) } - foreachArg(c, func(arg, _ Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { generateHints(p, comps, c, arg, execValidate) }) } func generateHints(p *Prog, compMap CompMap, c *Call, arg Arg, exec func()) { - if arg.Type().Dir() == DirOut { + typ := arg.Type() + if typ == nil || typ.Dir() == DirOut { return } - switch arg.Type().(type) { + switch typ.(type) { case *ProcType: // Random proc will not pass validation. // We can mutate it, but only if the resulting value is within the legal range. diff --git a/prog/hints_test.go b/prog/hints_test.go index 80a9042d..9a87d301 100644 --- a/prog/hints_test.go +++ b/prog/hints_test.go @@ -364,8 +364,8 @@ func TestHintsRandom(t *testing.T) { func extractValues(c *Call) map[uint64]bool { vals := make(map[uint64]bool) - foreachArg(c, func(arg, _ Arg, _ *[]Arg) { - if arg.Type().Dir() == DirOut { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { + if typ := arg.Type(); typ == nil || typ.Dir() == DirOut { return } switch a := arg.(type) { diff --git a/prog/mutation.go b/prog/mutation.go index 0aa06979..83d73a50 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -67,14 +67,14 @@ outer: retryArg := false for stop := false; !stop || retryArg; stop = r.oneOf(3) { retryArg = false - args, bases, parents := p.Target.mutationArgs(c) + args, ctxes := p.Target.mutationArgs(c) if len(args) == 0 { retry = true continue outer } idx := r.Intn(len(args)) - arg, base, parent := args[idx], bases[idx], parents[idx] - calls, ok := p.Target.mutateArg(r, s, arg, base, parent, &updateSizes) + arg, ctx := args[idx], ctxes[idx] + calls, ok := p.Target.mutateArg(r, s, arg, ctx, &updateSizes) if !ok { retryArg = true continue @@ -106,14 +106,10 @@ outer: } } -func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]Arg, updateSizes *bool) (calls []*Call, ok bool) { +func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updateSizes *bool) (calls []*Call, ok bool) { var baseSize uint64 - if base != nil { - b, ok := base.(*PointerArg) - if !ok || b.Res == nil { - panic("bad base arg") - } - baseSize = b.Res.Size() + if ctx.Base != nil { + baseSize = ctx.Base.Res.Size() } switch t := arg.Type().(type) { case *IntType, *FlagsType: @@ -133,7 +129,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]A } } case *LenType: - if !r.mutateSize(arg.(*ConstArg), *parent) { + if !r.mutateSize(arg.(*ConstArg), *ctx.Parent) { return nil, false } *updateSizes = false @@ -261,15 +257,14 @@ func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]A } // Update base pointer if size has increased. - if base != nil { - b := base.(*PointerArg) - if baseSize < b.Res.Size() { - newArg, newCalls := r.addr(s, b.Type(), b.Res.Size(), b.Res) + if base := ctx.Base; base != nil { + if baseSize < base.Res.Size() { + newArg, newCalls := r.addr(s, base.Type(), base.Res.Size(), base.Res) calls = append(calls, newCalls...) a1 := newArg.(*PointerArg) - b.PageIndex = a1.PageIndex - b.PageOffset = a1.PageOffset - b.PagesNum = a1.PagesNum + base.PageIndex = a1.PageIndex + base.PageOffset = a1.PageOffset + base.PagesNum = a1.PagesNum } } for _, c := range calls { @@ -278,29 +273,27 @@ func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]A return calls, true } -func (target *Target) mutationSubargs(arg0 Arg) (args, bases []Arg, parents []*[]Arg) { - ForeachSubarg(arg0, func(arg, base Arg, parent *[]Arg) { - if target.needMutateArg(arg, base, parent) { +func (target *Target) mutationSubargs(arg0 Arg) (args []Arg, ctxes []ArgCtx) { + ForeachSubArg(arg0, func(arg Arg, ctx *ArgCtx) { + if target.needMutateArg(arg, ctx) { args = append(args, arg) - bases = append(bases, base) - parents = append(parents, parent) + ctxes = append(ctxes, *ctx) } }) return } -func (target *Target) mutationArgs(c *Call) (args, bases []Arg, parents []*[]Arg) { - foreachArg(c, func(arg, base Arg, parent *[]Arg) { - if target.needMutateArg(arg, base, parent) { +func (target *Target) mutationArgs(c *Call) (args []Arg, ctxes []ArgCtx) { + ForeachArg(c, func(arg Arg, ctx *ArgCtx) { + if target.needMutateArg(arg, ctx) { args = append(args, arg) - bases = append(bases, base) - parents = append(parents, parent) + ctxes = append(ctxes, *ctx) } }) return } -func (target *Target) needMutateArg(arg, base Arg, parent *[]Arg) bool { +func (target *Target) needMutateArg(arg Arg, ctx *ArgCtx) bool { switch typ := arg.Type().(type) { case *StructType: if target.SpecialTypes[typ.Name()] == nil { @@ -329,20 +322,21 @@ func (target *Target) needMutateArg(arg, base Arg, parent *[]Arg) bool { } } typ := arg.Type() - if typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 { + if typ == nil || typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 { return false } - if base != nil { + if ctx.Base != nil { // TODO(dvyukov): need to check parent as well. // Say, timespec can be part of another struct and base // will point to that other struct, not timespec. // Strictly saying, we need to check parents all way up, // or better bail out from recursion when we reach // a special struct. - _, isStruct := base.Type().(*StructType) - _, isUnion := base.Type().(*UnionType) + baseType := ctx.Base.Type() + _, isStruct := baseType.(*StructType) + _, isUnion := baseType.(*UnionType) if (isStruct || isUnion) && - target.SpecialTypes[base.Type().Name()] != nil { + target.SpecialTypes[baseType.Name()] != nil { // These special structs/unions are mutated as a whole. return false } diff --git a/prog/prog.go b/prog/prog.go index 5247ca8c..68179541 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -476,7 +476,7 @@ func (p *Prog) replaceArgCheck(c *Call, arg, arg1 Arg, calls []*Call) { panic("call is already in prog") } } - foreachArg(c0, func(arg0, _ Arg, _ *[]Arg) { + ForeachArg(c0, func(arg0 Arg, _ *ArgCtx) { if arg0 == arg { if c0 != c { panic("wrong call") @@ -501,7 +501,7 @@ func (p *Prog) replaceArgCheck(c *Call, arg, arg1 Arg, calls []*Call) { // removeArg removes all references to/from arg0 from a program. func removeArg(arg0 Arg) { - ForeachSubarg(arg0, func(arg, _ Arg, _ *[]Arg) { + ForeachSubArg(arg0, func(arg Arg, ctx *ArgCtx) { if a, ok := arg.(*ResultArg); ok && a.Res != nil { if !(*a.Res.(ArgUsed).Used())[arg] { panic("broken tree") diff --git a/prog/rand.go b/prog/rand.go index a67ca7d9..5afd00c9 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -349,7 +349,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls [] } // Discard unsuccessful calls. for _, c := range calls { - foreachArg(c, func(arg, _ Arg, _ *[]Arg) { + ForeachArg(c, func(arg Arg, _ *ArgCtx) { if a, ok := arg.(*ResultArg); ok && a.Res != nil { delete(*a.Res.(ArgUsed).Used(), arg) } diff --git a/prog/size.go b/prog/size.go index 67f7ef75..9f2258c8 100644 --- a/prog/size.go +++ b/prog/size.go @@ -94,19 +94,23 @@ func (target *Target) assignSizes(args []Arg, parentsMap map[Arg]Arg) { func (target *Target) assignSizesArray(args []Arg) { parentsMap := make(map[Arg]Arg) - foreachArgArray(&args, nil, func(arg, base Arg, _ *[]Arg) { - if _, ok := arg.Type().(*StructType); ok { - for _, field := range arg.(*GroupArg).Inner { - parentsMap[InnerArg(field)] = arg + for _, arg := range args { + ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) { + if _, ok := arg.Type().(*StructType); ok { + for _, field := range arg.(*GroupArg).Inner { + parentsMap[InnerArg(field)] = arg + } } - } - }) + }) + } target.assignSizes(args, parentsMap) - foreachArgArray(&args, nil, func(arg, base Arg, _ *[]Arg) { - if _, ok := arg.Type().(*StructType); ok { - target.assignSizes(arg.(*GroupArg).Inner, parentsMap) - } - }) + for _, arg := range args { + ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) { + if _, ok := arg.Type().(*StructType); ok { + target.assignSizes(arg.(*GroupArg).Inner, parentsMap) + } + }) + } } func (target *Target) assignSizesCall(c *Call) { diff --git a/prog/target.go b/prog/target.go index eb238981..d9d31825 100644 --- a/prog/target.go +++ b/prog/target.go @@ -196,15 +196,15 @@ func (g *Gen) generateArg(typ Type, pcalls *[]*Call, ignoreSpecial bool) Arg { func (g *Gen) MutateArg(arg0 Arg) (calls []*Call) { updateSizes := true for stop := false; !stop; stop = g.r.oneOf(3) { - args, bases, parents := g.r.target.mutationSubargs(arg0) + args, ctxes := g.r.target.mutationSubargs(arg0) if len(args) == 0 { // TODO(dvyukov): probably need to return this condition // and updateSizes to caller so that Mutate can act accordingly. return } idx := g.r.Intn(len(args)) - arg, base, parent := args[idx], bases[idx], parents[idx] - newCalls, ok := g.r.target.mutateArg(g.r, g.s, arg, base, parent, &updateSizes) + arg, ctx := args[idx], ctxes[idx] + newCalls, ok := g.r.target.mutateArg(g.r, g.s, arg, ctx, &updateSizes) if !ok { continue } diff --git a/sys/linux/init_iptables.go b/sys/linux/init_iptables.go index 89604e0b..d0e77604 100644 --- a/sys/linux/init_iptables.go +++ b/sys/linux/init_iptables.go @@ -85,7 +85,7 @@ func (arch *arch) generateNetfilterTable(g *prog.Gen, typ prog.Type, old prog.Ar hookArg.Val = pos } // Now update standard target jump offsets. - prog.ForeachSubarg(arg, func(arg, _ prog.Arg, _ *[]prog.Arg) { + prog.ForeachSubArg(arg, func(arg prog.Arg, _ *prog.ArgCtx) { if !strings.HasPrefix(arg.Type().Name(), `xt_target_t["", `) { return }