syzkaller/prog/encoding.go
Dmitry Vyukov ba64d006de prog: implement strict parsing mode
Add bulk of checks for strict parsing mode.
Probably not complete, but we can extend then in future as needed.
Turns out we can't easily use it for serialized programs
as they omit default args and during deserialization it looks like missing args.
2018-12-10 16:37:01 +01:00

962 lines
20 KiB
Go

// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"strconv"
"strings"
)
// String generates a very compact program description (mostly for debug output).
func (p *Prog) String() string {
buf := new(bytes.Buffer)
for i, c := range p.Calls {
if i != 0 {
fmt.Fprintf(buf, "-")
}
fmt.Fprintf(buf, "%v", c.Meta.Name)
}
return buf.String()
}
func (p *Prog) Serialize() []byte {
p.debugValidate()
ctx := &serializer{
target: p.Target,
buf: new(bytes.Buffer),
vars: make(map[*ResultArg]int),
}
for _, c := range p.Calls {
ctx.call(c)
}
return ctx.buf.Bytes()
}
type serializer struct {
target *Target
buf *bytes.Buffer
vars map[*ResultArg]int
varSeq int
}
func (ctx *serializer) printf(text string, args ...interface{}) {
fmt.Fprintf(ctx.buf, text, args...)
}
func (ctx *serializer) allocVarID(arg *ResultArg) int {
id := ctx.varSeq
ctx.varSeq++
ctx.vars[arg] = id
return id
}
func (ctx *serializer) call(c *Call) {
if c.Ret != nil && len(c.Ret.uses) != 0 {
ctx.printf("r%v = ", ctx.allocVarID(c.Ret))
}
ctx.printf("%v(", c.Meta.Name)
for i, a := range c.Args {
if IsPad(a.Type()) {
continue
}
if i != 0 {
ctx.printf(", ")
}
ctx.arg(a)
}
ctx.printf(")\n")
}
func (ctx *serializer) arg(arg Arg) {
if arg == nil {
ctx.printf("nil")
return
}
arg.serialize(ctx)
}
func (a *ConstArg) serialize(ctx *serializer) {
ctx.printf("0x%x", a.Val)
}
func (a *PointerArg) serialize(ctx *serializer) {
if a.IsSpecial() {
ctx.printf("0x%x", a.Address)
return
}
target := ctx.target
ctx.printf("&%v", target.serializeAddr(a))
if a.Res != nil && isDefault(a.Res) && !target.isAnyPtr(a.Type()) {
return
}
ctx.printf("=")
if target.isAnyPtr(a.Type()) {
ctx.printf("ANY=")
}
ctx.arg(a.Res)
}
func (a *DataArg) serialize(ctx *serializer) {
if a.Type().Dir() == DirOut {
ctx.printf("\"\"/%v", a.Size())
return
}
data := a.Data()
if !a.Type().Varlen() {
// Statically typed data will be padded with 0s during
// deserialization, so we can strip them here for readability.
for len(data) >= 2 && data[len(data)-1] == 0 && data[len(data)-2] == 0 {
data = data[:len(data)-1]
}
}
serializeData(ctx.buf, data)
}
func (a *GroupArg) serialize(ctx *serializer) {
var delims []byte
switch a.Type().(type) {
case *StructType:
delims = []byte{'{', '}'}
case *ArrayType:
delims = []byte{'[', ']'}
default:
panic("unknown group type")
}
ctx.buf.WriteByte(delims[0])
lastNonDefault := len(a.Inner) - 1
if a.fixedInnerSize() {
for ; lastNonDefault >= 0; lastNonDefault-- {
if !isDefault(a.Inner[lastNonDefault]) {
break
}
}
}
for i := 0; i <= lastNonDefault; i++ {
arg1 := a.Inner[i]
if arg1 != nil && IsPad(arg1.Type()) {
continue
}
if i != 0 {
ctx.printf(", ")
}
ctx.arg(arg1)
}
ctx.buf.WriteByte(delims[1])
}
func (a *UnionArg) serialize(ctx *serializer) {
ctx.printf("@%v", a.Option.Type().FieldName())
if isDefault(a.Option) {
return
}
ctx.printf("=")
ctx.arg(a.Option)
}
func (a *ResultArg) serialize(ctx *serializer) {
if len(a.uses) != 0 {
ctx.printf("<r%v=>", ctx.allocVarID(a))
}
if a.Res == nil {
ctx.printf("0x%x", a.Val)
return
}
id, ok := ctx.vars[a.Res]
if !ok {
panic("no result")
}
ctx.printf("r%v", id)
if a.OpDiv != 0 {
ctx.printf("/%v", a.OpDiv)
}
if a.OpAdd != 0 {
ctx.printf("+%v", a.OpAdd)
}
}
type DeserializeMode int
const (
Strict DeserializeMode = iota
NonStrict DeserializeMode = iota
)
func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, error) {
p := newParser(target, data, mode == Strict)
prog, err := p.parseProg()
if err := p.Err(); err != nil {
return nil, err
}
if err != nil {
return nil, err
}
// This validation is done even in non-debug mode because deserialization
// procedure does not catch all bugs (e.g. mismatched types).
// And we can receive bad programs from corpus and hub.
if err := prog.validate(); err != nil {
return nil, err
}
for _, c := range prog.Calls {
target.SanitizeCall(c)
}
return prog, nil
}
func (p *parser) parseProg() (*Prog, error) {
prog := &Prog{
Target: p.target,
}
for p.Scan() {
if p.EOF() {
if p.comment != "" {
prog.Comments = append(prog.Comments, p.comment)
p.comment = ""
}
continue
}
if p.Char() == '#' {
if p.comment != "" {
prog.Comments = append(prog.Comments, p.comment)
}
p.comment = strings.TrimSpace(p.s[p.i+1:])
continue
}
name := p.Ident()
r := ""
if p.Char() == '=' {
r = name
p.Parse('=')
name = p.Ident()
}
meta := p.target.SyscallMap[name]
if meta == nil {
return nil, fmt.Errorf("unknown syscall %v", name)
}
c := &Call{
Meta: meta,
Ret: MakeReturnArg(meta.Ret),
Comment: p.comment,
}
prog.Calls = append(prog.Calls, c)
p.Parse('(')
for i := 0; p.Char() != ')'; i++ {
if i >= len(meta.Args) {
p.eatExcessive(false, "excessive syscall arguments")
break
}
typ := meta.Args[i]
if IsPad(typ) {
return nil, fmt.Errorf("padding in syscall %v arguments", name)
}
arg, err := p.parseArg(typ)
if err != nil {
return nil, err
}
c.Args = append(c.Args, arg)
if p.Char() != ')' {
p.Parse(',')
}
}
p.Parse(')')
p.SkipWs()
if !p.EOF() {
if p.Char() != '#' {
return nil, fmt.Errorf("tailing data (line #%v)", p.l)
}
if c.Comment != "" {
prog.Comments = append(prog.Comments, c.Comment)
}
c.Comment = strings.TrimSpace(p.s[p.i+1:])
}
for i := len(c.Args); i < len(meta.Args); i++ {
p.strictFailf("missing syscall args")
c.Args = append(c.Args, meta.Args[i].DefaultArg())
}
if len(c.Args) != len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
}
if r != "" && c.Ret != nil {
p.vars[r] = c.Ret
}
p.comment = ""
}
if p.comment != "" {
prog.Comments = append(prog.Comments, p.comment)
}
return prog, nil
}
func (p *parser) parseArg(typ Type) (Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
r = p.Ident()
p.Parse('=')
p.Parse('>')
}
arg, err := p.parseArgImpl(typ)
if err != nil {
return nil, err
}
if arg == nil {
if typ != nil {
arg = typ.DefaultArg()
} else if r != "" {
return nil, fmt.Errorf("named nil argument")
}
}
if r != "" {
if res, ok := arg.(*ResultArg); ok {
p.vars[r] = res
}
}
return arg, nil
}
func (p *parser) parseArgImpl(typ Type) (Arg, error) {
switch p.Char() {
case '0':
return p.parseArgInt(typ)
case 'r':
return p.parseArgRes(typ)
case '&':
return p.parseArgAddr(typ)
case '"', '\'':
return p.parseArgString(typ)
case '{':
return p.parseArgStruct(typ)
case '[':
return p.parseArgArray(typ)
case '@':
return p.parseArgUnion(typ)
case 'n':
p.Parse('n')
p.Parse('i')
p.Parse('l')
return nil, nil
default:
return nil, fmt.Errorf("failed to parse argument at %v (line #%v/%v: %v)",
int(p.Char()), p.l, p.i, p.s)
}
}
func (p *parser) parseArgInt(typ Type) (Arg, error) {
val := p.Ident()
v, err := strconv.ParseUint(val, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong arg value '%v': %v", val, err)
}
switch typ.(type) {
case *ConstType, *IntType, *FlagsType, *ProcType, *LenType, *CsumType:
return MakeConstArg(typ, v), nil
case *ResourceType:
return MakeResultArg(typ, nil, v), nil
case *PtrType, *VmaType:
index := -v % uint64(len(p.target.SpecialPointers))
return MakeSpecialPointerArg(typ, index), nil
default:
p.eatExcessive(true, "wrong int arg")
return typ.DefaultArg(), nil
}
}
func (p *parser) parseArgRes(typ Type) (Arg, error) {
id := p.Ident()
var div, add uint64
if p.Char() == '/' {
p.Parse('/')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result div op: '%v'", op)
}
div = v
}
if p.Char() == '+' {
p.Parse('+')
op := p.Ident()
v, err := strconv.ParseUint(op, 0, 64)
if err != nil {
return nil, fmt.Errorf("wrong result add op: '%v'", op)
}
add = v
}
v := p.vars[id]
if v == nil {
p.strictFailf("undeclared variable %v", id)
return typ.DefaultArg(), nil
}
arg := MakeResultArg(typ, v, 0)
arg.OpDiv = div
arg.OpAdd = add
return arg, nil
}
func (p *parser) parseArgAddr(typ Type) (Arg, error) {
var typ1 Type
switch t1 := typ.(type) {
case *PtrType:
typ1 = t1.Type
case *VmaType:
default:
p.eatExcessive(true, "wrong addr arg")
return typ.DefaultArg(), nil
}
p.Parse('&')
addr, vmaSize, err := p.parseAddr()
if err != nil {
return nil, err
}
var inner Arg
if p.Char() == '=' {
p.Parse('=')
if p.Char() == 'A' {
p.Parse('A')
p.Parse('N')
p.Parse('Y')
p.Parse('=')
typ = p.target.makeAnyPtrType(typ.Size(), typ.FieldName())
typ1 = p.target.any.array
}
inner, err = p.parseArg(typ1)
if err != nil {
return nil, err
}
}
if typ1 == nil {
return MakeVmaPointerArg(typ, addr, vmaSize), nil
}
if inner == nil {
inner = typ1.DefaultArg()
}
return MakePointerArg(typ, addr, inner), nil
}
func (p *parser) parseArgString(typ Type) (Arg, error) {
if _, ok := typ.(*BufferType); !ok {
p.eatExcessive(true, "wrong string arg")
return typ.DefaultArg(), nil
}
data, err := p.deserializeData()
if err != nil {
return nil, err
}
size := ^uint64(0)
if p.Char() == '/' {
p.Parse('/')
sizeStr := p.Ident()
size, err = strconv.ParseUint(sizeStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse buffer size: %q", sizeStr)
}
}
if !typ.Varlen() {
size = typ.Size()
} else if size == ^uint64(0) {
size = uint64(len(data))
}
if typ.Dir() == DirOut {
return MakeOutDataArg(typ, size), nil
}
if diff := int(size) - len(data); diff > 0 {
data = append(data, make([]byte, diff)...)
}
data = data[:size]
return MakeDataArg(typ, data), nil
}
func (p *parser) parseArgStruct(typ Type) (Arg, error) {
p.Parse('{')
t1, ok := typ.(*StructType)
if !ok {
p.eatExcessive(false, "wrong struct arg")
p.Parse('}')
return typ.DefaultArg(), nil
}
var inner []Arg
for i := 0; p.Char() != '}'; i++ {
if i >= len(t1.Fields) {
p.eatExcessive(false, "excessive struct %v fields", typ.Name())
break
}
fld := t1.Fields[i]
if IsPad(fld) {
inner = append(inner, MakeConstArg(fld, 0))
} else {
arg, err := p.parseArg(fld)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != '}' {
p.Parse(',')
}
}
}
p.Parse('}')
for len(inner) < len(t1.Fields) {
fld := t1.Fields[len(inner)]
if !IsPad(fld) {
p.strictFailf("missing struct %v fields %v/%v", typ.Name(), len(inner), len(t1.Fields))
}
inner = append(inner, fld.DefaultArg())
}
return MakeGroupArg(typ, inner), nil
}
func (p *parser) parseArgArray(typ Type) (Arg, error) {
p.Parse('[')
t1, ok := typ.(*ArrayType)
if !ok {
p.eatExcessive(false, "wrong array arg")
p.Parse(']')
return typ.DefaultArg(), nil
}
var inner []Arg
for i := 0; p.Char() != ']'; i++ {
arg, err := p.parseArg(t1.Type)
if err != nil {
return nil, err
}
inner = append(inner, arg)
if p.Char() != ']' {
p.Parse(',')
}
}
p.Parse(']')
if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd {
for uint64(len(inner)) < t1.RangeBegin {
p.strictFailf("missing array elements")
inner = append(inner, t1.Type.DefaultArg())
}
inner = inner[:t1.RangeBegin]
}
return MakeGroupArg(typ, inner), nil
}
func (p *parser) parseArgUnion(typ Type) (Arg, error) {
t1, ok := typ.(*UnionType)
if !ok {
p.eatExcessive(true, "wrong union arg")
return typ.DefaultArg(), nil
}
p.Parse('@')
name := p.Ident()
var optType Type
for _, t2 := range t1.Fields {
if name == t2.FieldName() {
optType = t2
break
}
}
if optType == nil {
p.eatExcessive(true, "wrong union option")
return typ.DefaultArg(), nil
}
var opt Arg
if p.Char() == '=' {
p.Parse('=')
var err error
opt, err = p.parseArg(optType)
if err != nil {
return nil, err
}
} else {
opt = optType.DefaultArg()
}
return MakeUnionArg(typ, opt), nil
}
// Eats excessive call arguments and struct fields to recover after description changes.
func (p *parser) eatExcessive(stopAtComma bool, what string, args ...interface{}) {
if p.strict {
p.failf(what, args...)
}
paren, brack, brace := 0, 0, 0
for !p.EOF() && p.e == nil {
ch := p.Char()
switch ch {
case '(':
paren++
case ')':
if paren == 0 {
return
}
paren--
case '[':
brack++
case ']':
if brack == 0 {
return
}
brack--
case '{':
brace++
case '}':
if brace == 0 {
return
}
brace--
case ',':
if stopAtComma && paren == 0 && brack == 0 && brace == 0 {
return
}
case '\'', '"':
p.Parse(ch)
for !p.EOF() && p.Char() != ch {
p.Parse(p.Char())
}
if p.EOF() {
return
}
}
p.Parse(ch)
}
}
const (
encodingAddrBase = 0x7f0000000000
maxLineLen = 1 << 20
)
func (target *Target) serializeAddr(arg *PointerArg) string {
ssize := ""
if arg.VmaSize != 0 {
ssize = fmt.Sprintf("/0x%x", arg.VmaSize)
}
return fmt.Sprintf("(0x%x%v)", encodingAddrBase+arg.Address, ssize)
}
func (p *parser) parseAddr() (uint64, uint64, error) {
p.Parse('(')
pstr := p.Ident()
addr, err := strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse addr: %q", pstr)
}
if addr < encodingAddrBase {
return 0, 0, fmt.Errorf("address without base offset: %q", pstr)
}
addr -= encodingAddrBase
// This is not used anymore, but left here to parse old programs.
if p.Char() == '+' || p.Char() == '-' {
minus := false
if p.Char() == '-' {
minus = true
p.Parse('-')
} else {
p.Parse('+')
}
ostr := p.Ident()
off, err := strconv.ParseUint(ostr, 0, 64)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse addr offset: %q", ostr)
}
if minus {
off = -off
}
addr += off
}
target := p.target
maxMem := target.NumPages * target.PageSize
var vmaSize uint64
if p.Char() == '/' {
p.Parse('/')
pstr := p.Ident()
size, err := strconv.ParseUint(pstr, 0, 64)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse addr size: %q", pstr)
}
addr = addr & ^(target.PageSize - 1)
vmaSize = (size + target.PageSize - 1) & ^(target.PageSize - 1)
if vmaSize == 0 {
vmaSize = target.PageSize
}
if vmaSize > maxMem {
vmaSize = maxMem
}
if addr > maxMem-vmaSize {
addr = maxMem - vmaSize
}
}
p.Parse(')')
return addr, vmaSize, nil
}
func serializeData(buf *bytes.Buffer, data []byte) {
readable := true
for _, v := range data {
if v >= 0x20 && v < 0x7f {
continue
}
switch v {
case 0, '\a', '\b', '\f', '\n', '\r', '\t', '\v':
continue
}
readable = false
break
}
if !readable || len(data) == 0 {
fmt.Fprintf(buf, "\"%v\"", hex.EncodeToString(data))
return
}
buf.WriteByte('\'')
for _, v := range data {
switch v {
case 0:
buf.Write([]byte{'\\', 'x', '0', '0'})
case '\a':
buf.Write([]byte{'\\', 'a'})
case '\b':
buf.Write([]byte{'\\', 'b'})
case '\f':
buf.Write([]byte{'\\', 'f'})
case '\n':
buf.Write([]byte{'\\', 'n'})
case '\r':
buf.Write([]byte{'\\', 'r'})
case '\t':
buf.Write([]byte{'\\', 't'})
case '\v':
buf.Write([]byte{'\\', 'v'})
case '\'':
buf.Write([]byte{'\\', '\''})
case '\\':
buf.Write([]byte{'\\', '\\'})
default:
buf.WriteByte(v)
}
}
buf.WriteByte('\'')
}
func (p *parser) deserializeData() ([]byte, error) {
var data []byte
if p.Char() == '"' {
p.Parse('"')
val := ""
if p.Char() != '"' {
val = p.Ident()
}
p.Parse('"')
var err error
data, err = hex.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("data arg has bad value %q", val)
}
} else {
if p.consume() != '\'' {
return nil, fmt.Errorf("data arg does not start with \" nor with '")
}
for p.Char() != '\'' && p.Char() != 0 {
v := p.consume()
if v != '\\' {
data = append(data, v)
continue
}
v = p.consume()
switch v {
case 'x':
hi := p.consume()
lo := p.consume()
if lo != '0' || hi != '0' {
return nil, fmt.Errorf(
"invalid \\x%c%c escape sequence in data arg", hi, lo)
}
data = append(data, 0)
case 'a':
data = append(data, '\a')
case 'b':
data = append(data, '\b')
case 'f':
data = append(data, '\f')
case 'n':
data = append(data, '\n')
case 'r':
data = append(data, '\r')
case 't':
data = append(data, '\t')
case 'v':
data = append(data, '\v')
case '\'':
data = append(data, '\'')
case '\\':
data = append(data, '\\')
default:
return nil, fmt.Errorf("invalid \\%c escape sequence in data arg", v)
}
}
p.Parse('\'')
}
return data, nil
}
type parser struct {
target *Target
strict bool
vars map[string]*ResultArg
comment string
r *bufio.Scanner
s string
i int
l int
e error
}
func newParser(target *Target, data []byte, strict bool) *parser {
p := &parser{
target: target,
strict: strict,
vars: make(map[string]*ResultArg),
r: bufio.NewScanner(bytes.NewReader(data)),
}
p.r.Buffer(nil, maxLineLen)
return p
}
func (p *parser) Scan() bool {
if p.e != nil {
return false
}
if !p.r.Scan() {
p.e = p.r.Err()
return false
}
p.s = p.r.Text()
p.i = 0
p.l++
return true
}
func (p *parser) Err() error {
return p.e
}
func (p *parser) Str() string {
return p.s
}
func (p *parser) EOF() bool {
return p.i == len(p.s)
}
func (p *parser) Char() byte {
if p.e != nil {
return 0
}
if p.EOF() {
p.failf("unexpected eof")
return 0
}
return p.s[p.i]
}
func (p *parser) Parse(ch byte) {
if p.e != nil {
return
}
if p.EOF() {
p.failf("want %s, got EOF", string(ch))
return
}
if p.s[p.i] != ch {
p.failf("want '%v', got '%v'", string(ch), string(p.s[p.i]))
return
}
p.i++
p.SkipWs()
}
func (p *parser) consume() byte {
if p.e != nil {
return 0
}
if p.EOF() {
p.failf("unexpected eof")
return 0
}
v := p.s[p.i]
p.i++
return v
}
func (p *parser) SkipWs() {
for p.i < len(p.s) && (p.s[p.i] == ' ' || p.s[p.i] == '\t') {
p.i++
}
}
func (p *parser) Ident() string {
i := p.i
for p.i < len(p.s) &&
(p.s[p.i] >= 'a' && p.s[p.i] <= 'z' ||
p.s[p.i] >= 'A' && p.s[p.i] <= 'Z' ||
p.s[p.i] >= '0' && p.s[p.i] <= '9' ||
p.s[p.i] == '_' || p.s[p.i] == '$') {
p.i++
}
if i == p.i {
p.failf("failed to parse identifier at pos %v", i)
return ""
}
s := p.s[i:p.i]
p.SkipWs()
return s
}
func (p *parser) failf(msg string, args ...interface{}) {
if p.e == nil {
p.e = fmt.Errorf("%v\nline #%v:%v: %v", fmt.Sprintf(msg, args...), p.l, p.i, p.s)
}
}
func (p *parser) strictFailf(msg string, args ...interface{}) {
if p.strict {
p.failf(msg, args...)
}
}
// CallSet returns a set of all calls in the program.
// It does very conservative parsing and is intended to parse paste/future serialization formats.
func CallSet(data []byte) (map[string]struct{}, error) {
calls := make(map[string]struct{})
s := bufio.NewScanner(bytes.NewReader(data))
s.Buffer(nil, maxLineLen)
for s.Scan() {
ln := s.Bytes()
if len(ln) == 0 || ln[0] == '#' {
continue
}
bracket := bytes.IndexByte(ln, '(')
if bracket == -1 {
return nil, fmt.Errorf("line does not contain opening bracket")
}
call := ln[:bracket]
if eq := bytes.IndexByte(call, '='); eq != -1 {
eq++
for eq < len(call) && call[eq] == ' ' {
eq++
}
call = call[eq:]
}
if len(call) == 0 {
return nil, fmt.Errorf("call name is empty")
}
calls[string(call)] = struct{}{}
}
if err := s.Err(); err != nil {
return nil, err
}
if len(calls) == 0 {
return nil, fmt.Errorf("program does not contain any calls")
}
return calls, nil
}