syzkaller/pkg/compiler/gen.go
Dmitry Vyukov 3a4641d90c pkg/compiler: refactor structGen
Still too complex. Split more.

Update #538
2018-08-02 16:57:31 +02:00

447 lines
11 KiB
Go

// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package compiler
import (
"fmt"
"sort"
"github.com/google/syzkaller/pkg/ast"
"github.com/google/syzkaller/prog"
)
const sizeUnassigned = ^uint64(0)
func (comp *compiler) genResources() []*prog.ResourceDesc {
var resources []*prog.ResourceDesc
for name, n := range comp.resources {
if !comp.used[name] {
continue
}
resources = append(resources, comp.genResource(n))
}
sort.Slice(resources, func(i, j int) bool {
return resources[i].Name < resources[j].Name
})
return resources
}
func (comp *compiler) genResource(n *ast.Resource) *prog.ResourceDesc {
res := &prog.ResourceDesc{
Name: n.Name.Name,
}
var base *ast.Type
for n != nil {
res.Values = append(genIntArray(n.Values), res.Values...)
res.Kind = append([]string{n.Name.Name}, res.Kind...)
base = n.Base
n = comp.resources[n.Base.Ident]
}
if len(res.Values) == 0 {
res.Values = []uint64{0}
}
res.Type = comp.genType(base, "", prog.DirIn, false)
return res
}
func (comp *compiler) genSyscalls() []*prog.Syscall {
var calls []*prog.Syscall
for _, decl := range comp.desc.Nodes {
if n, ok := decl.(*ast.Call); ok && n.NR != ^uint64(0) {
calls = append(calls, comp.genSyscall(n))
}
}
sort.Slice(calls, func(i, j int) bool {
return calls[i].Name < calls[j].Name
})
return calls
}
func (comp *compiler) genSyscall(n *ast.Call) *prog.Syscall {
var ret prog.Type
if n.Ret != nil {
ret = comp.genType(n.Ret, "ret", prog.DirOut, true)
}
return &prog.Syscall{
Name: n.Name.Name,
CallName: n.CallName,
NR: n.NR,
Args: comp.genFieldArray(n.Args, prog.DirIn, true),
Ret: ret,
}
}
func (comp *compiler) genStructDescs(syscalls []*prog.Syscall) []*prog.KeyedStruct {
// Calculate struct/union/array sizes, add padding to structs and detach
// StructDesc's from StructType's. StructType's can be recursive so it's
// not possible to write them out inline as other types. To break the
// recursion detach them, and write StructDesc's out as separate array
// of KeyedStruct's. prog package will reattach them during init.
ctx := &structGen{
comp: comp,
padded: make(map[interface{}]bool),
detach: make(map[**prog.StructDesc]bool),
}
// We have to do this in the loop until we pad nothing new
// due to recursive structs.
for {
start := len(ctx.padded)
for _, c := range syscalls {
for _, a := range c.Args {
ctx.walk(a)
}
if c.Ret != nil {
ctx.walk(c.Ret)
}
}
if start == len(ctx.padded) {
break
}
}
// Detach StructDesc's from StructType's. prog will reattach them again.
for descp := range ctx.detach {
*descp = nil
}
sort.Slice(ctx.structs, func(i, j int) bool {
si, sj := ctx.structs[i], ctx.structs[j]
if si.Key.Name != sj.Key.Name {
return si.Key.Name < sj.Key.Name
}
return si.Key.Dir < sj.Key.Dir
})
return ctx.structs
}
type structGen struct {
comp *compiler
padded map[interface{}]bool
detach map[**prog.StructDesc]bool
structs []*prog.KeyedStruct
}
func (ctx *structGen) check(key prog.StructKey, descp **prog.StructDesc) bool {
ctx.detach[descp] = true
desc := *descp
if ctx.padded[desc] {
return false
}
ctx.padded[desc] = true
for _, f := range desc.Fields {
ctx.walk(f)
if !f.Varlen() && f.Size() == sizeUnassigned {
// An inner struct is not padded yet.
// Leave this struct for next iteration.
delete(ctx.padded, desc)
return false
}
}
if ctx.comp.used[key.Name] {
ctx.structs = append(ctx.structs, &prog.KeyedStruct{
Key: key,
Desc: desc,
})
}
return true
}
func (ctx *structGen) walk(t0 prog.Type) {
switch t := t0.(type) {
case *prog.PtrType:
ctx.walk(t.Type)
case *prog.ArrayType:
ctx.walkArray(t)
case *prog.StructType:
ctx.walkStruct(t)
case *prog.UnionType:
ctx.walkUnion(t)
}
}
func (ctx *structGen) walkArray(t *prog.ArrayType) {
if ctx.padded[t] {
return
}
ctx.walk(t.Type)
if !t.Type.Varlen() && t.Type.Size() == sizeUnassigned {
// An inner struct is not padded yet.
// Leave this array for next iteration.
return
}
ctx.padded[t] = true
t.TypeSize = 0
if t.Kind == prog.ArrayRangeLen && t.RangeBegin == t.RangeEnd && !t.Type.Varlen() {
t.TypeSize = t.RangeBegin * t.Type.Size()
}
}
func (ctx *structGen) walkStruct(t *prog.StructType) {
if !ctx.check(t.Key, &t.StructDesc) {
return
}
comp := ctx.comp
structNode := comp.structNodes[t.StructDesc]
// Add paddings, calculate size, mark bitfields.
varlen := false
for _, f := range t.Fields {
if f.Varlen() {
varlen = true
}
}
comp.markBitfields(t.Fields)
packed, sizeAttr, alignAttr := comp.parseStructAttrs(structNode)
t.Fields = comp.addAlignment(t.Fields, varlen, packed, alignAttr)
t.AlignAttr = alignAttr
t.TypeSize = 0
if !varlen {
for _, f := range t.Fields {
if !f.BitfieldMiddle() {
t.TypeSize += f.Size()
}
}
if sizeAttr != sizeUnassigned {
if t.TypeSize > sizeAttr {
comp.error(structNode.Pos, "struct %v has size attribute %v"+
" which is less than struct size %v",
structNode.Name.Name, sizeAttr, t.TypeSize)
}
if pad := sizeAttr - t.TypeSize; pad != 0 {
t.Fields = append(t.Fields, genPad(pad))
}
t.TypeSize = sizeAttr
}
}
}
func (ctx *structGen) walkUnion(t *prog.UnionType) {
if !ctx.check(t.Key, &t.StructDesc) {
return
}
comp := ctx.comp
structNode := comp.structNodes[t.StructDesc]
varlen, sizeAttr := comp.parseUnionAttrs(structNode)
t.TypeSize = 0
if !varlen {
for _, fld := range t.Fields {
sz := fld.Size()
if sizeAttr != sizeUnassigned && sz > sizeAttr {
comp.error(structNode.Pos, "union %v has size attribute %v"+
" which is less than field %v size %v",
structNode.Name.Name, sizeAttr, fld.Name(), sz)
}
if t.TypeSize < sz {
t.TypeSize = sz
}
}
if sizeAttr != sizeUnassigned {
t.TypeSize = sizeAttr
}
}
}
func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, dir prog.Dir, varlen bool) {
// Leave node for genStructDescs to calculate size/padding.
comp.structNodes[res] = n
common := genCommon(n.Name.Name, "", sizeUnassigned, dir, false)
common.IsVarlen = varlen
*res = prog.StructDesc{
TypeCommon: common,
Fields: comp.genFieldArray(n.Fields, dir, false),
}
}
func (comp *compiler) markBitfields(fields []prog.Type) {
var bfOffset uint64
for i, f := range fields {
if f.BitfieldLength() == 0 {
continue
}
off, middle := bfOffset, true
bfOffset += f.BitfieldLength()
if i == len(fields)-1 || // Last bitfield in a group, if last field of the struct...
fields[i+1].BitfieldLength() == 0 || // or next field is not a bitfield...
f.Size() != fields[i+1].Size() || // or next field is of different size...
bfOffset+fields[i+1].BitfieldLength() > f.Size()*8 { // or next field does not fit into the current group.
middle, bfOffset = false, 0
}
setBitfieldOffset(f, off, middle)
}
}
func setBitfieldOffset(t0 prog.Type, offset uint64, middle bool) {
switch t := t0.(type) {
case *prog.IntType:
t.BitfieldOff, t.BitfieldMdl = offset, middle
case *prog.ConstType:
t.BitfieldOff, t.BitfieldMdl = offset, middle
case *prog.LenType:
t.BitfieldOff, t.BitfieldMdl = offset, middle
case *prog.FlagsType:
t.BitfieldOff, t.BitfieldMdl = offset, middle
case *prog.ProcType:
t.BitfieldOff, t.BitfieldMdl = offset, middle
default:
panic(fmt.Sprintf("type %#v can't be a bitfield", t))
}
}
func (comp *compiler) addAlignment(fields []prog.Type, varlen, packed bool, alignAttr uint64) []prog.Type {
var newFields []prog.Type
if packed {
// If a struct is packed, statically sized and has explicitly set alignment,
// add a padding at the end.
newFields = fields
if !varlen && alignAttr != 0 {
size := uint64(0)
for _, f := range fields {
if !f.BitfieldMiddle() {
size += f.Size()
}
}
if tail := size % alignAttr; tail != 0 {
newFields = append(newFields, genPad(alignAttr-tail))
}
}
return newFields
}
var align, off uint64
for i, f := range fields {
if i == 0 || !fields[i-1].BitfieldMiddle() {
a := comp.typeAlign(f)
if align < a {
align = a
}
// Append padding if the last field is not a bitfield or it's the last bitfield in a set.
if off%a != 0 {
pad := a - off%a
off += pad
newFields = append(newFields, genPad(pad))
}
}
newFields = append(newFields, f)
if !f.BitfieldMiddle() && (i != len(fields)-1 || !f.Varlen()) {
// Increase offset if the current field is not a bitfield
// or it's the last bitfield in a set, except when it's
// the last field in a struct and has variable length.
off += f.Size()
}
}
if alignAttr != 0 {
align = alignAttr
}
if align != 0 && off%align != 0 && !varlen {
pad := align - off%align
off += pad
newFields = append(newFields, genPad(pad))
}
return newFields
}
func (comp *compiler) typeAlign(t0 prog.Type) uint64 {
switch t0.(type) {
case *prog.IntType, *prog.ConstType, *prog.LenType, *prog.FlagsType, *prog.ProcType,
*prog.CsumType, *prog.PtrType, *prog.VmaType, *prog.ResourceType:
return t0.Size()
case *prog.BufferType:
return 1
}
switch t := t0.(type) {
case *prog.ArrayType:
return comp.typeAlign(t.Type)
case *prog.StructType:
packed, _, alignAttr := comp.parseStructAttrs(comp.structNodes[t.StructDesc])
if alignAttr != 0 {
return alignAttr // overrided by user attribute
}
if packed {
return 1
}
align := uint64(0)
for _, f := range t.Fields {
if a := comp.typeAlign(f); align < a {
align = a
}
}
return align
case *prog.UnionType:
align := uint64(0)
for _, f := range t.Fields {
if a := comp.typeAlign(f); align < a {
align = a
}
}
return align
default:
panic(fmt.Sprintf("unknown type: %#v", t))
}
}
func genPad(size uint64) prog.Type {
return &prog.ConstType{
IntTypeCommon: genIntCommon(genCommon("pad", "", size, prog.DirIn, false), 0, false),
IsPad: true,
}
}
func (comp *compiler) genField(f *ast.Field, dir prog.Dir, isArg bool) prog.Type {
return comp.genType(f.Type, f.Name.Name, dir, isArg)
}
func (comp *compiler) genFieldArray(fields []*ast.Field, dir prog.Dir, isArg bool) []prog.Type {
var res []prog.Type
for _, f := range fields {
res = append(res, comp.genField(f, dir, isArg))
}
return res
}
func (comp *compiler) genType(t *ast.Type, field string, dir prog.Dir, isArg bool) prog.Type {
desc, args, base := comp.getArgsBase(t, field, dir, isArg)
if desc.Gen == nil {
panic(fmt.Sprintf("no gen for %v %#v", field, t))
}
base.IsVarlen = desc.Varlen != nil && desc.Varlen(comp, t, args)
return desc.Gen(comp, t, args, base)
}
func genCommon(name, field string, size uint64, dir prog.Dir, opt bool) prog.TypeCommon {
return prog.TypeCommon{
TypeName: name,
TypeSize: size,
FldName: field,
ArgDir: dir,
IsOptional: opt,
}
}
func genIntCommon(com prog.TypeCommon, bitLen uint64, bigEndian bool) prog.IntTypeCommon {
bf := prog.FormatNative
if bigEndian {
bf = prog.FormatBigEndian
}
return prog.IntTypeCommon{
TypeCommon: com,
ArgFormat: bf,
BitfieldLen: bitLen,
}
}
func genIntArray(a []*ast.Int) []uint64 {
r := make([]uint64, len(a))
for i, v := range a {
r[i] = v.Value
}
return r
}
func genStrArray(a []*ast.String) []string {
r := make([]string, len(a))
for i, v := range a {
r[i] = v.Value
}
return r
}