diff --git a/pkg/osutil/osutil.go b/pkg/osutil/osutil.go index cd6a1cce..b50b6e45 100644 --- a/pkg/osutil/osutil.go +++ b/pkg/osutil/osutil.go @@ -130,11 +130,12 @@ func IsAccessible(name string) error { // FilesExist returns true if all files exist in dir. // Files are assumed to be relative names in slash notation. func FilesExist(dir string, files map[string]bool) bool { - for f, required := range files { + for pattern, required := range files { if !required { continue } - if !IsExist(filepath.Join(dir, filepath.FromSlash(f))) { + files, err := filepath.Glob(filepath.Join(dir, filepath.FromSlash(pattern))) + if err != nil || len(files) == 0 { return false } } @@ -154,7 +155,18 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error { if err := MkdirAll(tmpDir); err != nil { return err } + if err := foreachPatternFile(srcDir, tmpDir, files, CopyFile); err != nil { + return err + } + if err := os.RemoveAll(dstDir); err != nil { + return err + } + return os.Rename(tmpDir, dstDir) +} + +func foreachPatternFile(srcDir, dstDir string, files map[string]bool, fn func(src, dst string) error) error { srcDir = filepath.Clean(srcDir) + dstDir = filepath.Clean(dstDir) for pattern, required := range files { files, err := filepath.Glob(filepath.Join(srcDir, filepath.FromSlash(pattern))) if err != nil { @@ -170,19 +182,16 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error { if !strings.HasPrefix(file, srcDir) { return fmt.Errorf("file %q matched from %q in %q doesn't have src prefix", file, pattern, srcDir) } - dst := filepath.Join(tmpDir, strings.TrimPrefix(file, srcDir)) + dst := filepath.Join(dstDir, strings.TrimPrefix(file, srcDir)) if err := MkdirAll(filepath.Dir(dst)); err != nil { return err } - if err := CopyFile(file, dst); err != nil { + if err := fn(file, dst); err != nil { return err } } } - if err := os.RemoveAll(dstDir); err != nil { - return err - } - return os.Rename(tmpDir, dstDir) + return nil } func CopyDirRecursively(srcDir, dstDir string) error { @@ -219,20 +228,7 @@ func LinkFiles(srcDir, dstDir string, files map[string]bool) error { if err := MkdirAll(dstDir); err != nil { return err } - for f, required := range files { - src := filepath.Join(srcDir, filepath.FromSlash(f)) - if !required && !IsExist(src) { - continue - } - dst := filepath.Join(dstDir, filepath.FromSlash(f)) - if err := MkdirAll(filepath.Dir(dst)); err != nil { - return err - } - if err := os.Link(src, dst); err != nil { - return err - } - } - return nil + return foreachPatternFile(srcDir, dstDir, files, os.Link) } func MkdirAll(dir string) error { diff --git a/pkg/osutil/osutil_test.go b/pkg/osutil/osutil_test.go index 3d367a0c..168e7ba3 100644 --- a/pkg/osutil/osutil_test.go +++ b/pkg/osutil/osutil_test.go @@ -73,42 +73,53 @@ func TestCopyFiles(t *testing.T) { }, }, } - for i, test := range tests { - t.Run(fmt.Sprint(i), func(t *testing.T) { - dir, err := ioutil.TempDir("", "syz-osutil-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) - src := filepath.Join(dir, "src") - dst := filepath.Join(dir, "dst") - for _, file := range test.files { - file = filepath.Join(src, filepath.FromSlash(file)) - if err := MkdirAll(filepath.Dir(file)); err != nil { - t.Fatal(err) - } - if err := WriteFile(file, []byte{'a'}); err != nil { - t.Fatal(err) - } - } - if err := CopyFiles(src, dst, test.patterns); err != nil { - if test.err != "" { - if strings.Contains(err.Error(), test.err) { - return + for _, link := range []bool{false, true} { + fn, fnName := CopyFiles, "CopyFiles" + if link { + fn, fnName = LinkFiles, "LinkFiles" + } + t.Run(fnName, func(t *testing.T) { + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + dir, err := ioutil.TempDir("", "syz-osutil-test") + if err != nil { + t.Fatal(err) } - t.Fatalf("got err %q, want %q", err, test.err) - } - t.Fatal(err) - } else if test.err != "" { - t.Fatalf("got no err, want %q", test.err) - } - if err := os.RemoveAll(src); err != nil { - t.Fatal(err) - } - for _, file := range test.files { - if !IsExist(filepath.Join(dst, filepath.FromSlash(file))) { - t.Fatalf("%v does not exist in dst", file) - } + defer os.RemoveAll(dir) + src := filepath.Join(dir, "src") + dst := filepath.Join(dir, "dst") + for _, file := range test.files { + file = filepath.Join(src, filepath.FromSlash(file)) + if err := MkdirAll(filepath.Dir(file)); err != nil { + t.Fatal(err) + } + if err := WriteFile(file, []byte{'a'}); err != nil { + t.Fatal(err) + } + } + if err := fn(src, dst, test.patterns); err != nil { + if test.err != "" { + if strings.Contains(err.Error(), test.err) { + return + } + t.Fatalf("got err %q, want %q", err, test.err) + } + t.Fatal(err) + } else if test.err != "" { + t.Fatalf("got no err, want %q", test.err) + } + if err := os.RemoveAll(src); err != nil { + t.Fatal(err) + } + for _, file := range test.files { + if !IsExist(filepath.Join(dst, filepath.FromSlash(file))) { + t.Fatalf("%v does not exist in dst", file) + } + } + if !FilesExist(dst, test.patterns) { + t.Fatalf("dst files don't exist after copy") + } + }) } }) }