syzkaller/prog/encoding.go
Dmitry Vyukov 0913359f79 prog: increase line length limit when deserializing programs
bufio.Scanner has a default limit of 4K per line,
if a program contains longer line, it fails.
Extend the limit to 64K.
Also check scanning errors. Turns out even scanning of bytes.Buffer
can fail due to the line limit.
2017-01-09 20:19:44 +01:00

558 lines
12 KiB
Go

// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"io"
"strconv"
"github.com/google/syzkaller/sys"
)
// String generates a very compact program description (mostly for debug output).
func (p *Prog) String() string {
buf := new(bytes.Buffer)
for i, c := range p.Calls {
if i != 0 {
fmt.Fprintf(buf, "-")
}
fmt.Fprintf(buf, "%v", c.Meta.Name)
}
return buf.String()
}
func (p *Prog) Serialize() []byte {
/*
if err := p.validate(); err != nil {
panic("serializing invalid program")
}
*/
buf := new(bytes.Buffer)
vars := make(map[*Arg]int)
varSeq := 0
for _, c := range p.Calls {
if len(c.Ret.Uses) != 0 {
fmt.Fprintf(buf, "r%v = ", varSeq)
vars[c.Ret] = varSeq
varSeq++
}
fmt.Fprintf(buf, "%v(", c.Meta.Name)
for i, a := range c.Args {
if sys.IsPad(a.Type) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
a.serialize(buf, vars, &varSeq)
}
fmt.Fprintf(buf, ")\n")
}
return buf.Bytes()
}
func (a *Arg) serialize(buf io.Writer, vars map[*Arg]int, varSeq *int) {
if a == nil {
fmt.Fprintf(buf, "nil")
return
}
if len(a.Uses) != 0 {
fmt.Fprintf(buf, "<r%v=>", *varSeq)
vars[a] = *varSeq
*varSeq++
}
switch a.Kind {
case ArgConst:
fmt.Fprintf(buf, "0x%x", a.Val)
case ArgResult:
id, ok := vars[a.Res]
if !ok {
panic("no result")
}
fmt.Fprintf(buf, "r%v", id)
if a.OpDiv != 0 {
fmt.Fprintf(buf, "/%v", a.OpDiv)
}
if a.OpAdd != 0 {
fmt.Fprintf(buf, "+%v", a.OpAdd)
}
case ArgPointer:
fmt.Fprintf(buf, "&%v=", serializeAddr(a, true))
a.Res.serialize(buf, vars, varSeq)
case ArgPageSize:
fmt.Fprintf(buf, "%v", serializeAddr(a, false))
case ArgData:
fmt.Fprintf(buf, "\"%v\"", hex.EncodeToString(a.Data))
case ArgGroup:
var delims []byte
switch a.Type.(type) {
case *sys.StructType:
delims = []byte{'{', '}'}
case *sys.ArrayType:
delims = []byte{'[', ']'}
default:
panic("unknown group type")
}
buf.Write([]byte{delims[0]})
for i, a1 := range a.Inner {
if a1 != nil && sys.IsPad(a1.Type) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
a1.serialize(buf, vars, varSeq)
}
buf.Write([]byte{delims[1]})
case ArgUnion:
fmt.Fprintf(buf, "@%v=", a.OptionType.Name())
a.Option.serialize(buf, vars, varSeq)
default:
panic("unknown arg kind")
}
}
func Deserialize(data []byte) (prog *Prog, err error) {
prog = new(Prog)
p := &parser{r: bufio.NewScanner(bytes.NewReader(data))}
p.r.Buffer(nil, maxLineLen)
vars := make(map[string]*Arg)
for p.Scan() {
if p.EOF() || p.Char() == '#' {
continue
}
name := p.Ident()
r := ""
if p.Char() == '=' {
r = name
p.Parse('=')
name = p.Ident()
}
meta := sys.CallMap[name]
if meta == nil {
return nil, fmt.Errorf("unknown syscall %v", name)
}
c := &Call{
Meta: meta,
Ret: returnArg(meta.Ret),
}
prog.Calls = append(prog.Calls, c)
p.Parse('(')
for i := 0; p.Char() != ')'; i++ {
if i >= len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", i+1, len(meta.Args))
}
typ := meta.Args[i]
if sys.IsPad(typ) {
return nil, fmt.Errorf("padding in syscall %v arguments", name)
}
arg, err := parseArg(typ, p, vars)
if err != nil {
return nil, err
}
c.Args = append(c.Args, arg)
if p.Char() != ')' {
p.Parse(',')
}
}
p.Parse(')')
if !p.EOF() {
return nil, fmt.Errorf("tailing data (line #%v)", p.l)
}
if len(c.Args) != len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
}
if r != "" {
vars[r] = c.Ret
}
}
if err := p.Err(); err != nil {
return nil, err
}
if err := prog.validate(); err != nil {
return nil, err
}
return
}
func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
r = p.Ident()
p.Parse('=')
p.Parse('>')
}
var arg *Arg
switch p.Char() {
case '0':
val := p.Ident()
v, err := strconv.ParseUint(val, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong arg value '%v': %v", val, err)
}
arg = constArg(typ, uintptr(v))
case 'r':
id := p.Ident()
v, ok := vars[id]
if !ok || v == nil {
return nil, fmt.Errorf("result %v references unknown variable (vars=%+v)", id, vars)
}
arg = resultArg(typ, v)
if p.Char() == '/' {
p.Parse('/')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result div op: '%v'", op)
}
arg.OpDiv = uintptr(v)
}
if p.Char() == '+' {
p.Parse('+')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result add op: '%v'", op)
}
arg.OpAdd = uintptr(v)
}
case '&':
var typ1 sys.Type
switch t1 := typ.(type) {
case *sys.PtrType:
typ1 = t1.Type
case *sys.VmaType:
default:
return nil, fmt.Errorf("& arg is not a pointer: %#v", typ)
}
p.Parse('&')
page, off, size, err := parseAddr(p, true)
if err != nil {
return nil, err
}
p.Parse('=')
inner, err := parseArg(typ1, p, vars)
if err != nil {
return nil, err
}
arg = pointerArg(typ, page, off, size, inner)
case '(':
page, off, _, err := parseAddr(p, false)
if err != nil {
return nil, err
}
arg = pageSizeArg(typ, page, off)
case '"':
p.Parse('"')
val := ""
if p.Char() != '"' {
val = p.Ident()
}
p.Parse('"')
data, err := hex.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("data arg has bad value '%v'", val)
}
arg = dataArg(typ, data)
case '{':
t1, ok := typ.(*sys.StructType)
if !ok {
return nil, fmt.Errorf("'{' arg is not a struct: %#v", typ)
}
p.Parse('{')
var inner []*Arg
for i := 0; p.Char() != '}'; i++ {
if i >= len(t1.Fields) {
return nil, fmt.Errorf("wrong struct arg count: %v, want %v", i+1, len(t1.Fields))
}
fld := t1.Fields[i]
if sys.IsPad(fld) {
inner = append(inner, constArg(fld, 0))
} else {
arg, err := parseArg(fld, p, vars)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != '}' {
p.Parse(',')
}
}
}
p.Parse('}')
if last := t1.Fields[len(t1.Fields)-1]; sys.IsPad(last) {
inner = append(inner, constArg(last, 0))
}
arg = groupArg(typ, inner)
case '[':
t1, ok := typ.(*sys.ArrayType)
if !ok {
return nil, fmt.Errorf("'[' arg is not an array: %#v", typ)
}
p.Parse('[')
var inner []*Arg
for i := 0; p.Char() != ']'; i++ {
arg, err := parseArg(t1.Type, p, vars)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != ']' {
p.Parse(',')
}
}
p.Parse(']')
arg = groupArg(typ, inner)
case '@':
t1, ok := typ.(*sys.UnionType)
if !ok {
return nil, fmt.Errorf("'@' arg is not a union: %#v", typ)
}
p.Parse('@')
name := p.Ident()
p.Parse('=')
var optType sys.Type
for _, t2 := range t1.Options {
if name == t2.Name() {
optType = t2
break
}
}
if optType == nil {
return nil, fmt.Errorf("union arg %v has unknown option: %v", typ.Name(), name)
}
opt, err := parseArg(optType, p, vars)
if err != nil {
return nil, err
}
arg = unionArg(typ, opt, optType)
case 'n':
p.Parse('n')
p.Parse('i')
p.Parse('l')
if r != "" {
return nil, fmt.Errorf("named nil argument")
}
default:
return nil, fmt.Errorf("failed to parse argument at %v (line #%v/%v: %v)", int(p.Char()), p.l, p.i, p.s)
}
if r != "" {
vars[r] = arg
}
return arg, nil
}
const (
encodingAddrBase = 0x7f0000000000
encodingPageSize = 4 << 10
maxLineLen = 64 << 10
)
func serializeAddr(a *Arg, base bool) string {
page := a.AddrPage * encodingPageSize
if base {
page += encodingAddrBase
}
soff := ""
if off := a.AddrOffset; off != 0 {
sign := "+"
if off < 0 {
sign = "-"
off = -off
page += encodingPageSize
}
soff = fmt.Sprintf("%v0x%x", sign, off)
}
ssize := ""
if size := a.AddrPagesNum; size != 0 {
size *= encodingPageSize
ssize = fmt.Sprintf("/0x%x", size)
}
return fmt.Sprintf("(0x%x%v%v)", page, soff, ssize)
}
func parseAddr(p *parser, base bool) (uintptr, int, uintptr, error) {
p.Parse('(')
pstr := p.Ident()
page, err := strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr page: '%v'", pstr)
}
if page%encodingPageSize != 0 {
return 0, 0, 0, fmt.Errorf("address base is not page size aligned: '%v'", pstr)
}
if base {
if page < encodingAddrBase {
return 0, 0, 0, fmt.Errorf("address without base offset: '%v'", pstr)
}
page -= encodingAddrBase
}
var off int64
if p.Char() == '+' || p.Char() == '-' {
minus := false
if p.Char() == '-' {
minus = true
p.Parse('-')
} else {
p.Parse('+')
}
ostr := p.Ident()
off, err = strconv.ParseInt(ostr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr offset: '%v'", ostr)
}
if minus {
page -= encodingPageSize
off = -off
}
}
var size uint64
if p.Char() == '/' {
p.Parse('/')
pstr := p.Ident()
size, err = strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr size: '%v'", pstr)
}
}
p.Parse(')')
page /= encodingPageSize
size /= encodingPageSize
return uintptr(page), int(off), uintptr(size), nil
}
type parser struct {
r *bufio.Scanner
s string
i int
l int
e error
}
func (p *parser) Scan() bool {
if p.e != nil {
return false
}
if !p.r.Scan() {
p.e = p.r.Err()
return false
}
p.s = p.r.Text()
p.i = 0
p.l++
return true
}
func (p *parser) Err() error {
return p.e
}
func (p *parser) Str() string {
return p.s
}
func (p *parser) EOF() bool {
return p.i == len(p.s)
}
func (p *parser) Char() byte {
if p.e != nil {
return 0
}
if p.EOF() {
p.failf("unexpected eof")
return 0
}
return p.s[p.i]
}
func (p *parser) Parse(ch byte) {
if p.e != nil {
return
}
if p.EOF() {
p.failf("want %s, got EOF", string(ch))
return
}
if p.s[p.i] != ch {
p.failf("want '%v', got '%v'", string(ch), string(p.s[p.i]))
return
}
p.i++
p.SkipWs()
}
func (p *parser) SkipWs() {
for p.i < len(p.s) && (p.s[p.i] == ' ' || p.s[p.i] == '\t') {
p.i++
}
}
func (p *parser) Ident() string {
i := p.i
for p.i < len(p.s) &&
(p.s[p.i] >= 'a' && p.s[p.i] <= 'z' ||
p.s[p.i] >= 'A' && p.s[p.i] <= 'Z' ||
p.s[p.i] >= '0' && p.s[p.i] <= '9' ||
p.s[p.i] == '_' || p.s[p.i] == '$') {
p.i++
}
if i == p.i {
p.failf("failed to parse identifier at pos %v", i)
return ""
}
if ch := p.s[i]; ch >= '0' && ch <= '9' {
}
s := p.s[i:p.i]
p.SkipWs()
return s
}
func (p *parser) failf(msg string, args ...interface{}) {
p.e = fmt.Errorf("%v\nline #%v: %v", fmt.Sprintf(msg, args...), p.l, p.s)
}
// CallSet returns a set of all calls in the program.
// It does very conservative parsing and is intended to parse paste/future serialization formats.
func CallSet(data []byte) (map[string]struct{}, error) {
calls := make(map[string]struct{})
s := bufio.NewScanner(bytes.NewReader(data))
s.Buffer(nil, maxLineLen)
for s.Scan() {
ln := s.Bytes()
if len(ln) == 0 || ln[0] == '#' {
continue
}
bracket := bytes.IndexByte(ln, '(')
if bracket == -1 {
return nil, fmt.Errorf("line does not contain opening bracket")
}
call := ln[:bracket]
if eq := bytes.IndexByte(call, '='); eq != -1 {
eq++
for eq < len(call) && call[eq] == ' ' {
eq++
}
call = call[eq:]
}
if len(call) == 0 {
return nil, fmt.Errorf("call name is empty")
}
calls[string(call)] = struct{}{}
}
if err := s.Err(); err != nil {
return nil, err
}
if len(calls) == 0 {
return nil, fmt.Errorf("program does not contain any calls")
}
return calls, nil
}