csource: new package

Move C source generation into a separate package.
Prog is too bloated already.
This commit is contained in:
Dmitry Vyukov 2015-12-23 13:38:31 +01:00
parent 071ad4e91f
commit e253cbc79f
7 changed files with 340 additions and 177 deletions

233
csource/csource.go Normal file
View File

@ -0,0 +1,233 @@
// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package csource
import (
"bytes"
"fmt"
"io/ioutil"
"strings"
"os"
"os/exec"
"unsafe"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys"
)
type Options struct {
Threaded bool
Collide bool
}
func Write(p *prog.Prog, opts Options) []byte {
exec := p.SerializeForExec()
w := new(bytes.Buffer)
fmt.Fprintf(w, `// autogenerated by syzkaller (http://github.com/google/syzkaller)
#include <unistd.h>
#include <sys/syscall.h>
#include <string.h>
#include <stdint.h>
#include <pthread.h>
`)
handled := make(map[string]bool)
for _, c := range p.Calls {
name := c.Meta.CallName
nr, ok := prog.NewSyscalls[name]
if !ok || handled[name] {
continue
}
handled[name] = true
fmt.Fprintf(w, "#ifndef SYS_%v\n", name)
fmt.Fprintf(w, "#define SYS_%v %v\n", name, nr)
fmt.Fprintf(w, "#endif\n")
}
fmt.Fprintf(w, "\n")
calls,nvar := generateCalls(exec)
fmt.Fprintf(w, "long r[%v];\n\n", nvar)
if !opts.Threaded && !opts.Collide {
fmt.Fprintf(w, "int main()\n{\n")
fmt.Fprintf(w, "\tmemset(r, -1, sizeof(r));\n")
for _, c := range calls {
fmt.Fprintf(w, "%s", c)
}
fmt.Fprintf(w, "\treturn 0;\n}\n")
} else {
fmt.Fprintf(w, "void *thr(void *arg)\n{\n")
fmt.Fprintf(w, "\tswitch ((long)arg) {\n")
for i, c := range calls {
fmt.Fprintf(w, "\tcase %v:\n", i)
fmt.Fprintf(w, "%s", strings.Replace(c, "\t", "\t\t", -1))
fmt.Fprintf(w, "\t\tbreak;\n")
}
fmt.Fprintf(w, "\t}\n")
fmt.Fprintf(w, "\treturn 0;\n}\n\n")
fmt.Fprintf(w, "int main()\n{\n")
fmt.Fprintf(w, "\tlong i;\n")
fmt.Fprintf(w, "\tpthread_t th[%v];\n", len(calls))
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "\tmemset(r, -1, sizeof(r));\n")
fmt.Fprintf(w, "\tfor (i = 0; i < %v; i++) {\n", len(calls))
fmt.Fprintf(w, "\t\tpthread_create(&th[i], 0, thr, (void*)i);\n")
fmt.Fprintf(w, "\t\tusleep(10000);\n")
fmt.Fprintf(w, "\t}\n")
if opts.Collide {
fmt.Fprintf(w, "\tfor (i = 0; i < %v; i++) {\n", len(calls))
fmt.Fprintf(w, "\t\tpthread_create(&th[i], 0, thr, (void*)i);\n")
fmt.Fprintf(w, "\t\tif (i%%2==0)\n")
fmt.Fprintf(w, "\t\t\tusleep(10000);\n")
fmt.Fprintf(w, "\t}\n")
}
fmt.Fprintf(w, "\tusleep(100000);\n")
fmt.Fprintf(w, "\treturn 0;\n}\n")
}
return w.Bytes()
}
func generateCalls(exec []byte) ([]string, int) {
read := func() uintptr {
if len(exec) < 8 {
panic("exec program overflow")
}
v := *(*uint64)(unsafe.Pointer(&exec[0]))
exec = exec[8:]
return uintptr(v)
}
resultRef := func() string {
arg := read()
res := fmt.Sprintf("r[%v]", arg)
if opDiv := read(); opDiv != 0 {
res = fmt.Sprintf("%v/%v", res, opDiv)
}
if opAdd := read(); opAdd != 0 {
res = fmt.Sprintf("%v+%v", res, opAdd)
}
return res
}
lastCall := 0
seenCall := false
var calls []string
w := new(bytes.Buffer)
newCall := func() {
if seenCall {
seenCall = false
calls = append(calls, w.String())
w = new(bytes.Buffer)
}
}
n := 0
loop:
for ;; n++ {
switch instr := read(); instr {
case prog.ExecInstrEOF:
break loop
case prog.ExecInstrCopyin:
newCall()
addr := read()
typ := read()
size := read()
switch typ {
case prog.ExecArgConst:
arg := read()
fmt.Fprintf(w, "\t*(uint%v_t*)0x%x = (uint%v_t)0x%x;\n", size*8, addr, size*8, arg)
case prog.ExecArgResult:
fmt.Fprintf(w, "\t*(uint%v_t*)0x%x = %v;\n", size*8, addr, resultRef())
case prog.ExecArgData:
data := exec[:size]
exec = exec[(size+7)/8*8:]
var esc []byte
for _, v := range data {
hex := func(v byte) byte {
if v < 10 {
return '0' + v
}
return 'a' + v - 10
}
esc = append(esc, '\\', 'x', hex(v>>4), hex(v<<4>>4))
}
fmt.Fprintf(w, "\tmemcpy((void*)0x%x, \"%s\", %v);\n", addr, esc, size)
default:
panic("bad argument type")
}
case prog.ExecInstrCopyout:
addr := read()
size := read()
fmt.Fprintf(w, "\tif (r[%v] != -1)\n", lastCall)
fmt.Fprintf(w, "\t\tr[%v] = *(uint%v_t*)0x%x;\n", n, size*8, addr)
case prog.ExecInstrSetPad:
newCall()
read() // addr
read() // size
case prog.ExecInstrCheckPad:
read() // addr
read() // size
default:
// Normal syscall.
newCall()
meta := sys.Calls[instr]
fmt.Fprintf(w, "\tr[%v] = syscall(SYS_%v", n, meta.CallName)
nargs := read()
for i := uintptr(0); i < nargs; i++ {
typ := read()
size := read()
_ = size
switch typ {
case prog.ExecArgConst:
fmt.Fprintf(w, ", 0x%xul", read())
case prog.ExecArgResult:
fmt.Fprintf(w, ", %v", resultRef())
default:
panic("unknown arg type")
}
}
for i := nargs; i < 6; i++ {
fmt.Fprintf(w, ", 0")
}
fmt.Fprintf(w, ");\n")
lastCall = n
seenCall = true
}
}
newCall()
return calls, n
}
// WriteTempFile writes data to a temp file and returns its name.
func WriteTempFile(data []byte) (string, error) {
f, err := ioutil.TempFile("", "syz-prog")
if err != nil {
return "", fmt.Errorf("failed to create a temp file: %v", err)
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(f.Name())
return "", fmt.Errorf("failed to write temp file: %v", err)
}
f.Close()
return f.Name(), nil
}
// Build builds a C/C++ program from source file src
// and returns name of the resulting binary.
func Build(src string) (string, error) {
bin, err := ioutil.TempFile("", "syzkaller")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %v", err)
}
bin.Close()
out, err := exec.Command("gcc", "-x", "c++", "-std=gnu++11", src, "-o", bin.Name(), "-lpthread", "-static", "-O1", "-g").CombinedOutput()
if err != nil {
os.Remove(bin.Name())
data, _ := ioutil.ReadFile(src)
return "", fmt.Errorf("failed to build program::\n%s\n%s", data, out)
}
return bin.Name(), nil
}

