syzkaller/prog/encoding.go
Dmitry Vyukov a7e4a49fae all: spot optimizations
A bunch of spot optmizations after cpu/memory profiling:
1. Optimize hot-path coverage comparison in fuzzer.
2. Don't allocate and copy serialized program, serialize directly into shmem.
3. Reduce allocations during parsing of output shmem (encoding/binary sucks).
4. Don't allocate and copy coverage arrays, refer directly to the shmem region
   (we are not going to mutate them).
5. Don't validate programs outside of tests, validation allocates tons of memory.
6. Replace the choose primitive with simpler switches.
   Choose allocates fullload of memory (for int, func, and everything the func refers).
7. Other minor optimizations.
2017-01-20 23:55:25 +01:00

556 lines
12 KiB
Go

// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"io"
"strconv"
"github.com/google/syzkaller/sys"
)
// String generates a very compact program description (mostly for debug output).
func (p *Prog) String() string {
buf := new(bytes.Buffer)
for i, c := range p.Calls {
if i != 0 {
fmt.Fprintf(buf, "-")
}
fmt.Fprintf(buf, "%v", c.Meta.Name)
}
return buf.String()
}
func (p *Prog) Serialize() []byte {
if err := p.validate(); err != nil {
panic("serializing invalid program")
}
buf := new(bytes.Buffer)
vars := make(map[*Arg]int)
varSeq := 0
for _, c := range p.Calls {
if len(c.Ret.Uses) != 0 {
fmt.Fprintf(buf, "r%v = ", varSeq)
vars[c.Ret] = varSeq
varSeq++
}
fmt.Fprintf(buf, "%v(", c.Meta.Name)
for i, a := range c.Args {
if sys.IsPad(a.Type) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
a.serialize(buf, vars, &varSeq)
}
fmt.Fprintf(buf, ")\n")
}
return buf.Bytes()
}
func (a *Arg) serialize(buf io.Writer, vars map[*Arg]int, varSeq *int) {
if a == nil {
fmt.Fprintf(buf, "nil")
return
}
if len(a.Uses) != 0 {
fmt.Fprintf(buf, "<r%v=>", *varSeq)
vars[a] = *varSeq
*varSeq++
}
switch a.Kind {
case ArgConst:
fmt.Fprintf(buf, "0x%x", a.Val)
case ArgResult:
id, ok := vars[a.Res]
if !ok {
panic("no result")
}
fmt.Fprintf(buf, "r%v", id)
if a.OpDiv != 0 {
fmt.Fprintf(buf, "/%v", a.OpDiv)
}
if a.OpAdd != 0 {
fmt.Fprintf(buf, "+%v", a.OpAdd)
}
case ArgPointer:
fmt.Fprintf(buf, "&%v=", serializeAddr(a, true))
a.Res.serialize(buf, vars, varSeq)
case ArgPageSize:
fmt.Fprintf(buf, "%v", serializeAddr(a, false))
case ArgData:
fmt.Fprintf(buf, "\"%v\"", hex.EncodeToString(a.Data))
case ArgGroup:
var delims []byte
switch a.Type.(type) {
case *sys.StructType:
delims = []byte{'{', '}'}
case *sys.ArrayType:
delims = []byte{'[', ']'}
default:
panic("unknown group type")
}
buf.Write([]byte{delims[0]})
for i, a1 := range a.Inner {
if a1 != nil && sys.IsPad(a1.Type) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
a1.serialize(buf, vars, varSeq)
}
buf.Write([]byte{delims[1]})
case ArgUnion:
fmt.Fprintf(buf, "@%v=", a.OptionType.Name())
a.Option.serialize(buf, vars, varSeq)
default:
panic("unknown arg kind")
}
}
func Deserialize(data []byte) (prog *Prog, err error) {
prog = new(Prog)
p := &parser{r: bufio.NewScanner(bytes.NewReader(data))}
p.r.Buffer(nil, maxLineLen)
vars := make(map[string]*Arg)
for p.Scan() {
if p.EOF() || p.Char() == '#' {
continue
}
name := p.Ident()
r := ""
if p.Char() == '=' {
r = name
p.Parse('=')
name = p.Ident()
}
meta := sys.CallMap[name]
if meta == nil {
return nil, fmt.Errorf("unknown syscall %v", name)
}
c := &Call{
Meta: meta,
Ret: returnArg(meta.Ret),
}
prog.Calls = append(prog.Calls, c)
p.Parse('(')
for i := 0; p.Char() != ')'; i++ {
if i >= len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", i+1, len(meta.Args))
}
typ := meta.Args[i]
if sys.IsPad(typ) {
return nil, fmt.Errorf("padding in syscall %v arguments", name)
}
arg, err := parseArg(typ, p, vars)
if err != nil {
return nil, err
}
c.Args = append(c.Args, arg)
if p.Char() != ')' {
p.Parse(',')
}
}
p.Parse(')')
if !p.EOF() {
return nil, fmt.Errorf("tailing data (line #%v)", p.l)
}
if len(c.Args) != len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
}
if r != "" {
vars[r] = c.Ret
}
}
if err := p.Err(); err != nil {
return nil, err
}
if err := prog.validate(); err != nil {
return nil, err
}
return
}
func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
r = p.Ident()
p.Parse('=')
p.Parse('>')
}
var arg *Arg
switch p.Char() {
case '0':
val := p.Ident()
v, err := strconv.ParseUint(val, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong arg value '%v': %v", val, err)
}
arg = constArg(typ, uintptr(v))
case 'r':
id := p.Ident()
v, ok := vars[id]
if !ok || v == nil {
return nil, fmt.Errorf("result %v references unknown variable (vars=%+v)", id, vars)
}
arg = resultArg(typ, v)
if p.Char() == '/' {
p.Parse('/')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result div op: '%v'", op)
}
arg.OpDiv = uintptr(v)
}
if p.Char() == '+' {
p.Parse('+')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result add op: '%v'", op)
}
arg.OpAdd = uintptr(v)
}
case '&':
var typ1 sys.Type
switch t1 := typ.(type) {
case *sys.PtrType:
typ1 = t1.Type
case *sys.VmaType:
default:
return nil, fmt.Errorf("& arg is not a pointer: %#v", typ)
}
p.Parse('&')
page, off, size, err := parseAddr(p, true)
if err != nil {
return nil, err
}
p.Parse('=')
inner, err := parseArg(typ1, p, vars)
if err != nil {
return nil, err
}
arg = pointerArg(typ, page, off, size, inner)
case '(':
page, off, _, err := parseAddr(p, false)
if err != nil {
return nil, err
}
arg = pageSizeArg(typ, page, off)
case '"':
p.Parse('"')
val := ""
if p.Char() != '"' {
val = p.Ident()
}
p.Parse('"')
data, err := hex.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("data arg has bad value '%v'", val)
}
arg = dataArg(typ, data)
case '{':
t1, ok := typ.(*sys.StructType)
if !ok {
return nil, fmt.Errorf("'{' arg is not a struct: %#v", typ)
}
p.Parse('{')
var inner []*Arg
for i := 0; p.Char() != '}'; i++ {
if i >= len(t1.Fields) {
return nil, fmt.Errorf("wrong struct arg count: %v, want %v", i+1, len(t1.Fields))
}
fld := t1.Fields[i]
if sys.IsPad(fld) {
inner = append(inner, constArg(fld, 0))
} else {
arg, err := parseArg(fld, p, vars)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != '}' {
p.Parse(',')
}
}
}
p.Parse('}')
if last := t1.Fields[len(t1.Fields)-1]; sys.IsPad(last) {
inner = append(inner, constArg(last, 0))
}
arg = groupArg(typ, inner)
case '[':
t1, ok := typ.(*sys.ArrayType)
if !ok {
return nil, fmt.Errorf("'[' arg is not an array: %#v", typ)
}
p.Parse('[')
var inner []*Arg
for i := 0; p.Char() != ']'; i++ {
arg, err := parseArg(t1.Type, p, vars)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != ']' {
p.Parse(',')
}
}
p.Parse(']')
arg = groupArg(typ, inner)
case '@':
t1, ok := typ.(*sys.UnionType)
if !ok {
return nil, fmt.Errorf("'@' arg is not a union: %#v", typ)
}
p.Parse('@')
name := p.Ident()
p.Parse('=')
var optType sys.Type
for _, t2 := range t1.Options {
if name == t2.Name() {
optType = t2
break
}
}
if optType == nil {
return nil, fmt.Errorf("union arg %v has unknown option: %v", typ.Name(), name)
}
opt, err := parseArg(optType, p, vars)
if err != nil {
return nil, err
}
arg = unionArg(typ, opt, optType)
case 'n':
p.Parse('n')
p.Parse('i')
p.Parse('l')
if r != "" {
return nil, fmt.Errorf("named nil argument")
}
default:
return nil, fmt.Errorf("failed to parse argument at %v (line #%v/%v: %v)", int(p.Char()), p.l, p.i, p.s)
}
if r != "" {
vars[r] = arg
}
return arg, nil
}
const (
encodingAddrBase = 0x7f0000000000
encodingPageSize = 4 << 10
maxLineLen = 256 << 10
)
func serializeAddr(a *Arg, base bool) string {
page := a.AddrPage * encodingPageSize
if base {
page += encodingAddrBase
}
soff := ""
if off := a.AddrOffset; off != 0 {
sign := "+"
if off < 0 {
sign = "-"
off = -off
page += encodingPageSize
}
soff = fmt.Sprintf("%v0x%x", sign, off)
}
ssize := ""
if size := a.AddrPagesNum; size != 0 {
size *= encodingPageSize
ssize = fmt.Sprintf("/0x%x", size)
}
return fmt.Sprintf("(0x%x%v%v)", page, soff, ssize)
}
func parseAddr(p *parser, base bool) (uintptr, int, uintptr, error) {
p.Parse('(')
pstr := p.Ident()
page, err := strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr page: '%v'", pstr)
}
if page%encodingPageSize != 0 {
return 0, 0, 0, fmt.Errorf("address base is not page size aligned: '%v'", pstr)
}
if base {
if page < encodingAddrBase {
return 0, 0, 0, fmt.Errorf("address without base offset: '%v'", pstr)
}
page -= encodingAddrBase
}
var off int64
if p.Char() == '+' || p.Char() == '-' {
minus := false
if p.Char() == '-' {
minus = true
p.Parse('-')
} else {
p.Parse('+')
}
ostr := p.Ident()
off, err = strconv.ParseInt(ostr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr offset: '%v'", ostr)
}
if minus {
page -= encodingPageSize
off = -off
}
}
var size uint64
if p.Char() == '/' {
p.Parse('/')
pstr := p.Ident()
size, err = strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr size: '%v'", pstr)
}
}
p.Parse(')')
page /= encodingPageSize
size /= encodingPageSize
return uintptr(page), int(off), uintptr(size), nil
}
type parser struct {
r *bufio.Scanner
s string
i int
l int
e error
}
func (p *parser) Scan() bool {
if p.e != nil {
return false
}
if !p.r.Scan() {
p.e = p.r.Err()
return false
}
p.s = p.r.Text()
p.i = 0
p.l++
return true
}
func (p *parser) Err() error {
return p.e
}
func (p *parser) Str() string {
return p.s
}
func (p *parser) EOF() bool {
return p.i == len(p.s)
}
func (p *parser) Char() byte {
if p.e != nil {
return 0
}
if p.EOF() {
p.failf("unexpected eof")
return 0
}
return p.s[p.i]
}
func (p *parser) Parse(ch byte) {
if p.e != nil {
return
}
if p.EOF() {
p.failf("want %s, got EOF", string(ch))
return
}
if p.s[p.i] != ch {
p.failf("want '%v', got '%v'", string(ch), string(p.s[p.i]))
return
}
p.i++
p.SkipWs()
}
func (p *parser) SkipWs() {
for p.i < len(p.s) && (p.s[p.i] == ' ' || p.s[p.i] == '\t') {
p.i++
}
}
func (p *parser) Ident() string {
i := p.i
for p.i < len(p.s) &&
(p.s[p.i] >= 'a' && p.s[p.i] <= 'z' ||
p.s[p.i] >= 'A' && p.s[p.i] <= 'Z' ||
p.s[p.i] >= '0' && p.s[p.i] <= '9' ||
p.s[p.i] == '_' || p.s[p.i] == '$') {
p.i++
}
if i == p.i {
p.failf("failed to parse identifier at pos %v", i)
return ""
}
if ch := p.s[i]; ch >= '0' && ch <= '9' {
}
s := p.s[i:p.i]
p.SkipWs()
return s
}
func (p *parser) failf(msg string, args ...interface{}) {
p.e = fmt.Errorf("%v\nline #%v: %v", fmt.Sprintf(msg, args...), p.l, p.s)
}
// CallSet returns a set of all calls in the program.
// It does very conservative parsing and is intended to parse paste/future serialization formats.
func CallSet(data []byte) (map[string]struct{}, error) {
calls := make(map[string]struct{})
s := bufio.NewScanner(bytes.NewReader(data))
s.Buffer(nil, maxLineLen)
for s.Scan() {
ln := s.Bytes()
if len(ln) == 0 || ln[0] == '#' {
continue
}
bracket := bytes.IndexByte(ln, '(')
if bracket == -1 {
return nil, fmt.Errorf("line does not contain opening bracket")
}
call := ln[:bracket]
if eq := bytes.IndexByte(call, '='); eq != -1 {
eq++
for eq < len(call) && call[eq] == ' ' {
eq++
}
call = call[eq:]
}
if len(call) == 0 {
return nil, fmt.Errorf("call name is empty")
}
calls[string(call)] = struct{}{}
}
if err := s.Err(); err != nil {
return nil, err
}
if len(calls) == 0 {
return nil, fmt.Errorf("program does not contain any calls")
}
return calls, nil
}