pkg/osutil: fix LinkFiles/FilesExist for the new pattern format

This commit is contained in:
Dmitry Vyukov 2020-09-15 09:22:26 +02:00
parent a2360d0742
commit 9e681632f5
2 changed files with 64 additions and 57 deletions

View File

@ -130,11 +130,12 @@ func IsAccessible(name string) error {
// FilesExist returns true if all files exist in dir.
// Files are assumed to be relative names in slash notation.
func FilesExist(dir string, files map[string]bool) bool {
for f, required := range files {
for pattern, required := range files {
if !required {
continue
}
if !IsExist(filepath.Join(dir, filepath.FromSlash(f))) {
files, err := filepath.Glob(filepath.Join(dir, filepath.FromSlash(pattern)))
if err != nil || len(files) == 0 {
return false
}
}
@ -154,7 +155,18 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
if err := MkdirAll(tmpDir); err != nil {
return err
}
if err := foreachPatternFile(srcDir, tmpDir, files, CopyFile); err != nil {
return err
}
if err := os.RemoveAll(dstDir); err != nil {
return err
}
return os.Rename(tmpDir, dstDir)
}
func foreachPatternFile(srcDir, dstDir string, files map[string]bool, fn func(src, dst string) error) error {
srcDir = filepath.Clean(srcDir)
dstDir = filepath.Clean(dstDir)
for pattern, required := range files {
files, err := filepath.Glob(filepath.Join(srcDir, filepath.FromSlash(pattern)))
if err != nil {
@ -170,19 +182,16 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
if !strings.HasPrefix(file, srcDir) {
return fmt.Errorf("file %q matched from %q in %q doesn't have src prefix", file, pattern, srcDir)
}
dst := filepath.Join(tmpDir, strings.TrimPrefix(file, srcDir))
dst := filepath.Join(dstDir, strings.TrimPrefix(file, srcDir))
if err := MkdirAll(filepath.Dir(dst)); err != nil {
return err
}
if err := CopyFile(file, dst); err != nil {
if err := fn(file, dst); err != nil {
return err
}
}
}
if err := os.RemoveAll(dstDir); err != nil {
return err
}
return os.Rename(tmpDir, dstDir)
return nil
}
func CopyDirRecursively(srcDir, dstDir string) error {
@ -219,20 +228,7 @@ func LinkFiles(srcDir, dstDir string, files map[string]bool) error {
if err := MkdirAll(dstDir); err != nil {
return err
}
for f, required := range files {
src := filepath.Join(srcDir, filepath.FromSlash(f))
if !required && !IsExist(src) {
continue
}
dst := filepath.Join(dstDir, filepath.FromSlash(f))
if err := MkdirAll(filepath.Dir(dst)); err != nil {
return err
}
if err := os.Link(src, dst); err != nil {
return err
}
}
return nil
return foreachPatternFile(srcDir, dstDir, files, os.Link)
}
func MkdirAll(dir string) error {

View File

@ -73,6 +73,12 @@ func TestCopyFiles(t *testing.T) {
},
},
}
for _, link := range []bool{false, true} {
fn, fnName := CopyFiles, "CopyFiles"
if link {
fn, fnName = LinkFiles, "LinkFiles"
}
t.Run(fnName, func(t *testing.T) {
for i, test := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
dir, err := ioutil.TempDir("", "syz-osutil-test")
@ -91,7 +97,7 @@ func TestCopyFiles(t *testing.T) {
t.Fatal(err)
}
}
if err := CopyFiles(src, dst, test.patterns); err != nil {
if err := fn(src, dst, test.patterns); err != nil {
if test.err != "" {
if strings.Contains(err.Error(), test.err) {
return
@ -110,6 +116,11 @@ func TestCopyFiles(t *testing.T) {
t.Fatalf("%v does not exist in dst", file)
}
}
if !FilesExist(dst, test.patterns) {
t.Fatalf("dst files don't exist after copy")
}
})
}
})
}
}