53
csource/csource_test.go Normal file
View File

@ -0,0 +1,53 @@
// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package csource
import (
"math/rand"
"os"
"testing"
"time"
"github.com/google/syzkaller/prog"
)
func initTest(t *testing.T) (rand.Source, int) {
iters := 1000
if testing.Short() {
iters = 10
}
seed := int64(time.Now().UnixNano())
rs := rand.NewSource(seed)
t.Logf("seed=%v", seed)
return rs, iters
}
func Test(t *testing.T) {
rs, iters := initTest(t)
options := []Options{
Options{},
Options{Threaded: true},
Options{Threaded: true, Collide: true},
}
for i := 0; i < iters; i++ {
p := prog.Generate(rs, 10, nil)
for _, opts := range options {
testOne(t, p, opts)
}
}
}
func testOne(t *testing.T, p *prog.Prog, opts Options) {
src := Write(p, opts)
srcf, err := WriteTempFile(src)
if err != nil {
t.Fatalf("%v", err)
}
defer os.Remove(srcf)
bin, err := Build(srcf)
if err != nil {
t.Fatalf("%v", err)
}
defer os.Remove(bin)
}

View File

@ -6,14 +6,13 @@ package ipc
import (
"bufio"
"bytes"
"io/ioutil"
"math/rand"
"os"
"os/exec"
"strings"
"testing"
"time"
"github.com/google/syzkaller/csource"
"github.com/google/syzkaller/prog"
)
@ -22,33 +21,20 @@ func buildExecutor(t *testing.T) string {
}
func buildSource(t *testing.T, src []byte) string {
srcf, err := ioutil.TempFile("", "syzkaller")
tmp, err := csource.WriteTempFile(src)
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
t.Fatalf("%v", err)
}
srcf.Close()
os.Remove(srcf.Name())
name := srcf.Name() + ".c"
if err := ioutil.WriteFile(name, src, 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
defer os.Remove(name)
return buildProgram(t, name)
defer os.Remove(tmp)
return buildProgram(t, tmp)
}
func buildProgram(t *testing.T, src string) string {
bin, err := ioutil.TempFile("", "syzkaller")
bin, err := csource.Build(src)
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
t.Fatalf("%v", err)
}
bin.Close()
out, err := exec.Command("gcc", src, "-o", bin.Name(), "-lpthread", "-static", "-O1", "-g").CombinedOutput()
if err != nil {
os.Remove(bin.Name())
data, _ := ioutil.ReadFile(src)
t.Fatalf("failed to build program:\n%s\n%s", data, out)
}
return bin.Name()
return bin
}
func initTest(t *testing.T) (rand.Source, int) {
@ -73,7 +59,7 @@ func TestEmptyProg(t *testing.T) {
defer env.Close()
p := new(prog.Prog)
output, strace, cov, failed, hanged, err := env.Exec(p)
output, strace, cov, _, failed, hanged, err := env.Exec(p)
if err != nil {
t.Fatalf("failed to run executor: %v", err)
}
@ -102,7 +88,7 @@ func TestStrace(t *testing.T) {
defer env.Close()
p := new(prog.Prog)
_, strace, _, failed, hanged, err := env.Exec(p)
_, strace, _, _, failed, hanged, err := env.Exec(p)
if err != nil {
t.Fatalf("failed to run executor: %v", err)
}
@ -129,7 +115,7 @@ func TestExecute(t *testing.T) {
for i := 0; i < iters/len(flags); i++ {
p := prog.Generate(rs, 10, nil)
_, _, _, _, _, err := env.Exec(p)
_, _, _, _, _, _, err := env.Exec(p)
if err != nil {
t.Fatalf("failed to run executor: %v", err)
}
@ -163,12 +149,12 @@ func TestCompare(t *testing.T) {
rs, iters := initTest(t)
for i := 0; i < iters; i++ {
p := prog.Generate(rs, 10, nil)
_, strace1, _, _, _, err := env1.Exec(p)
_, strace1, _, _, _, _, err := env1.Exec(p)
if err != nil {
t.Fatalf("failed to run executor: %v", err)
}
src := p.WriteCSource()
src := csource.Write(p, csource.Options{})
cprog := buildSource(t, src)
defer os.Remove(cprog)
@ -178,7 +164,7 @@ func TestCompare(t *testing.T) {
}
defer env2.Close() // yes, that's defer in a loop
_, strace2, _, _, _, err := env2.Exec(nil)
_, strace2, _, _, _, _, err := env2.Exec(nil)
if err != nil {
t.Fatalf("failed to run c binary: %v", err)
}

View File

@ -1,117 +0,0 @@
// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"bytes"
"fmt"
"io"
"unsafe"
"github.com/google/syzkaller/sys"
)
func (p *Prog) WriteCSource() []byte {
exec := p.SerializeForExec()
buf := new(bytes.Buffer)
writeCSource(buf, exec)
return buf.Bytes()
}
func writeCSource(w io.Writer, exec []byte) {
fmt.Fprintf(w, `// autogenerated by syzkaller (http://github.com/google/syzkaller)
#include <syscall.h>
#include <string.h>
#include <stdint.h>
int main()
{
`)
read := func() uintptr {
if len(exec) < 8 {
panic("exec program overflow")
}
v := *(*uint64)(unsafe.Pointer(&exec[0]))
exec = exec[8:]
return uintptr(v)
}
resultRef := func() string {
arg := read()
res := fmt.Sprintf("r%v", arg)
if opDiv := read(); opDiv != 0 {
res = fmt.Sprintf("%v/%v", res, opDiv)
}
if opAdd := read(); opAdd != 0 {
res = fmt.Sprintf("%v+%v", res, opAdd)
}
return res
}
lastCall := 0
for n := 0; ; n++ {
switch instr := read(); instr {
case instrEOF:
fmt.Fprintf(w, "\treturn 0;\n}\n")
return
case instrCopyin:
addr := read()
typ := read()
size := read()
switch typ {
case execArgConst:
arg := read()
fmt.Fprintf(w, "\t*(uint%v_t*)0x%x = 0x%x;\n", size*8, addr, arg)
case execArgResult:
fmt.Fprintf(w, "\t*(uint%v_t*)0x%x = %v;\n", size*8, addr, resultRef())
case execArgData:
data := exec[:size]
exec = exec[(size+7)/8*8:]
var esc []byte
for _, v := range data {
hex := func(v byte) byte {
if v < 10 {
return '0' + v
}
return 'a' + v - 10
}
esc = append(esc, '\\', 'x', hex(v>>4), hex(v<<4>>4))
}
fmt.Fprintf(w, "\tmemcpy((void*)0x%x, \"%s\", %v);\n", addr, esc, size)
default:
panic("bad argument type")
}
case instrCopyout:
addr := read()
size := read()
fmt.Fprintf(w, "\tlong r%v = -1;\n", n)
fmt.Fprintf(w, "\tif (r%v != -1)\n", lastCall)
fmt.Fprintf(w, "\t\tr%v = *(uint%v_t*)0x%x;\n", n, size*8, addr)
case instrSetPad, instrCheckPad:
read() // addr
read() // size
default:
// Normal syscall.
meta := sys.Calls[instr]
fmt.Fprintf(w, "\tlong r%v = syscall(SYS_%v", n, meta.CallName)
nargs := read()
for i := uintptr(0); i < nargs; i++ {
typ := read()
size := read()
_ = size
switch typ {
case execArgConst:
fmt.Fprintf(w, ", 0x%xul", read())
case execArgResult:
fmt.Fprintf(w, ", %v", resultRef())
default:
panic("unknown arg type")
}
}
for i := nargs; i < 6; i++ {
fmt.Fprintf(w, ", 0")
}
fmt.Fprintf(w, ");\n")
lastCall = n
}
}
}

View File

@ -6,18 +6,22 @@
package prog
const (
instrEOF = ^uintptr(iota)
instrCopyin
instrCopyout
instrSetPad
instrCheckPad
import (
"fmt"
)
const (
execArgConst = uintptr(iota)
execArgResult
execArgData
ExecInstrEOF = ^uintptr(iota)
ExecInstrCopyin
ExecInstrCopyout
ExecInstrSetPad
ExecInstrCheckPad
)
const (
ExecArgConst = uintptr(iota)
ExecArgResult
ExecArgData
)
const (
@ -28,7 +32,7 @@ const (
func (p *Prog) SerializeForExec() []byte {
if err := p.validate(); err != nil {
panic("serializing invalid program")
panic(fmt.Errorf("serializing invalid program: %v", err))
}
var instrSeq uintptr
w := &execContext{args: make(map[*Arg]*argInfo)}
@ -61,11 +65,11 @@ func (p *Prog) SerializeForExec() []byte {
pad, padSize := arg1.IsPad()
if (arg1.Dir == DirIn && !pad) || (arg1.Dir == DirOut && pad) || arg1.Dir == DirInOut {
if pad {
w.write(instrSetPad)
w.write(ExecInstrSetPad)
w.write(physicalAddr(arg) + w.args[arg1].Offset)
w.write(padSize)
} else {
w.write(instrCopyin)
w.write(ExecInstrCopyin)
w.write(physicalAddr(arg) + w.args[arg1].Offset)
w.writeArg(arg1)
}
@ -89,7 +93,7 @@ func (p *Prog) SerializeForExec() []byte {
if pad && arg.Dir != DirIn {
instrSeq++
info := w.args[arg]
w.write(instrCheckPad)
w.write(ExecInstrCheckPad)
w.write(physicalAddr(base) + info.Offset)
w.write(padSize)
return
@ -108,7 +112,7 @@ func (p *Prog) SerializeForExec() []byte {
info := w.args[arg]
info.Idx = instrSeq
instrSeq++
w.write(instrCopyout)
w.write(ExecInstrCopyout)
w.write(physicalAddr(base) + info.Offset)
w.write(arg.Size(arg.Type))
default:
@ -116,7 +120,7 @@ func (p *Prog) SerializeForExec() []byte {
}
})
}
w.write(instrEOF)
w.write(ExecInstrEOF)
return w.buf
}
@ -151,25 +155,25 @@ func (w *execContext) write(v uintptr) {
func (w *execContext) writeArg(arg *Arg) {
switch arg.Kind {
case ArgConst:
w.write(execArgConst)
w.write(ExecArgConst)
w.write(arg.Size(arg.Type))
w.write(arg.Val)
case ArgResult:
w.write(execArgResult)
w.write(ExecArgResult)
w.write(arg.Size(arg.Type))
w.write(w.args[arg.Res].Idx)
w.write(arg.OpDiv)
w.write(arg.OpAdd)
case ArgPointer:
w.write(execArgConst)
w.write(ExecArgConst)
w.write(arg.Size(arg.Type))
w.write(physicalAddr(arg))
case ArgPageSize:
w.write(execArgConst)
w.write(ExecArgConst)
w.write(arg.Size(arg.Type))
w.write(arg.AddrPage * pageSize)
case ArgData:
w.write(execArgData)
w.write(ExecArgData)
w.write(uintptr(len(arg.Data)))
for i := 0; i < len(arg.Data); i += 8 {
var v uintptr

View File

@ -54,11 +54,3 @@ func TestSerializeForExec(t *testing.T) {
p.SerializeForExec()
}
}
func TestSerializeC(t *testing.T) {
rs, iters := initTest(t)
for i := 0; i < iters; i++ {
p := Generate(rs, 10, nil)
p.WriteCSource()
}
}

View File

@ -4,19 +4,27 @@
package main
import (
"flag"
"fmt"
"io/ioutil"
"os"
"github.com/google/syzkaller/csource"
"github.com/google/syzkaller/prog"
)
var (
flagThreaded = flag.Bool("threaded", false, "create threaded program")
flagCollide = flag.Bool("collide", false, "create collide program")
)
func main() {
if len(os.Args) != 2 {
fmt.Fprintf(os.Stderr, "usage: prog2c prog_file\n")
flag.Parse()
if len(flag.Args()) != 1 {
fmt.Fprintf(os.Stderr, "usage: prog2c [-threaded [-collide]] prog_file\n")
os.Exit(1)
}
data, err := ioutil.ReadFile(os.Args[1])
data, err := ioutil.ReadFile(flag.Args()[0])
if err != nil {
fmt.Fprintf(os.Stderr, "failed to read prog file: %v\n", err)
os.Exit(1)
@ -26,6 +34,10 @@ func main() {
fmt.Fprintf(os.Stderr, "failed to deserialize the program: %v\n", err)
os.Exit(1)
}
src := p.WriteCSource()
opts := csource.Options{
Threaded: *flagThreaded,
Collide: *flagCollide,
}
src := csource.Write(p, opts)
os.Stdout.Write(src)
}