diff --git a/core/base_backup.go b/core/base_backup.go index 1836d30f..aca7c6af 100644 --- a/core/base_backup.go +++ b/core/base_backup.go @@ -17,6 +17,7 @@ import ( "github.com/pocketbase/pocketbase/tools/archive" "github.com/pocketbase/pocketbase/tools/cron" "github.com/pocketbase/pocketbase/tools/filesystem" + "github.com/pocketbase/pocketbase/tools/osutils" "github.com/pocketbase/pocketbase/tools/security" ) @@ -160,9 +161,13 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error { return err } - parentDataDir := filepath.Dir(app.DataDir()) + // make sure that the special temp directory + if err := os.MkdirAll(filepath.Join(app.DataDir(), LocalTempDirName), os.ModePerm); err != nil { + return fmt.Errorf("failed to create a temp dir: %w", err) + } - extractedDataDir := filepath.Join(parentDataDir, "pb_restore_"+security.PseudorandomString(4)) + // note: it needs to be inside the current pb_data to avoid "cross-device link" errors + extractedDataDir := filepath.Join(app.DataDir(), LocalTempDirName, "pb_restore_"+security.PseudorandomString(4)) defer os.RemoveAll(extractedDataDir) if err := archive.Extract(tempZip.Name(), extractedDataDir); err != nil { return err @@ -180,64 +185,37 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error { log.Println(err) } - // make sure that a special temp directory exists in the extracted one - if err := os.MkdirAll(filepath.Join(extractedDataDir, LocalTempDirName), os.ModePerm); err != nil { - return fmt.Errorf("failed to create a temp dir: %w", err) - } + // root dir entries to exclude from the backup restore + exclude := []string{LocalBackupsDirName, LocalTempDirName} - // move the current pb_data to a special temp location that will - // hold the old data between dirs replace + // move the current pb_data content to a special temp location + // that will hold the old data between dirs replace // (the temp dir will be automatically removed on the next app start) - oldTempDataDir := filepath.Join(extractedDataDir, LocalTempDirName, "old_pb_data") - if err := os.Rename(app.DataDir(), oldTempDataDir); err != nil { - return fmt.Errorf("failed to move the current pb_data to a temp location: %w", err) + oldTempDataDir := filepath.Join(app.DataDir(), LocalTempDirName, "old_pb_data_" + security.PseudorandomString(4)) + if err := osutils.MoveDirContent(app.DataDir(), oldTempDataDir, exclude...); err != nil { + return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err) } - // "restore", aka. set the extracted backup as the new pb_data directory - if err := os.Rename(extractedDataDir, app.DataDir()); err != nil { - return fmt.Errorf("failed to set the extracted backup as pb_data dir: %w", err) + // move the extracted archive content to the app's pb_data + if err := osutils.MoveDirContent(extractedDataDir, app.DataDir(), exclude...); err != nil { + return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err) } - // update the old temp data dir path after the restore - oldTempDataDir = filepath.Join(app.DataDir(), LocalTempDirName, "old_pb_data") - - oldLocalBackupsDir := filepath.Join(oldTempDataDir, LocalBackupsDirName) - newLocalBackupsDir := filepath.Join(app.DataDir(), LocalBackupsDirName) - - revertDataDirChanges := func(revertLocalBackupsDir bool) error { - if revertLocalBackupsDir { - if _, err := os.Stat(newLocalBackupsDir); err == nil { - if err := os.Rename(newLocalBackupsDir, oldLocalBackupsDir); err != nil { - return fmt.Errorf("failed to revert the backups dir change: %w", err) - } - } - } - - if err := os.Rename(app.DataDir(), extractedDataDir); err != nil { + revertDataDirChanges := func() error { + if err := osutils.MoveDirContent(app.DataDir(), extractedDataDir, exclude...); err != nil { return fmt.Errorf("failed to revert the extracted dir change: %w", err) } - if err := os.Rename(oldTempDataDir, app.DataDir()); err != nil { + if err := osutils.MoveDirContent(oldTempDataDir, app.DataDir(), exclude...); err != nil { return fmt.Errorf("failed to revert old pb_data dir change: %w", err) } return nil } - // restore the local pb_data/backups dir (if any) - if _, err := os.Stat(oldLocalBackupsDir); err == nil { - if err := os.Rename(oldLocalBackupsDir, newLocalBackupsDir); err != nil { - if err := revertDataDirChanges(false); err != nil && app.IsDebug() { - log.Println(err) - } - - return fmt.Errorf("failed to move the local pb_data/backups dir: %w", err) - } - } - // restart the app if err := app.Restart(); err != nil { - if err := revertDataDirChanges(true); err != nil { + if err := revertDataDirChanges(); err != nil { panic(err) } diff --git a/tools/archive/create_test.go b/tools/archive/create_test.go index e22849ac..8f8aa32f 100644 --- a/tools/archive/create_test.go +++ b/tools/archive/create_test.go @@ -54,7 +54,7 @@ func TestCreateSuccess(t *testing.T) { } } -// --- +// ------------------------------------------------------------------- // note: make sure to call os.RemoveAll(dir) after you are done // working with the created test dir. diff --git a/tools/osutils/dir.go b/tools/osutils/dir.go new file mode 100644 index 00000000..500bbc79 --- /dev/null +++ b/tools/osutils/dir.go @@ -0,0 +1,80 @@ +package osutils + +import ( + "log" + "os" + "path/filepath" + + "github.com/pocketbase/pocketbase/tools/list" +) + +// MoveDirContent moves the src dir content, that is not listed in the exclide list, +// to dest dir (it will be created if missing). +// +// The rootExclude argument is used to specify a list of src root entries to exclude. +// +// Note that this method doesn't delete the old src dir. +// +// It is an alternative to os.Rename() for the cases where we can't +// rename/delete the src dir (see https://github.com/pocketbase/pocketbase/issues/2519). +func MoveDirContent(src string, dest string, rootExclude ...string) error { + entries, err := os.ReadDir(src) + if err != nil { + return err + } + + // make sure that the dest dir exist + manuallyCreatedDestDir := false + if _, err := os.Stat(dest); err != nil { + if err := os.Mkdir(dest, os.ModePerm); err != nil { + return err + } + manuallyCreatedDestDir = true + } + + moved := map[string]string{} + + tryRollback := func() ([]error) { + errs := []error{} + + for old, new := range moved { + if err := os.Rename(new, old); err != nil{ + errs = append(errs, err) + } + } + + // try to delete manually the created dest dir if all moved files were restored + if manuallyCreatedDestDir && len(errs) == 0 { + if err := os.Remove(dest); err != nil { + errs = append(errs, err) + } + } + + return errs + } + + for _, entry := range entries { + basename := entry.Name() + + if list.ExistInSlice(basename, rootExclude) { + continue + } + + old := filepath.Join(src, basename) + new := filepath.Join(dest, basename) + + if err := os.Rename(old, new); err != nil { + if errs := tryRollback(); len(errs) > 0 { + // currently just log the rollback errors + // in the future we may require go 1.20+ to use errors.Join() + log.Println(errs) + } + + return err + } + + moved[old] = new + } + + return nil +} diff --git a/tools/osutils/dir_test.go b/tools/osutils/dir_test.go new file mode 100644 index 00000000..a99fa20f --- /dev/null +++ b/tools/osutils/dir_test.go @@ -0,0 +1,144 @@ +package osutils_test + +import ( + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/pocketbase/pocketbase/tools/list" + "github.com/pocketbase/pocketbase/tools/osutils" + "github.com/pocketbase/pocketbase/tools/security" +) + +func TestMoveDirContent(t *testing.T) { + testDir := createTestDir(t) + defer os.RemoveAll(testDir) + + exclude := []string{ + "missing", + "test2", + "b", + } + + // missing dest path + // --- + dir1 := filepath.Join(filepath.Dir(testDir), "a", "b", "c", "d", "_pb_move_dir_content_test_" + security.PseudorandomString(4)) + defer os.RemoveAll(dir1) + + if err := osutils.MoveDirContent(testDir, dir1, exclude...); err == nil { + t.Fatal("Expected path error, got nil") + } + + // existing parent dir + // --- + dir2 := filepath.Join(filepath.Dir(testDir), "_pb_move_dir_content_test_" + security.PseudorandomString(4)) + defer os.RemoveAll(dir2) + + if err := osutils.MoveDirContent(testDir, dir2, exclude...); err != nil { + t.Fatalf("Expected dir2 to be created, got error: %v", err) + } + + + // find all files + files := []string{} + filepath.WalkDir(dir2, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + return nil + } + + files = append(files, path) + + return nil + }) + + expectedFiles := []string{ + filepath.Join(dir2, "test1"), + filepath.Join(dir2, "a", "a1"), + filepath.Join(dir2, "a", "a2"), + }; + + if len(files) != len(expectedFiles) { + t.Fatalf("Expected %d files, got %d: \n%v", len(expectedFiles), len(files), files) + } + + for _, expected := range expectedFiles { + if !list.ExistInSlice(expected, files) { + t.Fatalf("Missing expected file %q in \n%v", expected, files) + } + } +} + +// ------------------------------------------------------------------- + +// note: make sure to call os.RemoveAll(dir) after you are done +// working with the created test dir. +func createTestDir(t *testing.T) string { + dir, err := os.MkdirTemp(os.TempDir(), "test_dir") + if err != nil { + t.Fatal(err) + } + + // create sub directories + if err := os.MkdirAll(filepath.Join(dir, "a"), os.ModePerm); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(dir, "b"), os.ModePerm); err != nil { + t.Fatal(err) + } + + + { + f, err := os.OpenFile(filepath.Join(dir, "test1"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + { + f, err := os.OpenFile(filepath.Join(dir, "test2"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + { + f, err := os.OpenFile(filepath.Join(dir, "a/a1"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + { + f, err := os.OpenFile(filepath.Join(dir, "a/a2"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + { + f, err := os.OpenFile(filepath.Join(dir, "b/b2"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + { + f, err := os.OpenFile(filepath.Join(dir, "b/b2"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + f.Close() + } + + return dir +}