prog: rework foreachArg

Make Foreach* callback accept the arg and a context struct
that can contain lots of aux info.
This (1) removes lots of unuser base/parent args,
(2) provides foundation for stopping recursion,
(3) allows to merge foreachSubargOffset.
This commit is contained in:
Dmitry Vyukov 2018-02-18 13:49:48 +01:00
parent 2be2288ee2
commit 85d1218f41
11 changed files with 120 additions and 116 deletions

View File

@ -49,7 +49,7 @@ func newState(target *Target, ct *ChoiceTable) *state {
}
func (s *state) analyze(c *Call) {
foreachArgArray(&c.Args, c.Ret, func(arg, base Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
switch typ := arg.Type().(type) {
case *ResourceType:
if typ.Dir() != DirIn {
@ -80,45 +80,49 @@ func (s *state) analyze(c *Call) {
}
}
func foreachSubargImpl(arg Arg, parent *[]Arg, f func(arg, base Arg, parent *[]Arg)) {
var rec func(arg, base Arg, parent *[]Arg)
rec = func(arg, base Arg, parent *[]Arg) {
f(arg, base, parent)
switch a := arg.(type) {
case *GroupArg:
for _, arg1 := range a.Inner {
parent1 := parent
if _, ok := arg.Type().(*StructType); ok {
parent1 = &a.Inner
}
rec(arg1, base, parent1)
}
case *PointerArg:
if a.Res != nil {
rec(a.Res, arg, parent)
}
case *UnionArg:
rec(a.Option, base, parent)
type ArgCtx struct {
Parent *[]Arg // GroupArg.Inner (for structs) or Call.Args containing this arg
Base *PointerArg // pointer to the base of the heap object containing this arg
Offset uint64 // offset of this arg from the base
Stop bool // if set by the callback, subargs of this arg are not visited
}
func ForeachSubArg(arg Arg, f func(Arg, *ArgCtx)) {
foreachArgImpl(arg, ArgCtx{}, f)
}
func ForeachArg(c *Call, f func(Arg, *ArgCtx)) {
ctx := ArgCtx{}
if c.Ret != nil {
foreachArgImpl(c.Ret, ctx, f)
}
ctx.Parent = &c.Args
for _, arg := range c.Args {
foreachArgImpl(arg, ctx, f)
}
}
func foreachArgImpl(arg Arg, ctx ArgCtx, f func(Arg, *ArgCtx)) {
f(arg, &ctx)
if ctx.Stop {
return
}
switch a := arg.(type) {
case *GroupArg:
if _, ok := a.Type().(*StructType); ok {
ctx.Parent = &a.Inner
}
for _, arg1 := range a.Inner {
foreachArgImpl(arg1, ctx, f)
}
case *PointerArg:
if a.Res != nil {
ctx.Base = a
foreachArgImpl(a.Res, ctx, f)
}
case *UnionArg:
foreachArgImpl(a.Option, ctx, f)
}
rec(arg, nil, parent)
}
func ForeachSubarg(arg Arg, f func(arg, base Arg, parent *[]Arg)) {
foreachSubargImpl(arg, nil, f)
}
func foreachArgArray(args *[]Arg, ret Arg, f func(arg, base Arg, parent *[]Arg)) {
for _, arg := range *args {
foreachSubargImpl(arg, args, f)
}
if ret != nil {
foreachSubargImpl(ret, nil, f)
}
}
func foreachArg(c *Call, f func(arg, base Arg, parent *[]Arg)) {
foreachArgArray(&c.Args, nil, f)
}
func foreachSubargOffset(arg Arg, f func(arg Arg, offset uint64)) {
@ -153,10 +157,12 @@ func foreachSubargOffset(arg Arg, f func(arg Arg, offset uint64)) {
rec(arg, 0)
}
// TODO(dvyukov): combine RequiresBitmasks and RequiresChecksums into a single function
// to not walk the tree twice. They are always used together anyway.
func RequiresBitmasks(p *Prog) bool {
result := false
for _, c := range p.Calls {
foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if a, ok := arg.(*ConstArg); ok {
if a.Type().BitfieldOffset() != 0 || a.Type().BitfieldLength() != 0 {
result = true
@ -170,7 +176,7 @@ func RequiresBitmasks(p *Prog) bool {
func RequiresChecksums(p *Prog) bool {
result := false
for _, c := range p.Calls {
foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if _, ok := arg.Type().(*CsumType); ok {
result = true
}

View File

@ -26,8 +26,8 @@ type CsumChunk struct {
Size uint64 // for CsumChunkConst
}
func getFieldByName(arg Arg, name string) Arg {
for _, field := range arg.(*GroupArg).Inner {
func getFieldByName(arg *GroupArg, name string) Arg {
for _, field := range arg.Inner {
if field.Type().FieldName() == name {
return field
}
@ -35,7 +35,7 @@ func getFieldByName(arg Arg, name string) Arg {
panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name()))
}
func extractHeaderParamsIPv4(arg Arg) (Arg, Arg) {
func extractHeaderParamsIPv4(arg *GroupArg) (Arg, Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 4 {
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name()))
@ -47,7 +47,7 @@ func extractHeaderParamsIPv4(arg Arg) (Arg, Arg) {
return srcAddr, dstAddr
}
func extractHeaderParamsIPv6(arg Arg) (Arg, Arg) {
func extractHeaderParamsIPv6(arg *GroupArg) (Arg, Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 16 {
panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name()))
@ -100,7 +100,7 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo {
var pseudoCsumFields []Arg
// Find all csum fields.
foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if typ, ok := arg.Type().(*CsumType); ok {
switch typ.Kind {
case CsumInet:
@ -120,7 +120,7 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo {
// Build map of each field to its parent struct.
parentsMap := make(map[Arg]Arg)
foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if _, ok := arg.Type().(*StructType); ok {
for _, field := range arg.(*GroupArg).Inner {
parentsMap[InnerArg(field)] = arg
@ -146,18 +146,20 @@ func calcChecksumsCall(c *Call) map[Arg]CsumInfo {
}
// Extract ipv4 or ipv6 source and destination addresses.
ipv4HeaderParsed := false
ipv6HeaderParsed := false
var ipSrcAddr Arg
var ipDstAddr Arg
foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) {
ipv4HeaderParsed, ipv6HeaderParsed := false, false
var ipSrcAddr, ipDstAddr Arg
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
groupArg, ok := arg.(*GroupArg)
if !ok {
return
}
// syz_csum_* structs are used in tests
switch arg.Type().Name() {
switch groupArg.Type().Name() {
case "ipv4_header", "syz_csum_ipv4_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg)
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(groupArg)
ipv4HeaderParsed = true
case "ipv6_packet", "syz_csum_ipv6_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg)
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(groupArg)
ipv6HeaderParsed = true
}
})

View File

@ -86,7 +86,7 @@ func (p *Prog) SerializeForExec(buffer []byte) (int, error) {
}
// Calculate arg offsets within structs.
// Generate copyin instructions that fill in data into pointer arguments.
foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if a, ok := arg.(*PointerArg); ok && a.Res != nil {
foreachSubargOffset(a.Res, func(arg1 Arg, offset uint64) {
addr := p.Target.PhysicalAddr(arg) + offset
@ -167,7 +167,7 @@ func (p *Prog) SerializeForExec(buffer []byte) (int, error) {
w.writeArg(arg)
}
// Generate copyout instructions that persist interesting return values.
foreachArg(c, func(arg, base Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if !isUsed(arg) {
return
}
@ -176,9 +176,6 @@ func (p *Prog) SerializeForExec(buffer []byte) (int, error) {
// Idx is already assigned above.
case *ConstArg, *ResultArg:
// Create a separate copyout instruction that has own Idx.
if _, ok := base.(*PointerArg); !ok {
panic("arg base is not a pointer")
}
info := w.args[arg]
info.Idx = copyoutSeq
copyoutSeq++

View File

@ -77,16 +77,17 @@ func (p *Prog) MutateWithHints(callIndex int, comps CompMap, exec func(p *Prog))
}
exec(p)
}
foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
generateHints(p, comps, c, arg, execValidate)
})
}
func generateHints(p *Prog, compMap CompMap, c *Call, arg Arg, exec func()) {
if arg.Type().Dir() == DirOut {
typ := arg.Type()
if typ == nil || typ.Dir() == DirOut {
return
}
switch arg.Type().(type) {
switch typ.(type) {
case *ProcType:
// Random proc will not pass validation.
// We can mutate it, but only if the resulting value is within the legal range.

View File

@ -364,8 +364,8 @@ func TestHintsRandom(t *testing.T) {
func extractValues(c *Call) map[uint64]bool {
vals := make(map[uint64]bool)
foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
if arg.Type().Dir() == DirOut {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if typ := arg.Type(); typ == nil || typ.Dir() == DirOut {
return
}
switch a := arg.(type) {

View File

@ -67,14 +67,14 @@ outer:
retryArg := false
for stop := false; !stop || retryArg; stop = r.oneOf(3) {
retryArg = false
args, bases, parents := p.Target.mutationArgs(c)
args, ctxes := p.Target.mutationArgs(c)
if len(args) == 0 {
retry = true
continue outer
}
idx := r.Intn(len(args))
arg, base, parent := args[idx], bases[idx], parents[idx]
calls, ok := p.Target.mutateArg(r, s, arg, base, parent, &updateSizes)
arg, ctx := args[idx], ctxes[idx]
calls, ok := p.Target.mutateArg(r, s, arg, ctx, &updateSizes)
if !ok {
retryArg = true
continue
@ -106,14 +106,10 @@ outer:
}
}
func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]Arg, updateSizes *bool) (calls []*Call, ok bool) {
func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updateSizes *bool) (calls []*Call, ok bool) {
var baseSize uint64
if base != nil {
b, ok := base.(*PointerArg)
if !ok || b.Res == nil {
panic("bad base arg")
}
baseSize = b.Res.Size()
if ctx.Base != nil {
baseSize = ctx.Base.Res.Size()
}
switch t := arg.Type().(type) {
case *IntType, *FlagsType:
@ -133,7 +129,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]A
}
}
case *LenType:
if !r.mutateSize(arg.(*ConstArg), *parent) {
if !r.mutateSize(arg.(*ConstArg), *ctx.Parent) {
return nil, false
}
*updateSizes = false
@ -261,15 +257,14 @@ func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]A
}
// Update base pointer if size has increased.
if base != nil {
b := base.(*PointerArg)
if baseSize < b.Res.Size() {
newArg, newCalls := r.addr(s, b.Type(), b.Res.Size(), b.Res)
if base := ctx.Base; base != nil {
if baseSize < base.Res.Size() {
newArg, newCalls := r.addr(s, base.Type(), base.Res.Size(), base.Res)
calls = append(calls, newCalls...)
a1 := newArg.(*PointerArg)
b.PageIndex = a1.PageIndex
b.PageOffset = a1.PageOffset
b.PagesNum = a1.PagesNum
base.PageIndex = a1.PageIndex
base.PageOffset = a1.PageOffset
base.PagesNum = a1.PagesNum
}
}
for _, c := range calls {
@ -278,29 +273,27 @@ func (target *Target) mutateArg(r *randGen, s *state, arg, base Arg, parent *[]A
return calls, true
}
func (target *Target) mutationSubargs(arg0 Arg) (args, bases []Arg, parents []*[]Arg) {
ForeachSubarg(arg0, func(arg, base Arg, parent *[]Arg) {
if target.needMutateArg(arg, base, parent) {
func (target *Target) mutationSubargs(arg0 Arg) (args []Arg, ctxes []ArgCtx) {
ForeachSubArg(arg0, func(arg Arg, ctx *ArgCtx) {
if target.needMutateArg(arg, ctx) {
args = append(args, arg)
bases = append(bases, base)
parents = append(parents, parent)
ctxes = append(ctxes, *ctx)
}
})
return
}
func (target *Target) mutationArgs(c *Call) (args, bases []Arg, parents []*[]Arg) {
foreachArg(c, func(arg, base Arg, parent *[]Arg) {
if target.needMutateArg(arg, base, parent) {
func (target *Target) mutationArgs(c *Call) (args []Arg, ctxes []ArgCtx) {
ForeachArg(c, func(arg Arg, ctx *ArgCtx) {
if target.needMutateArg(arg, ctx) {
args = append(args, arg)
bases = append(bases, base)
parents = append(parents, parent)
ctxes = append(ctxes, *ctx)
}
})
return
}
func (target *Target) needMutateArg(arg, base Arg, parent *[]Arg) bool {
func (target *Target) needMutateArg(arg Arg, ctx *ArgCtx) bool {
switch typ := arg.Type().(type) {
case *StructType:
if target.SpecialTypes[typ.Name()] == nil {
@ -329,20 +322,21 @@ func (target *Target) needMutateArg(arg, base Arg, parent *[]Arg) bool {
}
}
typ := arg.Type()
if typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 {
if typ == nil || typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 {
return false
}
if base != nil {
if ctx.Base != nil {
// TODO(dvyukov): need to check parent as well.
// Say, timespec can be part of another struct and base
// will point to that other struct, not timespec.
// Strictly saying, we need to check parents all way up,
// or better bail out from recursion when we reach
// a special struct.
_, isStruct := base.Type().(*StructType)
_, isUnion := base.Type().(*UnionType)
baseType := ctx.Base.Type()
_, isStruct := baseType.(*StructType)
_, isUnion := baseType.(*UnionType)
if (isStruct || isUnion) &&
target.SpecialTypes[base.Type().Name()] != nil {
target.SpecialTypes[baseType.Name()] != nil {
// These special structs/unions are mutated as a whole.
return false
}

View File

@ -476,7 +476,7 @@ func (p *Prog) replaceArgCheck(c *Call, arg, arg1 Arg, calls []*Call) {
panic("call is already in prog")
}
}
foreachArg(c0, func(arg0, _ Arg, _ *[]Arg) {
ForeachArg(c0, func(arg0 Arg, _ *ArgCtx) {
if arg0 == arg {
if c0 != c {
panic("wrong call")
@ -501,7 +501,7 @@ func (p *Prog) replaceArgCheck(c *Call, arg, arg1 Arg, calls []*Call) {
// removeArg removes all references to/from arg0 from a program.
func removeArg(arg0 Arg) {
ForeachSubarg(arg0, func(arg, _ Arg, _ *[]Arg) {
ForeachSubArg(arg0, func(arg Arg, ctx *ArgCtx) {
if a, ok := arg.(*ResultArg); ok && a.Res != nil {
if !(*a.Res.(ArgUsed).Used())[arg] {
panic("broken tree")

View File

@ -349,7 +349,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []
}
// Discard unsuccessful calls.
for _, c := range calls {
foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
if a, ok := arg.(*ResultArg); ok && a.Res != nil {
delete(*a.Res.(ArgUsed).Used(), arg)
}

View File

@ -94,19 +94,23 @@ func (target *Target) assignSizes(args []Arg, parentsMap map[Arg]Arg) {
func (target *Target) assignSizesArray(args []Arg) {
parentsMap := make(map[Arg]Arg)
foreachArgArray(&args, nil, func(arg, base Arg, _ *[]Arg) {
if _, ok := arg.Type().(*StructType); ok {
for _, field := range arg.(*GroupArg).Inner {
parentsMap[InnerArg(field)] = arg
for _, arg := range args {
ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) {
if _, ok := arg.Type().(*StructType); ok {
for _, field := range arg.(*GroupArg).Inner {
parentsMap[InnerArg(field)] = arg
}
}
}
})
})
}
target.assignSizes(args, parentsMap)
foreachArgArray(&args, nil, func(arg, base Arg, _ *[]Arg) {
if _, ok := arg.Type().(*StructType); ok {
target.assignSizes(arg.(*GroupArg).Inner, parentsMap)
}
})
for _, arg := range args {
ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) {
if _, ok := arg.Type().(*StructType); ok {
target.assignSizes(arg.(*GroupArg).Inner, parentsMap)
}
})
}
}
func (target *Target) assignSizesCall(c *Call) {

View File

@ -196,15 +196,15 @@ func (g *Gen) generateArg(typ Type, pcalls *[]*Call, ignoreSpecial bool) Arg {
func (g *Gen) MutateArg(arg0 Arg) (calls []*Call) {
updateSizes := true
for stop := false; !stop; stop = g.r.oneOf(3) {
args, bases, parents := g.r.target.mutationSubargs(arg0)
args, ctxes := g.r.target.mutationSubargs(arg0)
if len(args) == 0 {
// TODO(dvyukov): probably need to return this condition
// and updateSizes to caller so that Mutate can act accordingly.
return
}
idx := g.r.Intn(len(args))
arg, base, parent := args[idx], bases[idx], parents[idx]
newCalls, ok := g.r.target.mutateArg(g.r, g.s, arg, base, parent, &updateSizes)
arg, ctx := args[idx], ctxes[idx]
newCalls, ok := g.r.target.mutateArg(g.r, g.s, arg, ctx, &updateSizes)
if !ok {
continue
}

View File

@ -85,7 +85,7 @@ func (arch *arch) generateNetfilterTable(g *prog.Gen, typ prog.Type, old prog.Ar
hookArg.Val = pos
}
// Now update standard target jump offsets.
prog.ForeachSubarg(arg, func(arg, _ prog.Arg, _ *[]prog.Arg) {
prog.ForeachSubArg(arg, func(arg prog.Arg, _ *prog.ArgCtx) {
if !strings.HasPrefix(arg.Type().Name(), `xt_target_t["", `) {
return
}