syzkaller/prog/validation.go
Dmitry Vyukov a7e4a49fae all: spot optimizations
A bunch of spot optmizations after cpu/memory profiling:
1. Optimize hot-path coverage comparison in fuzzer.
2. Don't allocate and copy serialized program, serialize directly into shmem.
3. Reduce allocations during parsing of output shmem (encoding/binary sucks).
4. Don't allocate and copy coverage arrays, refer directly to the shmem region
   (we are not going to mutate them).
5. Don't validate programs outside of tests, validation allocates tons of memory.
6. Replace the choose primitive with simpler switches.
   Choose allocates fullload of memory (for int, func, and everything the func refers).
7. Other minor optimizations.
2017-01-20 23:55:25 +01:00

228 lines
7.2 KiB
Go

// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"fmt"
"github.com/google/syzkaller/sys"
)
var debug = false // enabled in tests
type validCtx struct {
args map[*Arg]bool
uses map[*Arg]*Arg
}
func (p *Prog) validate() error {
if !debug {
return nil
}
ctx := &validCtx{make(map[*Arg]bool), make(map[*Arg]*Arg)}
for _, c := range p.Calls {
if err := c.validate(ctx); err != nil {
return err
}
}
for u, orig := range ctx.uses {
if !ctx.args[u] {
return fmt.Errorf("use of %+v referes to an out-of-tree arg\narg: %#v", *orig, u)
}
}
return nil
}
func (c *Call) validate(ctx *validCtx) error {
if c.Meta == nil {
return fmt.Errorf("call does not have meta information")
}
if len(c.Args) != len(c.Meta.Args) {
return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v", c.Meta.Name, len(c.Meta.Args), len(c.Args))
}
var checkArg func(arg *Arg, typ sys.Type) error
checkArg = func(arg *Arg, typ sys.Type) error {
if arg == nil {
return fmt.Errorf("syscall %v: nil arg", c.Meta.Name)
}
if ctx.args[arg] {
return fmt.Errorf("syscall %v: arg is referenced several times in the tree", c.Meta.Name)
}
ctx.args[arg] = true
for u := range arg.Uses {
ctx.uses[u] = arg
}
if arg.Type == nil {
return fmt.Errorf("syscall %v: no type", c.Meta.Name)
}
if arg.Type.Name() != typ.Name() {
return fmt.Errorf("syscall %v: type name mismatch: %v vs %v", c.Meta.Name, arg.Type.Name(), typ.Name())
}
if arg.Type.Dir() == sys.DirOut {
if (arg.Val != 0 && arg.Val != arg.Type.Default()) || arg.AddrPage != 0 || arg.AddrOffset != 0 {
// We generate output len arguments, which makes sense
// since it can be a length of a variable-length array
// which is not known otherwise.
if _, ok := arg.Type.(*sys.LenType); !ok {
return fmt.Errorf("syscall %v: output arg '%v' has non default value '%v'", c.Meta.Name, typ.Name(), arg.Val)
}
}
for _, v := range arg.Data {
if v != 0 {
return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, typ.Name())
}
}
}
switch typ1 := arg.Type.(type) {
case *sys.ResourceType:
switch arg.Kind {
case ArgResult:
case ArgReturn:
case ArgConst:
if arg.Type.Dir() == sys.DirOut && (arg.Val != 0 && arg.Val != arg.Type.Default()) {
return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, typ.Name(), arg.Val)
}
default:
return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind)
}
case *sys.StructType, *sys.ArrayType:
switch arg.Kind {
case ArgGroup:
default:
return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind)
}
case *sys.UnionType:
switch arg.Kind {
case ArgUnion:
default:
return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind)
}
case *sys.ProcType:
if arg.Val >= uintptr(typ1.ValuesPerProc) {
return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, typ.Name(), arg.Val)
}
case *sys.BufferType:
switch typ1.Kind {
case sys.BufferString:
if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) {
return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, len(arg.Data), typ1.Length)
}
}
}
switch arg.Kind {
case ArgConst:
case ArgResult:
if arg.Res == nil {
return fmt.Errorf("syscall %v: result arg '%v' has no reference", c.Meta.Name, typ.Name())
}
if !ctx.args[arg.Res] {
return fmt.Errorf("syscall %v: result arg '%v' references out-of-tree result: %p%+v -> %p%+v",
c.Meta.Name, typ.Name(), arg, arg, arg.Res, arg.Res)
}
if _, ok := arg.Res.Uses[arg]; !ok {
return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, typ.Name(), arg.Res.Uses)
}
case ArgPointer:
switch typ1 := typ.(type) {
case *sys.VmaType:
if arg.Res != nil {
return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, typ.Name())
}
if arg.AddrPagesNum == 0 {
return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, typ.Name())
}
case *sys.PtrType:
if arg.Type.Dir() == sys.DirOut {
return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, typ.Name())
}
if arg.Res == nil && !typ.Optional() {
return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, typ.Name())
}
if arg.Res != nil {
if err := checkArg(arg.Res, typ1.Type); err != nil {
return err
}
}
if arg.AddrPagesNum != 0 {
return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, typ.Name())
}
default:
return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, typ.Name(), typ)
}
case ArgPageSize:
case ArgData:
switch typ1 := typ.(type) {
case *sys.ArrayType:
if typ2, ok := typ1.Type.(*sys.IntType); !ok || typ2.Size() != 1 {
return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, typ.Name())
}
}
case ArgGroup:
switch typ1 := typ.(type) {
case *sys.StructType:
if len(arg.Inner) != len(typ1.Fields) {
return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, typ.Name(), len(typ1.Fields), len(arg.Inner))
}
for i, arg1 := range arg.Inner {
if err := checkArg(arg1, typ1.Fields[i]); err != nil {
return err
}
}
case *sys.ArrayType:
for _, arg1 := range arg.Inner {
if err := checkArg(arg1, typ1.Type); err != nil {
return err
}
}
default:
return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, typ.Name(), typ)
}
case ArgUnion:
typ1, ok := typ.(*sys.UnionType)
if !ok {
return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, typ.Name())
}
found := false
for _, typ2 := range typ1.Options {
if arg.OptionType.Name() == typ2.Name() {
found = true
break
}
}
if !found {
return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, typ.Name())
}
if err := checkArg(arg.Option, arg.OptionType); err != nil {
return err
}
case ArgReturn:
default:
return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, typ.Name())
}
return nil
}
for i, arg := range c.Args {
if arg.Kind == ArgReturn {
return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", c.Meta.Name, arg.Type.Name())
}
if err := checkArg(arg, c.Meta.Args[i]); err != nil {
return err
}
}
if c.Ret == nil {
return fmt.Errorf("syscall %v: return value is absent", c.Meta.Name)
}
if c.Ret.Kind != ArgReturn {
return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret.Kind)
}
if c.Meta.Ret != nil {
if err := checkArg(c.Ret, c.Meta.Ret); err != nil {
return err
}
} else if c.Ret.Type != nil {
return fmt.Errorf("syscall %v: return value has spurious type: %+v", c.Meta.Name, c.Ret.Type)
}
return nil
}