diff --git a/tools/archive/create.go b/tools/archive/create.go new file mode 100644 index 00000000..b6a5b76f --- /dev/null +++ b/tools/archive/create.go @@ -0,0 +1,70 @@ +package archive + +import ( + "archive/zip" + "io" + "io/fs" + "os" +) + +// Create creates a new zip archive from src dir content and saves it in dest path. +func Create(src, dest string) error { + zf, err := os.Create(dest) + if err != nil { + return err + } + defer zf.Close() + + zw := zip.NewWriter(zf) + defer zw.Close() + + if err := zipAddFS(zw, os.DirFS(src)); err != nil { + // try to cleanup the created zip file + os.Remove(dest) + + return err + } + + return nil +} + +// note remove after similar method is added in the std lib (https://github.com/golang/go/issues/54898) +func zipAddFS(w *zip.Writer, fsys fs.FS) error { + return fs.WalkDir(fsys, ".", func(name string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if d.IsDir() { + return nil + } + + info, err := d.Info() + if err != nil { + return err + } + + h, err := zip.FileInfoHeader(info) + if err != nil { + return err + } + + h.Name = name + h.Method = zip.Deflate + + fw, err := w.CreateHeader(h) + if err != nil { + return err + } + + f, err := fsys.Open(name) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(fw, f) + + return err + }) +} diff --git a/tools/archive/create_test.go b/tools/archive/create_test.go new file mode 100644 index 00000000..7a51f708 --- /dev/null +++ b/tools/archive/create_test.go @@ -0,0 +1,84 @@ +package archive_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/pocketbase/pocketbase/tools/archive" +) + +func TestCreateFailure(t *testing.T) { + testDir := createTestDir(t) + defer os.RemoveAll(testDir) + + zipPath := filepath.Join(os.TempDir(), "pb_test.zip") + defer os.RemoveAll(zipPath) + + missingDir := filepath.Join(os.TempDir(), "missing") + + if err := archive.Create(missingDir, zipPath); err == nil { + t.Fatal("Expected to fail due to missing directory or file") + } + + if _, err := os.Stat(zipPath); err == nil { + t.Fatalf("Expected the zip file not to be created") + } +} + +func TestCreateSuccess(t *testing.T) { + testDir := createTestDir(t) + defer os.RemoveAll(testDir) + + zipName := "pb_test.zip" + zipPath := filepath.Join(os.TempDir(), zipName) + defer os.RemoveAll(zipPath) + + // zip testDir content + if err := archive.Create(testDir, zipPath); err != nil { + t.Fatalf("Failed to create archive: %v", err) + } + + info, err := os.Stat(zipPath) + if err != nil { + t.Fatalf("Failed to retrieve the generated zip file: %v", err) + } + + if name := info.Name(); name != zipName { + t.Fatalf("Expected zip with name %q, got %q", zipName, name) + } + + expectedSize := int64(300) + if size := info.Size(); size != expectedSize { + t.Fatalf("Expected zip with size %d, got %d", expectedSize, size) + } +} + +// --- + +// note: make sure to call os.RemoveAll(dir) after you are done +// working with the created test dir. +func createTestDir(t *testing.T) string { + dir, err := os.MkdirTemp(os.TempDir(), "pb_zip_test") + if err != nil { + t.Fatal(err) + } + + if err := os.MkdirAll(filepath.Join(dir, "a/b/c"), os.ModePerm); err != nil { + t.Fatal(err) + } + + sub1, err := os.OpenFile(filepath.Join(dir, "a/sub1.txt"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + sub1.Close() + + sub2, err := os.OpenFile(filepath.Join(dir, "a/b/c/sub2.txt"), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + t.Fatal(err) + } + sub2.Close() + + return dir +} diff --git a/tools/archive/extract.go b/tools/archive/extract.go new file mode 100644 index 00000000..02d8202a --- /dev/null +++ b/tools/archive/extract.go @@ -0,0 +1,72 @@ +package archive + +import ( + "archive/zip" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// Extract extracts the zip archive at src to dest. +func Extract(src, dest string) error { + zr, err := zip.OpenReader(src) + if err != nil { + return err + } + defer zr.Close() + + // normalize dest path to check later for Zip Slip + dest = filepath.Clean(dest) + string(os.PathSeparator) + + for _, f := range zr.File { + err := extractFile(f, dest) + if err != nil { + return err + } + } + + return nil +} + +// extractFile extracts the provided zipFile into "basePath/zipFileName" path, +// creating all the necessary path directories. +func extractFile(zipFile *zip.File, basePath string) error { + path := filepath.Join(basePath, zipFile.Name) + + // check for Zip Slip + if !strings.HasPrefix(path, basePath) { + return fmt.Errorf("invalid file path: %s", path) + } + + r, err := zipFile.Open() + if err != nil { + return err + } + defer r.Close() + + if zipFile.FileInfo().IsDir() { + if err := os.MkdirAll(path, os.ModePerm); err != nil { + return err + } + } else { + // ensure that the file path directories are created + if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { + return err + } + + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zipFile.Mode()) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, r) + if err != nil { + return err + } + } + + return nil +} diff --git a/tools/archive/extract_test.go b/tools/archive/extract_test.go new file mode 100644 index 00000000..efb5a213 --- /dev/null +++ b/tools/archive/extract_test.go @@ -0,0 +1,57 @@ +package archive_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/pocketbase/pocketbase/tools/archive" +) + +func TestExtractFailure(t *testing.T) { + testDir := createTestDir(t) + defer os.RemoveAll(testDir) + + missingZipPath := filepath.Join(os.TempDir(), "pb_missing_test.zip") + extractPath := filepath.Join(os.TempDir(), "pb_zip_extract") + defer os.RemoveAll(extractPath) + + if err := archive.Extract(missingZipPath, extractPath); err == nil { + t.Fatal("Expected Extract to fail due to missing zipPath") + } + + if _, err := os.Stat(extractPath); err == nil { + t.Fatalf("Expected %q to not be created", extractPath) + } +} + +func TestExtractSuccess(t *testing.T) { + testDir := createTestDir(t) + defer os.RemoveAll(testDir) + + zipPath := filepath.Join(os.TempDir(), "pb_test.zip") + defer os.RemoveAll(zipPath) + + extractPath := filepath.Join(os.TempDir(), "pb_zip_extract") + defer os.RemoveAll(extractPath) + + // zip testDir content + if err := archive.Create(testDir, zipPath); err != nil { + t.Fatalf("Failed to create archive: %v", err) + } + + if err := archive.Extract(zipPath, extractPath); err != nil { + t.Fatalf("Failed to extract %q in %q", zipPath, extractPath) + } + + pathsToCheck := []string{ + filepath.Join(extractPath, "a/sub1.txt"), + filepath.Join(extractPath, "a/b/c/sub2.txt"), + } + + for _, p := range pathsToCheck { + if _, err := os.Stat(p); err != nil { + t.Fatalf("Failed to retrieve extracted file %q: %v", p, err) + } + } +}