syzkaller/pkg/csource/csource.go
Dmitry Vyukov 306ca0571c prog, pkg/compiler: support fmt type
fmt type allows to convert intergers and resources
to string representation.
2018-07-08 22:52:24 +02:00

545 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2015 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// Package csource generates [almost] equivalent C programs from syzkaller programs.
package csource
import (
"bytes"
"fmt"
"regexp"
"strings"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
)
func Write(p *prog.Prog, opts Options) ([]byte, error) {
if err := opts.Check(p.Target.OS); err != nil {
return nil, fmt.Errorf("csource: invalid opts: %v", err)
}
ctx := &context{
p: p,
opts: opts,
target: p.Target,
sysTarget: targets.Get(p.Target.OS, p.Target.Arch),
w: new(bytes.Buffer),
calls: make(map[string]uint64),
}
calls, vars, err := ctx.generateProgCalls(ctx.p)
if err != nil {
return nil, err
}
mmapProg := p.Target.GenerateUberMmapProg()
mmapCalls, _, err := ctx.generateProgCalls(mmapProg)
if err != nil {
return nil, err
}
for _, c := range append(mmapProg.Calls, p.Calls...) {
ctx.calls[c.Meta.CallName] = c.Meta.NR
}
ctx.print("// autogenerated by syzkaller (http://github.com/google/syzkaller)\n\n")
hdr, err := createCommonHeader(p, mmapProg, opts)
if err != nil {
return nil, err
}
ctx.w.Write(hdr)
ctx.print("\n")
ctx.generateSyscallDefines()
if len(vars) != 0 {
ctx.printf("uint64_t r[%v] = {", len(vars))
for i, v := range vars {
if i != 0 {
ctx.printf(", ")
}
ctx.printf("0x%x", v)
}
ctx.printf("};\n")
}
needProcID := opts.Procs > 1 || opts.EnableCgroups
for _, c := range p.Calls {
if c.Meta.CallName == "syz_mount_image" ||
c.Meta.CallName == "syz_read_part_table" {
needProcID = true
}
}
if needProcID {
ctx.printf("unsigned long long procid;\n")
}
if !opts.Repeat {
ctx.generateTestFunc(calls, len(vars) != 0, "loop")
ctx.print("int main()\n{\n")
for _, c := range mmapCalls {
ctx.printf("%s", c)
}
if opts.HandleSegv {
ctx.printf("\tinstall_segv_handler();\n")
}
if opts.UseTmpDir {
ctx.printf("\tuse_temporary_dir();\n")
}
ctx.writeLoopCall()
ctx.print("\treturn 0;\n}\n")
} else {
ctx.generateTestFunc(calls, len(vars) != 0, "execute_one")
if opts.Procs <= 1 {
ctx.print("int main()\n{\n")
for _, c := range mmapCalls {
ctx.printf("%s", c)
}
if opts.HandleSegv {
ctx.print("\tinstall_segv_handler();\n")
}
if opts.UseTmpDir {
ctx.print("\tchar *cwd = get_current_dir_name();\n")
}
ctx.print("\tfor (;;) {\n")
if opts.UseTmpDir {
ctx.print("\t\tif (chdir(cwd))\n")
ctx.print("\t\t\tfail(\"failed to chdir\");\n")
ctx.print("\t\tuse_temporary_dir();\n")
}
ctx.writeLoopCall()
ctx.print("\t}\n}\n")
} else {
ctx.print("int main()\n{\n")
for _, c := range mmapCalls {
ctx.printf("%s", c)
}
if opts.UseTmpDir {
ctx.print("\tchar *cwd = get_current_dir_name();\n")
}
ctx.printf("\tfor (procid = 0; procid < %v; procid++) {\n", opts.Procs)
ctx.print("\t\tif (fork() == 0) {\n")
if opts.HandleSegv {
ctx.print("\t\t\tinstall_segv_handler();\n")
}
ctx.print("\t\t\tfor (;;) {\n")
if opts.UseTmpDir {
ctx.print("\t\t\t\tif (chdir(cwd))\n")
ctx.print("\t\t\t\t\tfail(\"failed to chdir\");\n")
ctx.print("\t\t\t\tuse_temporary_dir();\n")
}
ctx.writeLoopCall()
ctx.print("\t\t\t}\n")
ctx.print("\t\t}\n")
ctx.print("\t}\n")
ctx.print("\tsleep(1000000);\n")
ctx.print("\treturn 0;\n}\n")
}
}
// Remove NONFAILING and debug calls.
result := ctx.w.Bytes()
if !opts.HandleSegv {
re := regexp.MustCompile(`\t*NONFAILING\((.*)\);\n`)
result = re.ReplaceAll(result, []byte("$1;\n"))
}
if !opts.Debug {
re := regexp.MustCompile(`\t*debug\((.*\n)*?.*\);\n`)
result = re.ReplaceAll(result, nil)
re = regexp.MustCompile(`\t*debug_dump_data\((.*\n)*?.*\);\n`)
result = re.ReplaceAll(result, nil)
}
result = bytes.Replace(result, []byte("NORETURN"), nil, -1)
result = bytes.Replace(result, []byte("PRINTF"), nil, -1)
// Remove duplicate new lines.
for {
result1 := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1)
result1 = bytes.Replace(result1, []byte("\n\n#include"), []byte("\n#include"), -1)
if len(result1) == len(result) {
break
}
result = result1
}
return result, nil
}
type context struct {
p *prog.Prog
opts Options
target *prog.Target
sysTarget *targets.Target
w *bytes.Buffer
calls map[string]uint64 // CallName -> NR
}
func (ctx *context) print(str string) {
ctx.w.WriteString(str)
}
func (ctx *context) printf(str string, args ...interface{}) {
ctx.print(fmt.Sprintf(str, args...))
}
func (ctx *context) writeLoopCall() {
if ctx.opts.Sandbox != "" {
ctx.printf("\tdo_sandbox_%v();\n", ctx.opts.Sandbox)
return
}
if ctx.opts.EnableTun {
ctx.printf("\tinitialize_tun();\n")
}
if ctx.opts.EnableNetdev {
ctx.printf("\tinitialize_netdevices();\n")
}
ctx.print("\tloop();\n")
}
func (ctx *context) generateTestFunc(calls []string, hasVars bool, name string) {
opts := ctx.opts
if !opts.Threaded && !opts.Collide {
ctx.printf("void %v()\n{\n", name)
if hasVars {
ctx.printf("\tlong res = 0;\n")
}
if opts.Debug {
// Use debug to avoid: error: debug defined but not used.
ctx.printf("\tdebug(\"%v\\n\");\n", name)
}
if opts.Repro {
ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n")
}
for _, c := range calls {
ctx.printf("%s", c)
}
ctx.printf("}\n\n")
} else {
ctx.printf("void execute_call(int call)\n{\n")
if hasVars {
ctx.printf("\tlong res;")
}
ctx.printf("\tswitch (call) {\n")
for i, c := range calls {
ctx.printf("\tcase %v:\n", i)
ctx.printf("%s", strings.Replace(c, "\t", "\t\t", -1))
ctx.printf("\t\tbreak;\n")
}
ctx.printf("\t}\n")
ctx.printf("}\n\n")
ctx.printf("void %v()\n{\n", name)
if opts.Debug {
// Use debug to avoid: error: debug defined but not used.
ctx.printf("\tdebug(\"%v\\n\");\n", name)
}
if opts.Repro {
ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n")
}
ctx.printf("\texecute(%v);\n", len(calls))
if opts.Collide {
ctx.printf("\tcollide = 1;\n")
ctx.printf("\texecute(%v);\n", len(calls))
}
ctx.printf("}\n\n")
}
}
func (ctx *context) generateSyscallDefines() {
prefix := ctx.sysTarget.SyscallPrefix
for name, nr := range ctx.calls {
if !ctx.sysTarget.SyscallNumbers ||
strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) {
continue
}
ctx.printf("#ifndef %v%v\n", prefix, name)
ctx.printf("#define %v%v %v\n", prefix, name, nr)
ctx.printf("#endif\n")
}
if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 {
// This is a dirty hack.
// On 32-bit linux mmap translated to old_mmap syscall which has a different signature.
// mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here.
ctx.printf("#undef __NR_mmap\n")
ctx.printf("#define __NR_mmap __NR_mmap2\n")
}
ctx.printf("\n")
}
func (ctx *context) generateProgCalls(p *prog.Prog) ([]string, []uint64, error) {
exec := make([]byte, prog.ExecBufferSize)
progSize, err := p.SerializeForExec(exec)
if err != nil {
return nil, nil, fmt.Errorf("failed to serialize program: %v", err)
}
decoded, err := ctx.target.DeserializeExec(exec[:progSize])
if err != nil {
return nil, nil, err
}
calls, vars := ctx.generateCalls(decoded)
return calls, vars, nil
}
func (ctx *context) generateCalls(p prog.ExecProg) ([]string, []uint64) {
var calls []string
csumSeq := 0
for ci, call := range p.Calls {
w := new(bytes.Buffer)
// Copyin.
for _, copyin := range call.Copyin {
ctx.copyin(w, &csumSeq, copyin)
}
// Call itself.
if ctx.opts.Fault && ctx.opts.FaultCall == ci {
fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/failslab/ignore-gfp-wait\", \"N\");\n")
fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/fail_futex/ignore-private\", \"N\");\n")
fmt.Fprintf(w, "\tinject_fault(%v);\n", ctx.opts.FaultNth)
}
callName := call.Meta.CallName
resCopyout := call.Index != prog.ExecNoCopyout
argCopyout := len(call.Copyout) != 0
emitCall := ctx.opts.EnableTun || callName != "syz_emit_ethernet" &&
callName != "syz_extract_tcp_res"
// TODO: if we don't emit the call we must also not emit copyin, copyout and fault injection.
// However, simply skipping whole iteration breaks tests due to unused static functions.
if emitCall {
native := ctx.sysTarget.SyscallNumbers && !strings.HasPrefix(callName, "syz_")
fmt.Fprintf(w, "\t")
if resCopyout || argCopyout {
fmt.Fprintf(w, "res = ")
}
if native {
fmt.Fprintf(w, "syscall(%v%v", ctx.sysTarget.SyscallPrefix, callName)
} else if strings.HasPrefix(callName, "syz_") {
fmt.Fprintf(w, "%v(", callName)
} else {
args := strings.Repeat(",long", len(call.Args))
if args != "" {
args = args[1:]
}
fmt.Fprintf(w, "((long(*)(%v))%v)(", args, callName)
}
for ai, arg := range call.Args {
if native || ai > 0 {
fmt.Fprintf(w, ", ")
}
switch arg := arg.(type) {
case prog.ExecArgConst:
if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
panic("sring format in syscall argument")
}
fmt.Fprintf(w, "%v", ctx.constArgToStr(arg))
case prog.ExecArgResult:
if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
panic("sring format in syscall argument")
}
fmt.Fprintf(w, "%v", ctx.resultArgToStr(arg))
default:
panic(fmt.Sprintf("unknown arg type: %+v", arg))
}
}
fmt.Fprintf(w, ");\n")
}
// Copyout.
if resCopyout || argCopyout {
if ctx.sysTarget.OS == "fuchsia" {
// On fuchsia we have real system calls that return ZX_OK on success,
// and libc calls that are casted to function returning long,
// as the result int -1 is returned as 0x00000000ffffffff rather than full -1.
if strings.HasPrefix(callName, "zx_") {
fmt.Fprintf(w, "\tif (res == ZX_OK)")
} else {
fmt.Fprintf(w, "\tif ((int)res != -1)")
}
} else {
fmt.Fprintf(w, "\tif (res != -1)")
}
copyoutMultiple := len(call.Copyout) > 1 || resCopyout && len(call.Copyout) > 0
if copyoutMultiple {
fmt.Fprintf(w, " {")
}
fmt.Fprintf(w, "\n")
if resCopyout {
fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index)
}
for _, copyout := range call.Copyout {
fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v_t*)0x%x);\n",
copyout.Index, copyout.Size*8, copyout.Addr)
}
if copyoutMultiple {
fmt.Fprintf(w, "\t}\n")
}
}
calls = append(calls, w.String())
}
return calls, p.Vars
}
func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.ExecArgCsum, csumSeq int) {
fmt.Fprintf(w, "\tstruct csum_inet csum_%d;\n", csumSeq)
fmt.Fprintf(w, "\tcsum_inet_init(&csum_%d);\n", csumSeq)
for i, chunk := range arg.Chunks {
switch chunk.Kind {
case prog.ExecArgCsumChunkData:
fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8_t*)0x%x, %d));\n",
csumSeq, chunk.Value, chunk.Size)
case prog.ExecArgCsumChunkConst:
fmt.Fprintf(w, "\tuint%d_t csum_%d_chunk_%d = 0x%x;\n",
chunk.Size*8, csumSeq, i, chunk.Value)
fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8_t*)&csum_%d_chunk_%d, %d);\n",
csumSeq, csumSeq, i, chunk.Size)
default:
panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind))
}
}
fmt.Fprintf(w, "\tNONFAILING(*(uint16_t*)0x%x = csum_inet_digest(&csum_%d));\n",
addr, csumSeq)
}
func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin) {
switch arg := copyin.Arg.(type) {
case prog.ExecArgConst:
if arg.BitfieldOffset == 0 && arg.BitfieldLength == 0 {
ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.constArgToStr(arg), arg.Format)
} else {
if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
panic("bitfield+string format")
}
fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v_t, 0x%x, %v, %v, %v));\n",
arg.Size*8, copyin.Addr, ctx.constArgToStr(arg),
arg.BitfieldOffset, arg.BitfieldLength)
}
case prog.ExecArgResult:
ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.resultArgToStr(arg), arg.Format)
case prog.ExecArgData:
fmt.Fprintf(w, "\tNONFAILING(memcpy((void*)0x%x, \"%s\", %v));\n",
copyin.Addr, toCString(arg.Data), len(arg.Data))
case prog.ExecArgCsum:
switch arg.Kind {
case prog.ExecArgCsumInet:
*csumSeq++
ctx.generateCsumInet(w, copyin.Addr, arg, *csumSeq)
default:
panic(fmt.Sprintf("unknown csum kind %v", arg.Kind))
}
default:
panic(fmt.Sprintf("bad argument type: %+v", arg))
}
}
func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) {
switch bf {
case prog.FormatNative, prog.FormatBigEndian:
fmt.Fprintf(w, "\tNONFAILING(*(uint%v_t*)0x%x = %v);\n", size*8, addr, val)
case prog.FormatStrDec:
if size != 20 {
panic("bad strdec size")
}
fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%020llu\", (long long)%v));\n", addr, val)
case prog.FormatStrHex:
if size != 18 {
panic("bad strdec size")
}
fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"0x%%016llx\", (long long)%v));\n", addr, val)
case prog.FormatStrOct:
if size != 23 {
panic("bad strdec size")
}
fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%023llo\", (long long)%v));\n", addr, val)
default:
panic("unknown binary format")
}
}
func (ctx *context) constArgToStr(arg prog.ExecArgConst) string {
mask := (uint64(1) << (arg.Size * 8)) - 1
v := arg.Value & mask
val := fmt.Sprintf("%v", v)
if v == ^uint64(0)&mask {
val = "-1"
} else if v >= 10 {
val = fmt.Sprintf("0x%x", v)
}
if ctx.opts.Procs > 1 && arg.PidStride != 0 {
val += fmt.Sprintf(" + procid*%v", arg.PidStride)
}
if arg.Format == prog.FormatBigEndian {
val = fmt.Sprintf("htobe%v(%v)", arg.Size*8, val)
}
return val
}
func (ctx *context) resultArgToStr(arg prog.ExecArgResult) string {
res := fmt.Sprintf("r[%v]", arg.Index)
if arg.DivOp != 0 {
res = fmt.Sprintf("%v/%v", res, arg.DivOp)
}
if arg.AddOp != 0 {
res = fmt.Sprintf("%v+%v", res, arg.AddOp)
}
if arg.Format == prog.FormatBigEndian {
res = fmt.Sprintf("htobe%v(%v)", arg.Size*8, res)
}
return res
}
func toCString(data []byte) []byte {
if len(data) == 0 {
return nil
}
readable := true
for i, v := range data {
// Allow 0 only as last byte.
if !isReadable(v) && (i != len(data)-1 || v != 0) {
readable = false
break
}
}
if !readable {
buf := new(bytes.Buffer)
for _, v := range data {
buf.Write([]byte{'\\', 'x', toHex(v >> 4), toHex(v << 4 >> 4)})
}
return buf.Bytes()
}
if data[len(data)-1] == 0 {
// Don't serialize last 0, C strings are 0-terminated anyway.
data = data[:len(data)-1]
}
buf := new(bytes.Buffer)
for _, v := range data {
switch v {
case '\t':
buf.Write([]byte{'\\', 't'})
case '\r':
buf.Write([]byte{'\\', 'r'})
case '\n':
buf.Write([]byte{'\\', 'n'})
case '\\':
buf.Write([]byte{'\\', '\\'})
case '"':
buf.Write([]byte{'\\', '"'})
default:
if v < 0x20 || v >= 0x7f {
panic("unexpected char during data serialization")
}
buf.WriteByte(v)
}
}
return buf.Bytes()
}
func isReadable(v byte) bool {
return v >= 0x20 && v < 0x7f || v == '\t' || v == '\r' || v == '\n'
}
func toHex(v byte) byte {
if v < 10 {
return '0' + v
}
return 'a' + v - 10
}