syzkaller/prog/encoding.go
Dmitry Vyukov 40723a067e prog: validate deserialized programs
The optimization change removed validation too aggressively.
We do need program validation during deserialization,
because we can get bad programs from corpus or hub.
Restore program validation after deserialization.
2017-01-24 10:53:21 +01:00

561 lines
12 KiB
Go

// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"io"
"strconv"
"github.com/google/syzkaller/sys"
)
// String generates a very compact program description (mostly for debug output).
func (p *Prog) String() string {
buf := new(bytes.Buffer)
for i, c := range p.Calls {
if i != 0 {
fmt.Fprintf(buf, "-")
}
fmt.Fprintf(buf, "%v", c.Meta.Name)
}
return buf.String()
}
func (p *Prog) Serialize() []byte {
if debug {
if err := p.validate(); err != nil {
panic("serializing invalid program")
}
}
buf := new(bytes.Buffer)
vars := make(map[*Arg]int)
varSeq := 0
for _, c := range p.Calls {
if len(c.Ret.Uses) != 0 {
fmt.Fprintf(buf, "r%v = ", varSeq)
vars[c.Ret] = varSeq
varSeq++
}
fmt.Fprintf(buf, "%v(", c.Meta.Name)
for i, a := range c.Args {
if sys.IsPad(a.Type) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
a.serialize(buf, vars, &varSeq)
}
fmt.Fprintf(buf, ")\n")
}
return buf.Bytes()
}
func (a *Arg) serialize(buf io.Writer, vars map[*Arg]int, varSeq *int) {
if a == nil {
fmt.Fprintf(buf, "nil")
return
}
if len(a.Uses) != 0 {
fmt.Fprintf(buf, "<r%v=>", *varSeq)
vars[a] = *varSeq
*varSeq++
}
switch a.Kind {
case ArgConst:
fmt.Fprintf(buf, "0x%x", a.Val)
case ArgResult:
id, ok := vars[a.Res]
if !ok {
panic("no result")
}
fmt.Fprintf(buf, "r%v", id)
if a.OpDiv != 0 {
fmt.Fprintf(buf, "/%v", a.OpDiv)
}
if a.OpAdd != 0 {
fmt.Fprintf(buf, "+%v", a.OpAdd)
}
case ArgPointer:
fmt.Fprintf(buf, "&%v=", serializeAddr(a, true))
a.Res.serialize(buf, vars, varSeq)
case ArgPageSize:
fmt.Fprintf(buf, "%v", serializeAddr(a, false))
case ArgData:
fmt.Fprintf(buf, "\"%v\"", hex.EncodeToString(a.Data))
case ArgGroup:
var delims []byte
switch a.Type.(type) {
case *sys.StructType:
delims = []byte{'{', '}'}
case *sys.ArrayType:
delims = []byte{'[', ']'}
default:
panic("unknown group type")
}
buf.Write([]byte{delims[0]})
for i, a1 := range a.Inner {
if a1 != nil && sys.IsPad(a1.Type) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
a1.serialize(buf, vars, varSeq)
}
buf.Write([]byte{delims[1]})
case ArgUnion:
fmt.Fprintf(buf, "@%v=", a.OptionType.FieldName())
a.Option.serialize(buf, vars, varSeq)
default:
panic("unknown arg kind")
}
}
func Deserialize(data []byte) (prog *Prog, err error) {
prog = new(Prog)
p := &parser{r: bufio.NewScanner(bytes.NewReader(data))}
p.r.Buffer(nil, maxLineLen)
vars := make(map[string]*Arg)
for p.Scan() {
if p.EOF() || p.Char() == '#' {
continue
}
name := p.Ident()
r := ""
if p.Char() == '=' {
r = name
p.Parse('=')
name = p.Ident()
}
meta := sys.CallMap[name]
if meta == nil {
return nil, fmt.Errorf("unknown syscall %v", name)
}
c := &Call{
Meta: meta,
Ret: returnArg(meta.Ret),
}
prog.Calls = append(prog.Calls, c)
p.Parse('(')
for i := 0; p.Char() != ')'; i++ {
if i >= len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", i+1, len(meta.Args))
}
typ := meta.Args[i]
if sys.IsPad(typ) {
return nil, fmt.Errorf("padding in syscall %v arguments", name)
}
arg, err := parseArg(typ, p, vars)
if err != nil {
return nil, err
}
c.Args = append(c.Args, arg)
if p.Char() != ')' {
p.Parse(',')
}
}
p.Parse(')')
if !p.EOF() {
return nil, fmt.Errorf("tailing data (line #%v)", p.l)
}
if len(c.Args) != len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
}
if r != "" {
vars[r] = c.Ret
}
}
if err := p.Err(); err != nil {
return nil, err
}
// This validation is done even in non-debug mode because deserialization
// procedure does not catch all bugs (e.g. mismatched types).
// And we can receive bad programs from corpus and hub.
if err := prog.validate(); err != nil {
return nil, err
}
return
}
func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
r = p.Ident()
p.Parse('=')
p.Parse('>')
}
var arg *Arg
switch p.Char() {
case '0':
val := p.Ident()
v, err := strconv.ParseUint(val, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong arg value '%v': %v", val, err)
}
arg = constArg(typ, uintptr(v))
case 'r':
id := p.Ident()
v, ok := vars[id]
if !ok || v == nil {
return nil, fmt.Errorf("result %v references unknown variable (vars=%+v)", id, vars)
}
arg = resultArg(typ, v)
if p.Char() == '/' {
p.Parse('/')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result div op: '%v'", op)
}
arg.OpDiv = uintptr(v)
}
if p.Char() == '+' {
p.Parse('+')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result add op: '%v'", op)
}
arg.OpAdd = uintptr(v)
}
case '&':
var typ1 sys.Type
switch t1 := typ.(type) {
case *sys.PtrType:
typ1 = t1.Type
case *sys.VmaType:
default:
return nil, fmt.Errorf("& arg is not a pointer: %#v", typ)
}
p.Parse('&')
page, off, size, err := parseAddr(p, true)
if err != nil {
return nil, err
}
p.Parse('=')
inner, err := parseArg(typ1, p, vars)
if err != nil {
return nil, err
}
arg = pointerArg(typ, page, off, size, inner)
case '(':
page, off, _, err := parseAddr(p, false)
if err != nil {
return nil, err
}
arg = pageSizeArg(typ, page, off)
case '"':
p.Parse('"')
val := ""
if p.Char() != '"' {
val = p.Ident()
}
p.Parse('"')
data, err := hex.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("data arg has bad value '%v'", val)
}
arg = dataArg(typ, data)
case '{':
t1, ok := typ.(*sys.StructType)
if !ok {
return nil, fmt.Errorf("'{' arg is not a struct: %#v", typ)
}
p.Parse('{')
var inner []*Arg
for i := 0; p.Char() != '}'; i++ {
if i >= len(t1.Fields) {
return nil, fmt.Errorf("wrong struct arg count: %v, want %v", i+1, len(t1.Fields))
}
fld := t1.Fields[i]
if sys.IsPad(fld) {
inner = append(inner, constArg(fld, 0))
} else {
arg, err := parseArg(fld, p, vars)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != '}' {
p.Parse(',')
}
}
}
p.Parse('}')
if last := t1.Fields[len(t1.Fields)-1]; sys.IsPad(last) {
inner = append(inner, constArg(last, 0))
}
arg = groupArg(typ, inner)
case '[':
t1, ok := typ.(*sys.ArrayType)
if !ok {
return nil, fmt.Errorf("'[' arg is not an array: %#v", typ)
}
p.Parse('[')
var inner []*Arg
for i := 0; p.Char() != ']'; i++ {
arg, err := parseArg(t1.Type, p, vars)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != ']' {
p.Parse(',')
}
}
p.Parse(']')
arg = groupArg(typ, inner)
case '@':
t1, ok := typ.(*sys.UnionType)
if !ok {
return nil, fmt.Errorf("'@' arg is not a union: %#v", typ)
}
p.Parse('@')
name := p.Ident()
p.Parse('=')
var optType sys.Type
for _, t2 := range t1.Options {
if name == t2.FieldName() {
optType = t2
break
}
}
if optType == nil {
return nil, fmt.Errorf("union arg %v has unknown option: %v", typ.Name(), name)
}
opt, err := parseArg(optType, p, vars)
if err != nil {
return nil, err
}
arg = unionArg(typ, opt, optType)
case 'n':
p.Parse('n')
p.Parse('i')
p.Parse('l')
if r != "" {
return nil, fmt.Errorf("named nil argument")
}
default:
return nil, fmt.Errorf("failed to parse argument at %v (line #%v/%v: %v)", int(p.Char()), p.l, p.i, p.s)
}
if r != "" {
vars[r] = arg
}
return arg, nil
}
const (
encodingAddrBase = 0x7f0000000000
encodingPageSize = 4 << 10
maxLineLen = 256 << 10
)
func serializeAddr(a *Arg, base bool) string {
page := a.AddrPage * encodingPageSize
if base {
page += encodingAddrBase
}
soff := ""
if off := a.AddrOffset; off != 0 {
sign := "+"
if off < 0 {
sign = "-"
off = -off
page += encodingPageSize
}
soff = fmt.Sprintf("%v0x%x", sign, off)
}
ssize := ""
if size := a.AddrPagesNum; size != 0 {
size *= encodingPageSize
ssize = fmt.Sprintf("/0x%x", size)
}
return fmt.Sprintf("(0x%x%v%v)", page, soff, ssize)
}
func parseAddr(p *parser, base bool) (uintptr, int, uintptr, error) {
p.Parse('(')
pstr := p.Ident()
page, err := strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr page: '%v'", pstr)
}
if page%encodingPageSize != 0 {
return 0, 0, 0, fmt.Errorf("address base is not page size aligned: '%v'", pstr)
}
if base {
if page < encodingAddrBase {
return 0, 0, 0, fmt.Errorf("address without base offset: '%v'", pstr)
}
page -= encodingAddrBase
}
var off int64
if p.Char() == '+' || p.Char() == '-' {
minus := false
if p.Char() == '-' {
minus = true
p.Parse('-')
} else {
p.Parse('+')
}
ostr := p.Ident()
off, err = strconv.ParseInt(ostr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr offset: '%v'", ostr)
}
if minus {
page -= encodingPageSize
off = -off
}
}
var size uint64
if p.Char() == '/' {
p.Parse('/')
pstr := p.Ident()
size, err = strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse addr size: '%v'", pstr)
}
}
p.Parse(')')
page /= encodingPageSize
size /= encodingPageSize
return uintptr(page), int(off), uintptr(size), nil
}
type parser struct {
r *bufio.Scanner
s string
i int
l int
e error
}
func (p *parser) Scan() bool {
if p.e != nil {
return false
}
if !p.r.Scan() {
p.e = p.r.Err()
return false
}
p.s = p.r.Text()
p.i = 0
p.l++
return true
}
func (p *parser) Err() error {
return p.e
}
func (p *parser) Str() string {
return p.s
}
func (p *parser) EOF() bool {
return p.i == len(p.s)
}
func (p *parser) Char() byte {
if p.e != nil {
return 0
}
if p.EOF() {
p.failf("unexpected eof")
return 0
}
return p.s[p.i]
}
func (p *parser) Parse(ch byte) {
if p.e != nil {
return
}
if p.EOF() {
p.failf("want %s, got EOF", string(ch))
return
}
if p.s[p.i] != ch {
p.failf("want '%v', got '%v'", string(ch), string(p.s[p.i]))
return
}
p.i++
p.SkipWs()
}
func (p *parser) SkipWs() {
for p.i < len(p.s) && (p.s[p.i] == ' ' || p.s[p.i] == '\t') {
p.i++
}
}
func (p *parser) Ident() string {
i := p.i
for p.i < len(p.s) &&
(p.s[p.i] >= 'a' && p.s[p.i] <= 'z' ||
p.s[p.i] >= 'A' && p.s[p.i] <= 'Z' ||
p.s[p.i] >= '0' && p.s[p.i] <= '9' ||
p.s[p.i] == '_' || p.s[p.i] == '$') {
p.i++
}
if i == p.i {
p.failf("failed to parse identifier at pos %v", i)
return ""
}
if ch := p.s[i]; ch >= '0' && ch <= '9' {
}
s := p.s[i:p.i]
p.SkipWs()
return s
}
func (p *parser) failf(msg string, args ...interface{}) {
p.e = fmt.Errorf("%v\nline #%v: %v", fmt.Sprintf(msg, args...), p.l, p.s)
}
// CallSet returns a set of all calls in the program.
// It does very conservative parsing and is intended to parse paste/future serialization formats.
func CallSet(data []byte) (map[string]struct{}, error) {
calls := make(map[string]struct{})
s := bufio.NewScanner(bytes.NewReader(data))
s.Buffer(nil, maxLineLen)
for s.Scan() {
ln := s.Bytes()
if len(ln) == 0 || ln[0] == '#' {
continue
}
bracket := bytes.IndexByte(ln, '(')
if bracket == -1 {
return nil, fmt.Errorf("line does not contain opening bracket")
}
call := ln[:bracket]
if eq := bytes.IndexByte(call, '='); eq != -1 {
eq++
for eq < len(call) && call[eq] == ' ' {
eq++
}
call = call[eq:]
}
if len(call) == 0 {
return nil, fmt.Errorf("call name is empty")
}
calls[string(call)] = struct{}{}
}
if err := s.Err(); err != nil {
return nil, err
}
if len(calls) == 0 {
return nil, fmt.Errorf("program does not contain any calls")
}
return calls, nil
}