From b41406fbd64a0a61e6088810254a189dc828ecad Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Mon, 30 Sep 2024 16:27:59 +0300 Subject: [PATCH] moved FindUploadedFiles in RequestEvent --- apis/backup_upload.go | 2 +- apis/base.go | 29 --------------- apis/base_test.go | 73 ------------------------------------- apis/record_crud.go | 8 ++--- tools/router/event.go | 29 +++++++++++++++ tools/router/event_test.go | 74 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 107 deletions(-) diff --git a/apis/backup_upload.go b/apis/backup_upload.go index 3117abcf..81eaf051 100644 --- a/apis/backup_upload.go +++ b/apis/backup_upload.go @@ -18,7 +18,7 @@ func backupUpload(e *core.RequestEvent) error { form := new(backupUploadForm) form.fsys = fsys - files, _ := FindUploadedFiles(e.Request, "file") + files, _ := e.FindUploadedFiles("file") if len(files) > 0 { form.File = files[0] } diff --git a/apis/base.go b/apis/base.go index 42003474..1ca1cb53 100644 --- a/apis/base.go +++ b/apis/base.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/pocketbase/pocketbase/core" - "github.com/pocketbase/pocketbase/tools/filesystem" "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/router" ) @@ -172,31 +171,3 @@ func safeRedirectPath(path string) string { } return path } - -// FindUploadedFiles extracts all form files of "key" from a http request -// and returns a slice with filesystem.File instances (if any). -func FindUploadedFiles(r *http.Request, key string) ([]*filesystem.File, error) { - if r.MultipartForm == nil { - err := r.ParseMultipartForm(router.DefaultMaxMemory) - if err != nil { - return nil, err - } - } - - if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 { - return nil, http.ErrMissingFile - } - - result := make([]*filesystem.File, 0, len(r.MultipartForm.File[key])) - - for _, fh := range r.MultipartForm.File[key] { - file, err := filesystem.NewFileFromMultipart(fh) - if err != nil { - return nil, err - } - - result = append(result, file) - } - - return result, nil -} diff --git a/apis/base_test.go b/apis/base_test.go index 79d71c1d..19d96a51 100644 --- a/apis/base_test.go +++ b/apis/base_test.go @@ -1,15 +1,11 @@ package apis_test import ( - "bytes" "fmt" - "mime/multipart" "net/http" "net/http/httptest" "os" "path/filepath" - "regexp" - "strings" "testing" "github.com/pocketbase/pocketbase/apis" @@ -237,75 +233,6 @@ func TestStatic(t *testing.T) { } } -func TestFindUploadedFiles(t *testing.T) { - scenarios := []struct { - filename string - expectedPattern string - }{ - {"ab.png", `^ab\w{10}_\w{10}\.png$`}, - {"test", `^test_\w{10}\.txt$`}, - {"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`}, - {strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`}, - } - - for _, s := range scenarios { - t.Run(s.filename, func(t *testing.T) { - // create multipart form file body - body := new(bytes.Buffer) - mp := multipart.NewWriter(body) - w, err := mp.CreateFormFile("test", s.filename) - if err != nil { - t.Fatal(err) - } - w.Write([]byte("test")) - mp.Close() - // --- - - req := httptest.NewRequest(http.MethodPost, "/", body) - req.Header.Add("Content-Type", mp.FormDataContentType()) - - result, err := apis.FindUploadedFiles(req, "test") - if err != nil { - t.Fatal(err) - } - - if len(result) != 1 { - t.Fatalf("Expected 1 file, got %d", len(result)) - } - - if result[0].Size != 4 { - t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size) - } - - pattern, err := regexp.Compile(s.expectedPattern) - if err != nil { - t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err) - } - if !pattern.MatchString(result[0].Name) { - t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name) - } - }) - } -} - -func TestFindUploadedFilesMissing(t *testing.T) { - body := new(bytes.Buffer) - mp := multipart.NewWriter(body) - mp.Close() - - req := httptest.NewRequest(http.MethodPost, "/", body) - req.Header.Add("Content-Type", mp.FormDataContentType()) - - result, err := apis.FindUploadedFiles(req, "test") - if err == nil { - t.Error("Expected error, got nil") - } - - if result != nil { - t.Errorf("Expected result to be nil, got %v", result) - } -} - func TestMustSubFS(t *testing.T) { t.Parallel() diff --git a/apis/record_crud.go b/apis/record_crud.go index ece65d37..badd3a52 100644 --- a/apis/record_crud.go +++ b/apis/record_crud.go @@ -529,7 +529,7 @@ func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[strin result := record.ReplaceModifiers(info.Body) // resolve uploaded files - uploadedFiles, err := extractUploadedFiles(e.Request, record.Collection(), "") + uploadedFiles, err := extractUploadedFiles(e, record.Collection(), "") if err != nil { return nil, err } @@ -559,8 +559,8 @@ func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[strin return result, nil } -func extractUploadedFiles(request *http.Request, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) { - contentType := request.Header.Get("content-type") +func extractUploadedFiles(re *core.RequestEvent, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) { + contentType := re.Request.Header.Get("content-type") if !strings.HasPrefix(contentType, "multipart/form-data") { return nil, nil // not multipart/form-data request } @@ -585,7 +585,7 @@ func extractUploadedFiles(request *http.Request, collection *core.Collection, pr if prefix != "" { k = prefix + "." + k } - files, err := FindUploadedFiles(request, k) + files, err := re.FindUploadedFiles(k) if err != nil && !errors.Is(err, http.ErrMissingFile) { return nil, err } diff --git a/tools/router/event.go b/tools/router/event.go index d4e51b12..7f8fc143 100644 --- a/tools/router/event.go +++ b/tools/router/event.go @@ -12,6 +12,7 @@ import ( "path/filepath" "strings" + "github.com/pocketbase/pocketbase/tools/filesystem" "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/picker" "github.com/pocketbase/pocketbase/tools/store" @@ -126,6 +127,34 @@ func (e *Event) UnsafeRealIP() string { return e.RemoteIP() } +// FindUploadedFiles extracts all form files of "key" from a http request +// and returns a slice with filesystem.File instances (if any). +func (e *Event) FindUploadedFiles(key string) ([]*filesystem.File, error) { + if e.Request.MultipartForm == nil { + err := e.Request.ParseMultipartForm(DefaultMaxMemory) + if err != nil { + return nil, err + } + } + + if e.Request.MultipartForm == nil || e.Request.MultipartForm.File == nil || len(e.Request.MultipartForm.File[key]) == 0 { + return nil, http.ErrMissingFile + } + + result := make([]*filesystem.File, 0, len(e.Request.MultipartForm.File[key])) + + for _, fh := range e.Request.MultipartForm.File[key] { + file, err := filesystem.NewFileFromMultipart(fh) + if err != nil { + return nil, err + } + + result = append(result, file) + } + + return result, nil +} + // Store // ------------------------------------------------------------------- diff --git a/tools/router/event_test.go b/tools/router/event_test.go index ba464bdc..902a76fd 100644 --- a/tools/router/event_test.go +++ b/tools/router/event_test.go @@ -13,6 +13,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "regexp" "strconv" "strings" "testing" @@ -277,6 +278,79 @@ func TestEventUnsafeRealIP(t *testing.T) { } } +func TestFindUploadedFiles(t *testing.T) { + scenarios := []struct { + filename string + expectedPattern string + }{ + {"ab.png", `^ab\w{10}_\w{10}\.png$`}, + {"test", `^test_\w{10}\.txt$`}, + {"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`}, + {strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`}, + } + + for _, s := range scenarios { + t.Run(s.filename, func(t *testing.T) { + // create multipart form file body + body := new(bytes.Buffer) + mp := multipart.NewWriter(body) + w, err := mp.CreateFormFile("test", s.filename) + if err != nil { + t.Fatal(err) + } + w.Write([]byte("test")) + mp.Close() + // --- + + req := httptest.NewRequest(http.MethodPost, "/", body) + req.Header.Add("Content-Type", mp.FormDataContentType()) + + event := router.Event{Request: req} + + result, err := event.FindUploadedFiles("test") + if err != nil { + t.Fatal(err) + } + + if len(result) != 1 { + t.Fatalf("Expected 1 file, got %d", len(result)) + } + + if result[0].Size != 4 { + t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size) + } + + pattern, err := regexp.Compile(s.expectedPattern) + if err != nil { + t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err) + } + if !pattern.MatchString(result[0].Name) { + t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name) + } + }) + } +} + +func TestFindUploadedFilesMissing(t *testing.T) { + body := new(bytes.Buffer) + mp := multipart.NewWriter(body) + mp.Close() + + req := httptest.NewRequest(http.MethodPost, "/", body) + req.Header.Add("Content-Type", mp.FormDataContentType()) + + event := router.Event{Request: req} + + result, err := event.FindUploadedFiles("test") + if err == nil { + t.Error("Expected error, got nil") + } + + if result != nil { + t.Errorf("Expected result to be nil, got %v", result) + } +} + func TestEventSetGet(t *testing.T) { event := router.Event{}