merge v0.23.0-rc changes

This commit is contained in:
Gani Georgiev 2024-09-29 19:23:19 +03:00
parent ad92992324
commit 844f18cac3
753 changed files with 85141 additions and 63396 deletions

2
.github/SECURITY.md vendored
View File

@ -2,4 +2,4 @@
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**. If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
All reports will be promptly addressed, and you'll be credited accordingly. All reports will be promptly addressed and you'll be credited in the fix release notes.

View File

@ -7,7 +7,14 @@ on:
jobs: jobs:
goreleaser: goreleaser:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
flags: ""
steps: steps:
# re-enable auto-snapshot from goreleaser-action@v3
# (https://github.com/goreleaser/goreleaser-action-v4-auto-snapshot-example)
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
@ -16,12 +23,12 @@ jobs:
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v4
with: with:
node-version: 20.11.0 node-version: 20.17.0
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: '>=1.22.5' go-version: '>=1.23.0'
# This step usually is not needed because the /ui/dist is pregenerated locally # This step usually is not needed because the /ui/dist is pregenerated locally
# but its here to ensure that each release embeds the latest admin ui artifacts. # but its here to ensure that each release embeds the latest admin ui artifacts.
@ -36,19 +43,14 @@ jobs:
# - name: Generate jsvm types # - name: Generate jsvm types
# run: go run ./plugins/jsvm/internal/types/types.go # run: go run ./plugins/jsvm/internal/types/types.go
# The prebuilt golangci-lint doesn't support go 1.18+ yet
# https://github.com/golangci/golangci-lint/issues/2649
# - name: Run linter
# uses: golangci/golangci-lint-action@v3
- name: Run tests - name: Run tests
run: go test ./... run: go test ./...
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v3 uses: goreleaser/goreleaser-action@v6
with: with:
distribution: goreleaser distribution: goreleaser
version: latest version: '~> v2'
args: release --clean args: release --clean ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,3 +1,5 @@
version: 2
project_name: pocketbase project_name: pocketbase
dist: .builds dist: .builds
@ -58,7 +60,7 @@ checksum:
name_template: 'checksums.txt' name_template: 'checksums.txt'
snapshot: snapshot:
name_template: '{{ incpatch .Version }}-next' version_template: '{{ incpatch .Version }}-next'
changelog: changelog:
sort: asc sort: asc

View File

@ -1,3 +1,8 @@
## v0.23.0-RC (WIP)
...
## v0.22.21 ## v0.22.21
- Lock the logs database during backup to prevent `database disk image is malformed` errors in case there is a log write running in the background ([#5541](https://github.com/pocketbase/pocketbase/discussions/5541)). - Lock the logs database during backup to prevent `database disk image is malformed` errors in case there is a log write running in the background ([#5541](https://github.com/pocketbase/pocketbase/discussions/5541)).

View File

@ -10,7 +10,7 @@
<a href="https://pkg.go.dev/github.com/pocketbase/pocketbase" target="_blank" rel="noopener"><img src="https://godoc.org/github.com/pocketbase/pocketbase?status.svg" alt="Go package documentation" /></a> <a href="https://pkg.go.dev/github.com/pocketbase/pocketbase" target="_blank" rel="noopener"><img src="https://godoc.org/github.com/pocketbase/pocketbase?status.svg" alt="Go package documentation" /></a>
</p> </p>
[PocketBase](https://pocketbase.io) is an open source Go backend, consisting of: [PocketBase](https://pocketbase.io) is an open source Go backend that includes:
- embedded database (_SQLite_) with **realtime subscriptions** - embedded database (_SQLite_) with **realtime subscriptions**
- built-in **files and users management** - built-in **files and users management**
@ -46,7 +46,7 @@ your own custom app specific business logic and still have a single portable exe
Here is a minimal example: Here is a minimal example:
0. [Install Go 1.21+](https://go.dev/doc/install) (_if you haven't already_) 0. [Install Go 1.23+](https://go.dev/doc/install) (_if you haven't already_)
1. Create a new project directory with the following `main.go` file inside it: 1. Create a new project directory with the following `main.go` file inside it:
```go ```go
@ -56,29 +56,20 @@ Here is a minimal example:
"log" "log"
"net/http" "net/http"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase" "github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
) )
func main() { func main() {
app := pocketbase.New() app := pocketbase.New()
app.OnBeforeServe().Add(func(e *core.ServeEvent) error { app.OnServe().BindFunc(func(se *core.ServeEvent) error {
// add new "GET /hello" route to the app router (echo) // registers new "GET /hello" route
e.Router.AddRoute(echo.Route{ se.Router.Get("/hello", func(re *core.RequestEvent) error {
Method: http.MethodGet, return re.String(200, "Hello world!")
Path: "/hello",
Handler: func(c echo.Context) error {
return c.String(200, "Hello world!")
},
Middlewares: []echo.MiddlewareFunc{
apis.ActivityLogger(app),
},
}) })
return nil return se.Next()
}) })
if err := app.Start(); err != nil { if err := app.Start(); err != nil {
@ -145,7 +136,7 @@ Check also the [Testing guide](http://pocketbase.io/docs/testing) to learn how t
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**. If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
All reports will be promptly addressed, and you'll be credited accordingly. All reports will be promptly addressed and you'll be credited in the fix release notes.
## Contributing ## Contributing

View File

@ -1,353 +0,0 @@
package apis
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindAdminApi registers the admin api endpoints and the corresponding handlers.
func bindAdminApi(app core.App, rg *echo.Group) {
api := adminApi{app: app}
subGroup := rg.Group("/admins", ActivityLogger(app))
subGroup.POST("/auth-with-password", api.authWithPassword)
subGroup.POST("/request-password-reset", api.requestPasswordReset)
subGroup.POST("/confirm-password-reset", api.confirmPasswordReset)
subGroup.POST("/auth-refresh", api.authRefresh, RequireAdminAuth())
subGroup.GET("", api.list, RequireAdminAuth())
subGroup.POST("", api.create, RequireAdminAuthOnlyIfAny(app))
subGroup.GET("/:id", api.view, RequireAdminAuth())
subGroup.PATCH("/:id", api.update, RequireAdminAuth())
subGroup.DELETE("/:id", api.delete, RequireAdminAuth())
}
type adminApi struct {
app core.App
}
func (api *adminApi) authResponse(c echo.Context, admin *models.Admin, finalizers ...func(token string) error) error {
token, tokenErr := tokens.NewAdminAuthToken(api.app, admin)
if tokenErr != nil {
return NewBadRequestError("Failed to create auth token.", tokenErr)
}
for _, f := range finalizers {
if err := f(token); err != nil {
return err
}
}
event := new(core.AdminAuthEvent)
event.HttpContext = c
event.Admin = admin
event.Token = token
return api.app.OnAdminAuthRequest().Trigger(event, func(e *core.AdminAuthEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(200, map[string]any{
"token": e.Token,
"admin": e.Admin,
})
})
}
func (api *adminApi) authRefresh(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin == nil {
return NewNotFoundError("Missing auth admin context.", nil)
}
event := new(core.AdminAuthRefreshEvent)
event.HttpContext = c
event.Admin = admin
return api.app.OnAdminBeforeAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error {
return api.app.OnAdminAfterAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error {
return api.authResponse(e.HttpContext, e.Admin)
})
})
}
func (api *adminApi) authWithPassword(c echo.Context) error {
form := forms.NewAdminLogin(api.app)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
event := new(core.AdminAuthWithPasswordEvent)
event.HttpContext = c
event.Password = form.Password
event.Identity = form.Identity
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(admin *models.Admin) error {
event.Admin = admin
return api.app.OnAdminBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
return api.app.OnAdminAfterAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error {
return api.authResponse(e.HttpContext, e.Admin)
})
})
}
})
return submitErr
}
func (api *adminApi) requestPasswordReset(c echo.Context) error {
form := forms.NewAdminPasswordResetRequest(api.app)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
if err := form.Validate(); err != nil {
return NewBadRequestError("An error occurred while validating the form.", err)
}
event := new(core.AdminRequestPasswordResetEvent)
event.HttpContext = c
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(Admin *models.Admin) error {
event.Admin = Admin
return api.app.OnAdminBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error {
// run in background because we don't need to show the result to the client
routine.FireAndForget(func() {
if err := next(e.Admin); err != nil {
api.app.Logger().Error("Failed to send admin password reset request.", "error", err)
}
})
return api.app.OnAdminAfterRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
// eagerly write 204 response and skip submit errors
// as a measure against admins enumeration
if !c.Response().Committed {
c.NoContent(http.StatusNoContent)
}
return submitErr
}
func (api *adminApi) confirmPasswordReset(c echo.Context) error {
form := forms.NewAdminPasswordResetConfirm(api.app)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.AdminConfirmPasswordResetEvent)
event.HttpContext = c
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(admin *models.Admin) error {
event.Admin = admin
return api.app.OnAdminBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to set new password.", err)
}
return api.app.OnAdminAfterConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *adminApi) list(c echo.Context) error {
fieldResolver := search.NewSimpleFieldResolver(
"id", "created", "updated", "name", "email",
)
admins := []*models.Admin{}
result, err := search.NewProvider(fieldResolver).
Query(api.app.Dao().AdminQuery()).
ParseAndExec(c.QueryParams().Encode(), &admins)
if err != nil {
return NewBadRequestError("", err)
}
event := new(core.AdminsListEvent)
event.HttpContext = c
event.Admins = admins
event.Result = result
return api.app.OnAdminsListRequest().Trigger(event, func(e *core.AdminsListEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Result)
})
}
func (api *adminApi) view(c echo.Context) error {
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
admin, err := api.app.Dao().FindAdminById(id)
if err != nil || admin == nil {
return NewNotFoundError("", err)
}
event := new(core.AdminViewEvent)
event.HttpContext = c
event.Admin = admin
return api.app.OnAdminViewRequest().Trigger(event, func(e *core.AdminViewEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Admin)
})
}
func (api *adminApi) create(c echo.Context) error {
admin := &models.Admin{}
form := forms.NewAdminUpsert(api.app, admin)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.AdminCreateEvent)
event.HttpContext = c
event.Admin = admin
// create the admin
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(m *models.Admin) error {
event.Admin = m
return api.app.OnAdminBeforeCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to create admin.", err)
}
return api.app.OnAdminAfterCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Admin)
})
})
}
})
return submitErr
}
func (api *adminApi) update(c echo.Context) error {
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
admin, err := api.app.Dao().FindAdminById(id)
if err != nil || admin == nil {
return NewNotFoundError("", err)
}
form := forms.NewAdminUpsert(api.app, admin)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.AdminUpdateEvent)
event.HttpContext = c
event.Admin = admin
// update the admin
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(m *models.Admin) error {
event.Admin = m
return api.app.OnAdminBeforeUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to update admin.", err)
}
return api.app.OnAdminAfterUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Admin)
})
})
}
})
return submitErr
}
func (api *adminApi) delete(c echo.Context) error {
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
admin, err := api.app.Dao().FindAdminById(id)
if err != nil || admin == nil {
return NewNotFoundError("", err)
}
event := new(core.AdminDeleteEvent)
event.HttpContext = c
event.Admin = admin
return api.app.OnAdminBeforeDeleteRequest().Trigger(event, func(e *core.AdminDeleteEvent) error {
if err := api.app.Dao().DeleteAdmin(e.Admin); err != nil {
return NewBadRequestError("Failed to delete admin.", err)
}
return api.app.OnAdminAfterDeleteRequest().Trigger(event, func(e *core.AdminDeleteEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}

View File

@ -1,925 +0,0 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestAdminAuthWithPassword(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"identity":{"code":"validation_required","message":"Cannot be blank."},"password":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "invalid data",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "wrong email",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"missing@example.com","password":"1234567890"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
},
},
{
Name: "wrong password",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"invalid"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
},
},
{
Name: "valid email/password (guest)",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"admin":{"id":"sywbhecnh46rhm0"`,
`"token":`,
},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
"OnAdminAfterAuthWithPasswordRequest": 1,
"OnAdminAuthRequest": 1,
},
},
{
Name: "valid email/password (already authorized)",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4MTYwMH0.han3_sG65zLddpcX2ic78qgy7FKecuPfOpFa8Dvi5Bg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"admin":{"id":"sywbhecnh46rhm0"`,
`"token":`,
},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
"OnAdminAfterAuthWithPasswordRequest": 1,
"OnAdminAuthRequest": 1,
},
},
{
Name: "OnAdminAfterAuthWithPasswordRequest error response",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4MTYwMH0.han3_sG65zLddpcX2ic78qgy7FKecuPfOpFa8Dvi5Bg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterAuthWithPasswordRequest().Add(func(e *core.AdminAuthWithPasswordEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
"OnAdminAfterAuthWithPasswordRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminRequestPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "invalid data",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "missing admin",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
},
{
Name: "existing admin",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnMailerBeforeAdminResetPasswordSend": 1,
"OnMailerAfterAdminResetPasswordSend": 1,
"OnAdminBeforeRequestPasswordResetRequest": 1,
"OnAdminAfterRequestPasswordResetRequest": 1,
},
},
{
Name: "existing admin (after already sent)",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
// simulate recent password request
admin, err := app.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
admin.LastResetSentAt = types.NowDateTime()
dao := daos.New(app.Dao().DB()) // new dao to ignore hooks
if err := dao.Save(admin); err != nil {
t.Fatal(err)
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminConfirmPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"password":{"code":"validation_required","message":"Cannot be blank."},"passwordConfirm":{"code":"validation_required","message":"Cannot be blank."},"token":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "invalid data",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "expired token",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MTY0MDk5MTY2MX0.GLwCOsgWTTEKXTK-AyGW838de1OeZGIjfHH0FoRLqZg",
"password":"1234567890",
"passwordConfirm":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"token":{"code":"validation_invalid_token","message":"Invalid or expired token."}}}`},
},
{
Name: "valid token + invalid password",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
"password":"123456",
"passwordConfirm":"123456"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"password":{"code":"validation_length_out_of_range"`},
},
{
Name: "valid token + valid password",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
"password":"1234567891",
"passwordConfirm":"1234567891"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeConfirmPasswordResetRequest": 1,
"OnAdminAfterConfirmPasswordResetRequest": 1,
},
},
{
Name: "OnAdminAfterConfirmPasswordResetRequest error response",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
"password":"1234567891",
"passwordConfirm":"1234567891"
}`),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterConfirmPasswordResetRequest().Add(func(e *core.AdminConfirmPasswordResetEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeConfirmPasswordResetRequest": 1,
"OnAdminAfterConfirmPasswordResetRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminRefresh(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin (expired token)",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MTY0MDk5MTY2MX0.I7w8iktkleQvC7_UIRpD7rNzcU4OnF7i7SFIUu6lD_4",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin (valid token)",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"admin":{"id":"sywbhecnh46rhm0"`,
`"token":`,
},
ExpectedEvents: map[string]int{
"OnAdminAuthRequest": 1,
"OnAdminBeforeAuthRefreshRequest": 1,
"OnAdminAfterAuthRefreshRequest": 1,
},
},
{
Name: "OnAdminAfterAuthRefreshRequest error response",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterAuthRefreshRequest().Add(func(e *core.AdminAuthRefreshEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthRefreshRequest": 1,
"OnAdminAfterAuthRefreshRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/admins",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodGet,
Url: "/api/admins",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin",
Method: http.MethodGet,
Url: "/api/admins",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":3`,
`"items":[{`,
`"id":"sywbhecnh46rhm0"`,
`"id":"sbmbsdb40jyxf7h"`,
`"id":"9q2trqumvlyr3bd"`,
},
ExpectedEvents: map[string]int{
"OnAdminsListRequest": 1,
},
},
{
Name: "authorized as admin + paging and sorting",
Method: http.MethodGet,
Url: "/api/admins?page=2&perPage=1&sort=-created",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":2`,
`"perPage":1`,
`"totalItems":3`,
`"items":[{`,
`"id":"sbmbsdb40jyxf7h"`,
},
NotExpectedContent: []string{
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnAdminsListRequest": 1,
},
},
{
Name: "authorized as admin + invalid filter",
Method: http.MethodGet,
Url: "/api/admins?filter=invalidfield~'test2'",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + valid filter",
Method: http.MethodGet,
Url: "/api/admins?filter=email~'test3'",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"items":[{`,
`"id":"9q2trqumvlyr3bd"`,
},
NotExpectedContent: []string{
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnAdminsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/admins/sbmbsdb40jyxf7h",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodGet,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + nonexisting admin id",
Method: http.MethodGet,
Url: "/api/admins/nonexisting",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + existing admin id",
Method: http.MethodGet,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sbmbsdb40jyxf7h"`,
},
NotExpectedContent: []string{
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnAdminViewRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + missing admin id",
Method: http.MethodDelete,
Url: "/api/admins/missing",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + existing admin id",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnModelBeforeDelete": 1,
"OnModelAfterDelete": 1,
"OnAdminBeforeDeleteRequest": 1,
"OnAdminAfterDeleteRequest": 1,
},
},
{
Name: "authorized as admin - try to delete the only remaining admin",
Method: http.MethodDelete,
Url: "/api/admins/sywbhecnh46rhm0",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
// delete all admins except the authorized one
adminModel := &models.Admin{}
_, err := app.Dao().DB().Delete(adminModel.TableName(), dbx.Not(dbx.HashExp{
"id": "sywbhecnh46rhm0",
})).Execute()
if err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeDeleteRequest": 1,
},
},
{
Name: "OnAdminAfterDeleteRequest error response",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterDeleteRequest().Add(func(e *core.AdminDeleteEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeDelete": 1,
"OnModelAfterDelete": 1,
"OnAdminBeforeDeleteRequest": 1,
"OnAdminAfterDeleteRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminCreate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized (while having at least 1 existing admin)",
Method: http.MethodPost,
Url: "/api/admins",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "unauthorized (while having 0 existing admins)",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{"email":"testnew@example.com","password":"1234567890","passwordConfirm":"1234567890","avatar":3}`),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
// delete all admins
_, err := app.Dao().DB().NewQuery("DELETE FROM {{_admins}}").Execute()
if err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":`,
`"email":"testnew@example.com"`,
`"avatar":3`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
"OnAdminBeforeCreateRequest": 1,
"OnAdminAfterCreateRequest": 1,
},
},
{
Name: "authorized as user",
Method: http.MethodPost,
Url: "/api/admins",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + empty data",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."},"password":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "authorized as admin + invalid data format",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + invalid data",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{
"email":"test@example.com",
"password":"1234",
"passwordConfirm":"4321",
"avatar":99
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"avatar":{"code":"validation_max_less_equal_than_required"`,
`"email":{"code":"validation_admin_email_exists"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
},
{
Name: "authorized as admin + valid data",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567890",
"passwordConfirm":"1234567890",
"avatar":3
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":`,
`"email":"testnew@example.com"`,
`"avatar":3`,
},
NotExpectedContent: []string{
`"password"`,
`"passwordConfirm"`,
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
"OnAdminBeforeCreateRequest": 1,
"OnAdminAfterCreateRequest": 1,
},
},
{
Name: "OnAdminAfterCreateRequest error response",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567890",
"passwordConfirm":"1234567890",
"avatar":3
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterCreateRequest().Add(func(e *core.AdminCreateEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
"OnAdminBeforeCreateRequest": 1,
"OnAdminAfterCreateRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminUpdate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + missing admin",
Method: http.MethodPatch,
Url: "/api/admins/missing",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + empty data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sbmbsdb40jyxf7h"`,
`"email":"test2@example.com"`,
`"avatar":2`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeUpdateRequest": 1,
"OnAdminAfterUpdateRequest": 1,
},
},
{
Name: "authorized as admin + invalid formatted data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + invalid data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{
"email":"test@example.com",
"password":"1234",
"passwordConfirm":"4321",
"avatar":99
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"avatar":{"code":"validation_max_less_equal_than_required"`,
`"email":{"code":"validation_admin_email_exists"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
},
{
Name: "authorized as admin + valid data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567891",
"passwordConfirm":"1234567891",
"avatar":5
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sbmbsdb40jyxf7h"`,
`"email":"testnew@example.com"`,
`"avatar":5`,
},
NotExpectedContent: []string{
`"password"`,
`"passwordConfirm"`,
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeUpdateRequest": 1,
"OnAdminAfterUpdateRequest": 1,
},
},
{
Name: "OnAdminAfterUpdateRequest error response",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567891",
"passwordConfirm":"1234567891",
"avatar":5
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterUpdateRequest().Add(func(e *core.AdminUpdateEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeUpdateRequest": 1,
"OnAdminAfterUpdateRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -1,132 +0,0 @@
package apis
import (
"net/http"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/inflector"
)
// ApiError defines the struct for a basic api error response.
type ApiError struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]any `json:"data"`
// stores unformatted error data (could be an internal error, text, etc.)
rawData any
}
// Error makes it compatible with the `error` interface.
func (e *ApiError) Error() string {
return e.Message
}
// RawData returns the unformatted error data (could be an internal error, text, etc.)
func (e *ApiError) RawData() any {
return e.rawData
}
// NewNotFoundError creates and returns 404 `ApiError`.
func NewNotFoundError(message string, data any) *ApiError {
if message == "" {
message = "The requested resource wasn't found."
}
return NewApiError(http.StatusNotFound, message, data)
}
// NewBadRequestError creates and returns 400 `ApiError`.
func NewBadRequestError(message string, data any) *ApiError {
if message == "" {
message = "Something went wrong while processing your request."
}
return NewApiError(http.StatusBadRequest, message, data)
}
// NewForbiddenError creates and returns 403 `ApiError`.
func NewForbiddenError(message string, data any) *ApiError {
if message == "" {
message = "You are not allowed to perform this request."
}
return NewApiError(http.StatusForbidden, message, data)
}
// NewUnauthorizedError creates and returns 401 `ApiError`.
func NewUnauthorizedError(message string, data any) *ApiError {
if message == "" {
message = "Missing or invalid authentication token."
}
return NewApiError(http.StatusUnauthorized, message, data)
}
// NewApiError creates and returns new normalized `ApiError` instance.
func NewApiError(status int, message string, data any) *ApiError {
return &ApiError{
rawData: data,
Data: safeErrorsData(data),
Code: status,
Message: strings.TrimSpace(inflector.Sentenize(message)),
}
}
func safeErrorsData(data any) map[string]any {
switch v := data.(type) {
case validation.Errors:
return resolveSafeErrorsData[error](v)
case map[string]validation.Error:
return resolveSafeErrorsData[validation.Error](v)
case map[string]error:
return resolveSafeErrorsData[error](v)
case map[string]any:
return resolveSafeErrorsData[any](v)
default:
return map[string]any{} // not nil to ensure that is json serialized as object
}
}
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
result := map[string]any{}
for name, err := range data {
if isNestedError(err) {
result[name] = safeErrorsData(err)
continue
}
result[name] = resolveSafeErrorItem(err)
}
return result
}
func isNestedError(err any) bool {
switch err.(type) {
case validation.Errors, map[string]validation.Error, map[string]error, map[string]any:
return true
}
return false
}
// resolveSafeErrorItem extracts from each validation error its
// public safe error code and message.
func resolveSafeErrorItem(err any) map[string]string {
// default public safe error values
code := "validation_invalid_value"
msg := "Invalid value."
// only validation errors are public safe
if obj, ok := err.(validation.Error); ok {
code = obj.Code()
msg = inflector.Sentenize(obj.Error())
}
return map[string]string{
"code": code,
"message": msg,
}
}

42
apis/api_error_aliases.go Normal file
View File

@ -0,0 +1,42 @@
package apis
import "github.com/pocketbase/pocketbase/tools/router"
// ApiError aliases to minimize the breaking changes with earlier versions
// and for consistency with the JSVM binds.
// -------------------------------------------------------------------
// NewApiError is an alias for [router.NewApiError].
func NewApiError(status int, message string, errData any) *router.ApiError {
return router.NewApiError(status, message, errData)
}
// NewBadRequestError is an alias for [router.NewBadRequestError].
func NewBadRequestError(message string, errData any) *router.ApiError {
return router.NewBadRequestError(message, errData)
}
// NewNotFoundError is an alias for [router.NewNotFoundError].
func NewNotFoundError(message string, errData any) *router.ApiError {
return router.NewNotFoundError(message, errData)
}
// NewForbiddenError is an alias for [router.NewForbiddenError].
func NewForbiddenError(message string, errData any) *router.ApiError {
return router.NewForbiddenError(message, errData)
}
// NewUnauthorizedError is an alias for [router.NewUnauthorizedError].
func NewUnauthorizedError(message string, errData any) *router.ApiError {
return router.NewUnauthorizedError(message, errData)
}
// NewTooManyRequestsError is an alias for [router.NewTooManyRequestsError].
func NewTooManyRequestsError(message string, errData any) *router.ApiError {
return router.NewTooManyRequestsError(message, errData)
}
// NewInternalServerError is an alias for [router.NewInternalServerError].
func NewInternalServerError(message string, errData any) *router.ApiError {
return router.NewInternalServerError(message, errData)
}

View File

@ -1,162 +0,0 @@
package apis_test
import (
"encoding/json"
"errors"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/apis"
)
func TestNewApiErrorWithRawData(t *testing.T) {
t.Parallel()
e := apis.NewApiError(
300,
"message_test",
"rawData_test",
)
result, _ := json.Marshal(e)
expected := `{"code":300,"message":"Message_test.","data":{}}`
if string(result) != expected {
t.Errorf("Expected %v, got %v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() != "rawData_test" {
t.Errorf("Expected rawData %v, got %v", "rawData_test", e.RawData())
}
}
func TestNewApiErrorWithValidationData(t *testing.T) {
t.Parallel()
e := apis.NewApiError(
300,
"message_test",
validation.Errors{
"err1": errors.New("test error"), // should be normalized
"err2": validation.ErrRequired,
"err3": validation.Errors{
"sub1": errors.New("test error"), // should be normalized
"sub2": validation.ErrRequired,
"sub3": validation.Errors{
"sub11": validation.ErrRequired,
},
},
},
)
result, _ := json.Marshal(e)
expected := `{"code":300,"message":"Message_test.","data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"sub1":{"code":"validation_invalid_value","message":"Invalid value."},"sub2":{"code":"validation_required","message":"Cannot be blank."},"sub3":{"sub11":{"code":"validation_required","message":"Cannot be blank."}}}}}`
if string(result) != expected {
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() == nil {
t.Error("Expected non-nil rawData")
}
}
func TestNewNotFoundError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":404,"message":"The requested resource wasn't found.","data":{}}`},
{"demo", "rawData_test", `{"code":404,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":404,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewNotFoundError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}
func TestNewBadRequestError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
{"demo", "rawData_test", `{"code":400,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":400,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewBadRequestError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}
func TestNewForbiddenError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":403,"message":"You are not allowed to perform this request.","data":{}}`},
{"demo", "rawData_test", `{"code":403,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":403,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewForbiddenError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}
func TestNewUnauthorizedError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":401,"message":"Missing or invalid authentication token.","data":{}}`},
{"demo", "rawData_test", `{"code":401,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":401,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewUnauthorizedError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}

View File

@ -6,42 +6,37 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms" "github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/types" "github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast" "github.com/spf13/cast"
) )
// bindBackupApi registers the file api endpoints and the corresponding handlers. // bindBackupApi registers the file api endpoints and the corresponding handlers.
// func bindBackupApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
// @todo add hooks once the app hooks api restructuring is finalized sub := rg.Group("/backups")
func bindBackupApi(app core.App, rg *echo.Group) { sub.GET("", backupsList).Bind(RequireSuperuserAuth())
api := backupApi{app: app} sub.POST("", backupCreate).Bind(RequireSuperuserAuth())
sub.POST("/upload", backupUpload).Bind(RequireSuperuserAuthOnlyIfAny())
subGroup := rg.Group("/backups", ActivityLogger(app)) sub.GET("/{key}", backupDownload) // relies on superuser file token
subGroup.GET("", api.list, RequireAdminAuth()) sub.DELETE("/{key}", backupDelete).Bind(RequireSuperuserAuth())
subGroup.POST("", api.create, RequireAdminAuth()) sub.POST("/{key}/restore", backupRestore).Bind(RequireSuperuserAuthOnlyIfAny())
subGroup.POST("/upload", api.upload, RequireAdminAuth())
subGroup.GET("/:key", api.download)
subGroup.DELETE("/:key", api.delete, RequireAdminAuth())
subGroup.POST("/:key/restore", api.restore, RequireAdminAuth())
} }
type backupApi struct { type backupFileInfo struct {
app core.App Modified types.DateTime `json:"modified"`
Key string `json:"key"`
Size int64 `json:"size"`
} }
func (api *backupApi) list(c echo.Context) error { func backupsList(e *core.RequestEvent) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
fsys, err := api.app.NewBackupsFilesystem() fsys, err := e.App.NewBackupsFilesystem()
if err != nil { if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err) return e.BadRequestError("Failed to load backups filesystem.", err)
} }
defer fsys.Close() defer fsys.Close()
@ -49,166 +44,112 @@ func (api *backupApi) list(c echo.Context) error {
backups, err := fsys.List("") backups, err := fsys.List("")
if err != nil { if err != nil {
return NewBadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil) return e.BadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
} }
result := make([]models.BackupFileInfo, len(backups)) result := make([]backupFileInfo, len(backups))
for i, obj := range backups { for i, obj := range backups {
modified, _ := types.ParseDateTime(obj.ModTime) modified, _ := types.ParseDateTime(obj.ModTime)
result[i] = models.BackupFileInfo{ result[i] = backupFileInfo{
Key: obj.Key, Key: obj.Key,
Size: obj.Size, Size: obj.Size,
Modified: modified, Modified: modified,
} }
} }
return c.JSON(http.StatusOK, result) return e.JSON(http.StatusOK, result)
} }
func (api *backupApi) create(c echo.Context) error { func backupDownload(e *core.RequestEvent) error {
if api.app.Store().Has(core.StoreKeyActiveBackup) { fileToken := e.Request.URL.Query().Get("token")
return NewBadRequestError("Try again later - another backup/restore process has already been started", nil)
}
form := forms.NewBackupCreate(api.app) authRecord, err := e.App.FindAuthRecordByToken(fileToken, core.TokenTypeFile)
if err := c.Bind(form); err != nil { if err != nil || !authRecord.IsSuperuser() {
return NewBadRequestError("An error occurred while loading the submitted data.", err) return e.ForbiddenError("Insufficient permissions to access the resource.", err)
}
return form.Submit(func(next forms.InterceptorNextFunc[string]) forms.InterceptorNextFunc[string] {
return func(name string) error {
if err := next(name); err != nil {
return NewBadRequestError("Failed to create backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return c.NoContent(http.StatusNoContent)
}
})
}
func (api *backupApi) upload(c echo.Context) error {
files, err := rest.FindUploadedFiles(c.Request(), "file")
if err != nil {
return NewBadRequestError("Missing or invalid uploaded file.", err)
}
form := forms.NewBackupUpload(api.app)
form.File = files[0]
return form.Submit(func(next forms.InterceptorNextFunc[*filesystem.File]) forms.InterceptorNextFunc[*filesystem.File] {
return func(file *filesystem.File) error {
if err := next(file); err != nil {
return NewBadRequestError("Failed to upload backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return c.NoContent(http.StatusNoContent)
}
})
}
func (api *backupApi) download(c echo.Context) error {
fileToken := c.QueryParam("token")
_, err := api.app.Dao().FindAdminByToken(
fileToken,
api.app.Settings().AdminFileToken.Secret,
)
if err != nil {
return NewForbiddenError("Insufficient permissions to access the resource.", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel() defer cancel()
fsys, err := api.app.NewBackupsFilesystem() fsys, err := e.App.NewBackupsFilesystem()
if err != nil { if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err) return e.InternalServerError("Failed to load backups filesystem.", err)
} }
defer fsys.Close() defer fsys.Close()
fsys.SetContext(ctx) fsys.SetContext(ctx)
key := c.PathParam("key") key := e.Request.PathValue("key")
br, err := fsys.GetFile(key)
if err != nil {
return NewBadRequestError("Failed to retrieve backup item. Raw error: \n"+err.Error(), nil)
}
defer br.Close()
return fsys.Serve( return fsys.Serve(
c.Response(), e.Response,
c.Request(), e.Request,
key, key,
filepath.Base(key), // without the path prefix (if any) filepath.Base(key), // without the path prefix (if any)
) )
} }
func (api *backupApi) restore(c echo.Context) error { func backupDelete(e *core.RequestEvent) error {
if api.app.Store().Has(core.StoreKeyActiveBackup) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
return NewBadRequestError("Try again later - another backup/restore process has already been started.", nil) defer cancel()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
key := e.Request.PathValue("key")
if key != "" && cast.ToString(e.App.Store().Get(core.StoreKeyActiveBackup)) == key {
return e.BadRequestError("The backup is currently being used and cannot be deleted.", nil)
} }
key := c.PathParam("key") if err := fsys.Delete(key); err != nil {
return e.BadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
}
return e.NoContent(http.StatusNoContent)
}
func backupRestore(e *core.RequestEvent) error {
if e.App.Store().Has(core.StoreKeyActiveBackup) {
return e.BadRequestError("Try again later - another backup/restore process has already been started.", nil)
}
key := e.Request.PathValue("key")
existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
fsys, err := api.app.NewBackupsFilesystem() fsys, err := e.App.NewBackupsFilesystem()
if err != nil { if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err) return e.InternalServerError("Failed to load backups filesystem.", err)
} }
defer fsys.Close() defer fsys.Close()
fsys.SetContext(existsCtx) fsys.SetContext(existsCtx)
if exists, err := fsys.Exists(key); !exists { if exists, err := fsys.Exists(key); !exists {
return NewBadRequestError("Missing or invalid backup file.", err) return e.BadRequestError("Missing or invalid backup file.", err)
} }
go func() { routine.FireAndForget(func() {
// wait max 15 minutes to fetch the backup // give some optimistic time to write the response before restarting the app
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
// give some optimistic time to write the response
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if err := api.app.RestoreBackup(ctx, key); err != nil { // wait max 10 minutes to fetch the backup
api.app.Logger().Error("Failed to restore backup", "key", key, "error", err.Error()) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
if err := e.App.RestoreBackup(ctx, key); err != nil {
e.App.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
} }
}() })
return c.NoContent(http.StatusNoContent) return e.NoContent(http.StatusNoContent)
}
func (api *backupApi) delete(c echo.Context) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := api.app.NewBackupsFilesystem()
if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
key := c.PathParam("key")
if key != "" && cast.ToString(api.app.Store().Get(core.StoreKeyActiveBackup)) == key {
return NewBadRequestError("The backup is currently being used and cannot be deleted.", nil)
}
if err := fsys.Delete(key); err != nil {
return NewBadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
}
return c.NoContent(http.StatusNoContent)
} }

78
apis/backup_create.go Normal file
View File

@ -0,0 +1,78 @@
package apis
import (
"context"
"net/http"
"regexp"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func backupCreate(e *core.RequestEvent) error {
if e.App.Store().Has(core.StoreKeyActiveBackup) {
return e.BadRequestError("Try again later - another backup/restore process has already been started", nil)
}
form := new(backupCreateForm)
form.app = e.App
err := e.BindBody(form)
if err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
err = form.validate()
if err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
err = e.App.CreateBackup(context.Background(), form.Name)
if err != nil {
return e.BadRequestError("Failed to create backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return e.NoContent(http.StatusNoContent)
}
// -------------------------------------------------------------------
var backupNameRegex = regexp.MustCompile(`^[a-z0-9_-]+\.zip$`)
type backupCreateForm struct {
app core.App
Name string `form:"name" json:"name"`
}
func (form *backupCreateForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(
&form.Name,
validation.Length(1, 150),
validation.Match(backupNameRegex),
validation.By(form.checkUniqueName),
),
)
}
func (form *backupCreateForm) checkUniqueName(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
fsys, err := form.app.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
if exists, err := fsys.Exists(v); err != nil || exists {
return validation.NewError("validation_backup_name_exists", "The backup file name is invalid or already exists.")
}
return nil
}

View File

@ -10,7 +10,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
"gocloud.dev/blob" "gocloud.dev/blob"
@ -23,50 +22,51 @@ func TestBackupsList(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups", URL: "/api/backups",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups", URL: "/api/backups",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (empty list)", Name: "authorized as superuser (empty list)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups", URL: "/api/backups",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`[]`,
}, },
ExpectedStatus: 200,
ExpectedContent: []string{`[]`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin", Name: "authorized as superuser",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups", URL: "/api/backups",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -77,6 +77,7 @@ func TestBackupsList(t *testing.T) {
`"test2.zip"`, `"test2.zip"`,
`"test3.zip"`, `"test3.zip"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -92,50 +93,53 @@ func TestBackupsCreate(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups", URL: "/api/backups",
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups", URL: "/api/backups",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (pending backup)", Name: "authorized as superuser (pending backup)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups", URL: "/api/backups",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Store().Set(core.StoreKeyActiveBackup, "") app.Store().Set(core.StoreKeyActiveBackup, "")
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (autogenerated name)", Name: "authorized as superuser (autogenerated name)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups", URL: "/api/backups",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app) files, err := getBackupFiles(app)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -151,16 +155,20 @@ func TestBackupsCreate(t *testing.T) {
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnBackupCreate": 1,
},
}, },
{ {
Name: "authorized as admin (invalid name)", Name: "authorized as superuser (invalid name)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups", URL: "/api/backups",
Body: strings.NewReader(`{"name":"!test.zip"}`), Body: strings.NewReader(`{"name":"!test.zip"}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
@ -168,16 +176,17 @@ func TestBackupsCreate(t *testing.T) {
`"data":{`, `"data":{`,
`"name":{"code":"validation_match_invalid"`, `"name":{"code":"validation_match_invalid"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (valid name)", Name: "authorized as superuser (valid name)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups", URL: "/api/backups",
Body: strings.NewReader(`{"name":"test.zip"}`), Body: strings.NewReader(`{"name":"test.zip"}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app) files, err := getBackupFiles(app)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -193,6 +202,10 @@ func TestBackupsCreate(t *testing.T) {
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnBackupCreate": 1,
},
}, },
} }
@ -201,7 +214,7 @@ func TestBackupsCreate(t *testing.T) {
} }
} }
func TestBackupsUpload(t *testing.T) { func TestBackupUpload(t *testing.T) {
t.Parallel() t.Parallel()
// create dummy form data bodies // create dummy form data bodies
@ -243,55 +256,58 @@ func TestBackupsUpload(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/upload", URL: "/api/backups/upload",
Body: bodies[0].buffer, Body: bodies[0].buffer,
RequestHeaders: map[string]string{ Headers: map[string]string{
"Content-Type": bodies[0].contentType, "Content-Type": bodies[0].contentType,
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/upload", URL: "/api/backups/upload",
Body: bodies[1].buffer, Body: bodies[1].buffer,
RequestHeaders: map[string]string{ Headers: map[string]string{
"Content-Type": bodies[1].contentType, "Content-Type": bodies[1].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (missing file)", Name: "authorized as superuser (missing file)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/upload", URL: "/api/backups/upload",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app) ensureNoBackups(t, app)
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{`}, ExpectedContent: []string{`"data":{`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (existing backup name)", Name: "authorized as superuser (existing backup name)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/upload", URL: "/api/backups/upload",
Body: bodies[3].buffer, Body: bodies[3].buffer,
RequestHeaders: map[string]string{ Headers: map[string]string{
"Content-Type": bodies[3].contentType, "Content-Type": bodies[3].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
fsys, err := app.NewBackupsFilesystem() fsys, err := app.NewBackupsFilesystem()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -302,7 +318,7 @@ func TestBackupsUpload(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app) files, _ := getBackupFiles(app)
if total := len(files); total != 1 { if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total) t.Fatalf("Expected %d backup file, got %d", 1, total)
@ -310,23 +326,49 @@ func TestBackupsUpload(t *testing.T) {
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"file":{`}, ExpectedContent: []string{`"data":{"file":{`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (valid file)", Name: "authorized as superuser (valid file)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/upload", URL: "/api/backups/upload",
Body: bodies[4].buffer, Body: bodies[4].buffer,
RequestHeaders: map[string]string{ Headers: map[string]string{
"Content-Type": bodies[4].contentType, "Content-Type": bodies[4].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app) files, _ := getBackupFiles(app)
if total := len(files); total != 1 { if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total) t.Fatalf("Expected %d backup file, got %d", 1, total)
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unauthorized with 0 superusers (valid file)",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bodies[5].buffer,
Headers: map[string]string{
"Content-Type": bodies[5].contentType,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all superusers
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
if err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app)
if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -342,148 +384,159 @@ func TestBackupsDownload(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with record auth header", Name: "with record auth header",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with admin auth header", Name: "with superuser auth header",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with empty or invalid token", Name: "with empty or invalid token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=", URL: "/api/backups/test1.zip?token=",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with valid record auth token", Name: "with valid record auth token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with valid record file token", Name: "with valid record file token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk", URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with valid admin auth token", Name: "with valid superuser auth token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with expired admin file token", Name: "with expired superuser file token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImFkbWluIn0.g7Q_3UX6H--JWJ7yt1Hoe-1ugTX1KpbKzdt0zjGSe-E", URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.hTNDzikwJdcoWrLnRnp7xbaifZ2vuYZ0oOYRHtJfnk4",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with valid admin file token but missing backup name", Name: "with valid superuser file token but missing backup name",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU", URL: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with valid admin file token", Name: "with valid superuser file token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU", URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
`storage/`, "storage/",
`data.db`, "data.db",
`logs.db`, "aux.db",
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "with valid admin file token and backup name with escaped char", Name: "with valid superuser file token and backup name with escaped char",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU", URL: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
`storage/`, "storage/",
`data.db`, "data.db",
`logs.db`, "aux.db",
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -495,7 +548,7 @@ func TestBackupsDownload(t *testing.T) {
func TestBackupsDelete(t *testing.T) { func TestBackupsDelete(t *testing.T) {
t.Parallel() t.Parallel()
noTestBackupFilesChanges := func(t *testing.T, app *tests.TestApp) { noTestBackupFilesChanges := func(t testing.TB, app *tests.TestApp) {
files, err := getBackupFiles(app) files, err := getBackupFiles(app)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -511,62 +564,65 @@ func TestBackupsDelete(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodDelete, Method: http.MethodDelete,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app) noTestBackupFilesChanges(t, app)
}, },
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodDelete, Method: http.MethodDelete,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app) noTestBackupFilesChanges(t, app)
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (missing file)", Name: "authorized as superuser (missing file)",
Method: http.MethodDelete, Method: http.MethodDelete,
Url: "/api/backups/missing.zip", URL: "/api/backups/missing.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app) noTestBackupFilesChanges(t, app)
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (existing file with matching active backup)", Name: "authorized as superuser (existing file with matching active backup)",
Method: http.MethodDelete, Method: http.MethodDelete,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -574,20 +630,21 @@ func TestBackupsDelete(t *testing.T) {
// mock active backup with the same name to delete // mock active backup with the same name to delete
app.Store().Set(core.StoreKeyActiveBackup, "test1.zip") app.Store().Set(core.StoreKeyActiveBackup, "test1.zip")
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app) noTestBackupFilesChanges(t, app)
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (existing file and no matching active backup)", Name: "authorized as superuser (existing file and no matching active backup)",
Method: http.MethodDelete, Method: http.MethodDelete,
Url: "/api/backups/test1.zip", URL: "/api/backups/test1.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -595,7 +652,7 @@ func TestBackupsDelete(t *testing.T) {
// mock active backup with different name // mock active backup with different name
app.Store().Set(core.StoreKeyActiveBackup, "new.zip") app.Store().Set(core.StoreKeyActiveBackup, "new.zip")
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app) files, err := getBackupFiles(app)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -614,20 +671,21 @@ func TestBackupsDelete(t *testing.T) {
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (backup with escaped character)", Name: "authorized as superuser (backup with escaped character)",
Method: http.MethodDelete, Method: http.MethodDelete,
Url: "/api/backups/%40test4.zip", URL: "/api/backups/%40test4.zip",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app) files, err := getBackupFiles(app)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -646,6 +704,7 @@ func TestBackupsDelete(t *testing.T) {
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -661,53 +720,56 @@ func TestBackupsRestore(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/test1.zip/restore", URL: "/api/backups/test1.zip/restore",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/test1.zip/restore", URL: "/api/backups/test1.zip/restore",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (missing file)", Name: "authorized as superuser (missing file)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/missing.zip/restore", URL: "/api/backups/missing.zip/restore",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (active backup process)", Name: "authorized as superuser (active backup process)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/backups/test1.zip/restore", URL: "/api/backups/test1.zip/restore",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil { if err := createTestBackups(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -716,6 +778,26 @@ func TestBackupsRestore(t *testing.T) {
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unauthorized with no superusers (checks only access)",
Method: http.MethodPost,
URL: "/api/backups/missing.zip/restore",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all superusers
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
if err != nil {
t.Fatal(err)
}
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -758,7 +840,7 @@ func getBackupFiles(app core.App) ([]*blob.ListObject, error) {
return fsys.List("") return fsys.List("")
} }
func ensureNoBackups(t *testing.T, app *tests.TestApp) { func ensureNoBackups(t testing.TB, app *tests.TestApp) {
files, err := getBackupFiles(app) files, err := getBackupFiles(app)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

72
apis/backup_upload.go Normal file
View File

@ -0,0 +1,72 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/filesystem"
)
func backupUpload(e *core.RequestEvent) error {
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
form := new(backupUploadForm)
form.fsys = fsys
files, _ := FindUploadedFiles(e.Request, "file")
if len(files) > 0 {
form.File = files[0]
}
err = form.validate()
if err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
err = fsys.UploadFile(form.File, form.File.OriginalName)
if err != nil {
return e.BadRequestError("Failed to upload backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return e.NoContent(http.StatusNoContent)
}
// -------------------------------------------------------------------
type backupUploadForm struct {
fsys *filesystem.System
File *filesystem.File `json:"file"`
}
func (form *backupUploadForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(
&form.File,
validation.Required,
validation.By(validators.UploadedFileMimeType([]string{"application/zip"})),
validation.By(form.checkUniqueName),
),
)
}
func (form *backupUploadForm) checkUniqueName(value any) error {
v, _ := value.(*filesystem.File)
if v == nil {
return nil // nothing to check
}
// note: we use the original name because that is what we upload
if exists, err := form.fsys.Exists(v.OriginalName); err != nil || exists {
return validation.NewError("validation_backup_name_exists", "Backup file with the specified name already exists.")
}
return nil
}

View File

@ -1,266 +1,202 @@
// Package apis implements the default PocketBase api services and middlewares.
package apis package apis
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog"
"net/http" "net/http"
"net/url"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/rest" "github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/ui" "github.com/pocketbase/pocketbase/tools/hook"
"github.com/spf13/cast" "github.com/pocketbase/pocketbase/tools/router"
) )
const trailedAdminPath = "/_/" // StaticWildcardParam is the name of Static handler wildcard parameter.
const StaticWildcardParam = "path"
// InitApi creates a configured echo instance with registered // NewRouter returns a new router instance loaded with the default app middlewares and api routes.
// system and app specific routes and middlewares. func NewRouter(app core.App) (*router.Router[*core.RequestEvent], error) {
func InitApi(app core.App) (*echo.Echo, error) { pbRouter := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*core.RequestEvent, router.EventCleanupFunc) {
e := echo.New() event := new(core.RequestEvent)
e.Debug = false event.Response = w
e.Binder = &rest.MultiBinder{} event.Request = r
e.JSONSerializer = &rest.Serializer{ event.App = app
FieldsParam: fieldsQueryParam,
}
// configure a custom router return event, nil
e.ResetRouterCreator(func(ec *echo.Echo) echo.Router {
return echo.NewRouter(echo.RouterConfig{
UnescapePathParamValues: true,
AllowOverwritingRoute: true,
})
}) })
// default middlewares // register default middlewares
e.Pre(middleware.RemoveTrailingSlashWithConfig(middleware.RemoveTrailingSlashConfig{ pbRouter.Bind(activityLogger())
Skipper: func(c echo.Context) bool { pbRouter.Bind(loadAuthToken())
// enable by default only for the API routes pbRouter.Bind(securityHeaders())
return !strings.HasPrefix(c.Request().URL.Path, "/api/") pbRouter.Bind(rateLimit())
}, pbRouter.Bind(BodyLimit(DefaultMaxBodySize))
}))
e.Pre(LoadAuthContext(app))
e.Use(middleware.Recover())
e.Use(middleware.Secure())
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set(ContextExecStartKey, time.Now())
return next(c) apiGroup := pbRouter.Group("/api")
} bindSettingsApi(app, apiGroup)
}) bindCollectionApi(app, apiGroup)
bindRecordCrudApi(app, apiGroup)
bindRecordAuthApi(app, apiGroup)
bindLogsApi(app, apiGroup)
bindBackupApi(app, apiGroup)
bindFileApi(app, apiGroup)
bindBatchApi(app, apiGroup)
bindRealtimeApi(app, apiGroup)
bindHealthApi(app, apiGroup)
// custom error handler return pbRouter, nil
e.HTTPErrorHandler = func(c echo.Context, err error) {
if err == nil {
return // no error
}
var apiErr *ApiError
if errors.As(err, &apiErr) {
// already an api error...
} else if v := new(echo.HTTPError); errors.As(err, &v) {
msg := fmt.Sprintf("%v", v.Message)
apiErr = NewApiError(v.Code, msg, v)
} else {
if errors.Is(err, sql.ErrNoRows) {
apiErr = NewNotFoundError("", err)
} else {
apiErr = NewBadRequestError("", err)
}
}
logRequest(app, c, apiErr)
if c.Response().Committed {
return // already committed
}
event := new(core.ApiErrorEvent)
event.HttpContext = c
event.Error = apiErr
// send error response
hookErr := app.OnBeforeApiError().Trigger(event, func(e *core.ApiErrorEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
// @see https://github.com/labstack/echo/issues/608
if e.HttpContext.Request().Method == http.MethodHead {
return e.HttpContext.NoContent(apiErr.Code)
}
return e.HttpContext.JSON(apiErr.Code, apiErr)
})
if hookErr == nil {
if err := app.OnAfterApiError().Trigger(event); err != nil {
app.Logger().Debug("OnAfterApiError failure", slog.String("error", err.Error()))
}
} else {
app.Logger().Debug("OnBeforeApiError error (truly rare case, eg. client already disconnected)", slog.String("error", hookErr.Error()))
}
}
// admin ui routes
bindStaticAdminUI(app, e)
// default routes
api := e.Group("/api", eagerRequestInfoCache(app))
bindSettingsApi(app, api)
bindAdminApi(app, api)
bindCollectionApi(app, api)
bindRecordCrudApi(app, api)
bindRecordAuthApi(app, api)
bindFileApi(app, api)
bindRealtimeApi(app, api)
bindLogsApi(app, api)
bindHealthApi(app, api)
bindBackupApi(app, api)
// catch all any route
api.Any("/*", func(c echo.Context) error {
return echo.ErrNotFound
}, ActivityLogger(app))
return e, nil
} }
// StaticDirectoryHandler is similar to `echo.StaticDirectoryHandler` // WrapStdHandler wraps Go [http.Handler] into a PocketBase handler func.
// but without the directory redirect which conflicts with RemoveTrailingSlash middleware. func WrapStdHandler(h http.Handler) hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
h.ServeHTTP(e.Response, e.Request)
return nil
}
}
// WrapStdMiddleware wraps Go [func(http.Handler) http.Handle] into a PocketBase middleware func.
func WrapStdMiddleware(m func(http.Handler) http.Handler) hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) (err error) {
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e.Response = w
e.Request = r
err = e.Next()
})).ServeHTTP(e.Response, e.Request)
return err
}
}
// MustSubFS returns an [fs.FS] corresponding to the subtree rooted at fsys's dir.
//
// This is similar to [fs.Sub] but panics on failure.
func MustSubFS(fsys fs.FS, dir string) fs.FS {
dir = filepath.ToSlash(filepath.Clean(dir)) // ToSlash in case of Windows path
sub, err := fs.Sub(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to create sub FS: %w", err))
}
return sub
}
// Static is a handler function to serve static directory content from fsys.
// //
// If a file resource is missing and indexFallback is set, the request // If a file resource is missing and indexFallback is set, the request
// will be forwarded to the base index.html (useful also for SPA). // will be forwarded to the base index.html (useful for SPA with pretty urls).
// //
// @see https://github.com/labstack/echo/issues/2211 // NB! Expects the route to have a "{path...}" wildcard parameter.
func StaticDirectoryHandler(fileSystem fs.FS, indexFallback bool) echo.HandlerFunc { //
return func(c echo.Context) error { // Special redirects:
p := c.PathParam("*") // - if "path" is a file that ends in index.html, it is redirected to its non-index.html version (eg. /test/index.html -> /test/)
// - if "path" is a directory that has index.html, the index.html file is rendered,
// otherwise if missing - returns 404 or fallback to the root index.html if indexFallback is set
//
// Example:
//
// fsys := os.DirFS("./pb_public")
// router.GET("/files/{path...}", apis.Static(fsys, false))
func Static(fsys fs.FS, indexFallback bool) hook.HandlerFunc[*core.RequestEvent] {
if fsys == nil {
panic("Static: the provided fs.FS argument is nil")
}
// escape url path return func(e *core.RequestEvent) error {
tmpPath, err := url.PathUnescape(p) // disable the activity logger to avoid flooding with messages
if err != nil { //
return fmt.Errorf("failed to unescape path variable: %w", err) // note: errors are still logged
if e.Get(requestEventKeySkipSuccessActivityLog) == nil {
e.Set(requestEventKeySkipSuccessActivityLog, true)
} }
p = tmpPath
// fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid filename := e.Request.PathValue(StaticWildcardParam)
name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) filename = filepath.ToSlash(filepath.Clean(strings.TrimPrefix(filename, "/")))
fileErr := c.FileFS(name, fileSystem) // eagerly check for directory traversal
//
// note: this is just out of an abundance of caution because the fs.FS implementation could be non-std,
// but usually shouldn't be necessary since os.DirFS.Open is expected to fail if the filename starts with dots
if len(filename) > 2 && filename[0] == '.' && filename[1] == '.' && (filename[2] == '/' || filename[2] == '\\') {
if indexFallback && filename != router.IndexPage {
return e.FileFS(fsys, router.IndexPage)
}
return router.ErrFileNotFound
}
if fileErr != nil && indexFallback && errors.Is(fileErr, echo.ErrNotFound) { fi, err := fs.Stat(fsys, filename)
return c.FileFS("index.html", fileSystem) if err != nil {
if indexFallback && filename != router.IndexPage {
return e.FileFS(fsys, router.IndexPage)
}
return router.ErrFileNotFound
}
if fi.IsDir() {
// redirect to a canonical dir url, aka. with trailing slash
if !strings.HasSuffix(e.Request.URL.Path, "/") {
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(e.Request.URL.Path+"/"))
}
} else {
urlPath := e.Request.URL.Path
if strings.HasSuffix(urlPath, "/") {
// redirect to a non-trailing slash file route
urlPath = strings.TrimRight(urlPath, "/")
if len(urlPath) > 0 {
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(urlPath))
}
} else if stripped, ok := strings.CutSuffix(urlPath, router.IndexPage); ok {
// redirect without the index.html
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(stripped))
}
}
fileErr := e.FileFS(fsys, filename)
if fileErr != nil && indexFallback && filename != router.IndexPage && errors.Is(fileErr, router.ErrFileNotFound) {
return e.FileFS(fsys, router.IndexPage)
} }
return fileErr return fileErr
} }
} }
// bindStaticAdminUI registers the endpoints that serves the static admin UI. // safeRedirectPath normalizes the path string by replacing all beginning slashes
func bindStaticAdminUI(app core.App, e *echo.Echo) error { // (`\\`, `//`, `\/`) with a single forward slash to prevent open redirect attacks
// redirect to trailing slash to ensure that relative urls will still work properly func safeRedirectPath(path string) string {
e.GET( if len(path) > 1 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
strings.TrimRight(trailedAdminPath, "/"), path = "/" + strings.TrimLeft(path, `/\`)
func(c echo.Context) error { }
return c.Redirect(http.StatusTemporaryRedirect, strings.TrimLeft(trailedAdminPath, "/")) return path
},
)
// serves static files from the /ui/dist directory
// (similar to echo.StaticFS but with gzip middleware enabled)
e.GET(
trailedAdminPath+"*",
echo.StaticDirectoryHandler(ui.DistDirFS, false),
installerRedirect(app),
uiCacheControl(),
middleware.Gzip(),
)
return nil
} }
func uiCacheControl() echo.MiddlewareFunc { // FindUploadedFiles extracts all form files of "key" from a http request
return func(next echo.HandlerFunc) echo.HandlerFunc { // and returns a slice with filesystem.File instances (if any).
return func(c echo.Context) error { func FindUploadedFiles(r *http.Request, key string) ([]*filesystem.File, error) {
// add default Cache-Control header for all Admin UI resources if r.MultipartForm == nil {
// (ignoring the root admin path) err := r.ParseMultipartForm(router.DefaultMaxMemory)
if c.Request().URL.Path != trailedAdminPath { if err != nil {
c.Response().Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400") return nil, err
}
return next(c)
} }
} }
}
const hasAdminsCacheKey = "@hasAdmins" if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 {
return nil, http.ErrMissingFile
func updateHasAdminsCache(app core.App) error {
total, err := app.Dao().TotalAdmins()
if err != nil {
return err
} }
app.Store().Set(hasAdminsCacheKey, total > 0) result := make([]*filesystem.File, 0, len(r.MultipartForm.File[key]))
return nil for _, fh := range r.MultipartForm.File[key] {
} file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
// installerRedirect redirects the user to the installer admin UI page return nil, err
// when the application needs some preliminary configurations to be done.
func installerRedirect(app core.App) echo.MiddlewareFunc {
// keep hasAdminsCacheKey value up-to-date
app.OnAdminAfterCreateRequest().Add(func(data *core.AdminCreateEvent) error {
return updateHasAdminsCache(app)
})
app.OnAdminAfterDeleteRequest().Add(func(data *core.AdminDeleteEvent) error {
return updateHasAdminsCache(app)
})
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// skip redirect checks for non-root level index.html requests
path := c.Request().URL.Path
if path != trailedAdminPath && path != trailedAdminPath+"index.html" {
return next(c)
}
hasAdmins := cast.ToBool(app.Store().Get(hasAdminsCacheKey))
if !hasAdmins {
// update the cache to make sure that the admin wasn't created by another process
if err := updateHasAdminsCache(app); err != nil {
return err
}
hasAdmins = cast.ToBool(app.Store().Get(hasAdminsCacheKey))
}
_, hasInstallerParam := c.Request().URL.Query()["installer"]
if !hasAdmins && !hasInstallerParam {
// redirect to the installer page
return c.Redirect(http.StatusTemporaryRedirect, "?installer#")
}
if hasAdmins && hasInstallerParam {
// clear the installer param
return c.Redirect(http.StatusTemporaryRedirect, "?")
}
return next(c)
} }
result = append(result, file)
} }
return result, nil
} }

View File

@ -1,422 +1,386 @@
package apis_test package apis_test
import ( import (
"database/sql" "bytes"
"errors"
"fmt" "fmt"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"strings" "strings"
"testing" "testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/rest" "github.com/pocketbase/pocketbase/tools/router"
"github.com/spf13/cast"
) )
func Test404(t *testing.T) { func TestWrapStdHandler(t *testing.T) {
t.Parallel() t.Parallel()
scenarios := []tests.ApiScenario{ app, _ := tests.NewTestApp()
{ defer app.Cleanup()
Method: http.MethodGet,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodPost,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodPatch,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodDelete,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodHead,
Url: "/api/missing",
ExpectedStatus: 404,
},
}
for _, scenario := range scenarios { req := httptest.NewRequest(http.MethodGet, "/", nil)
scenario.Test(t) rec := httptest.NewRecorder()
}
}
func TestCustomRoutesAndErrorsHandling(t *testing.T) { e := new(core.RequestEvent)
t.Parallel() e.App = app
e.Request = req
e.Response = rec
scenarios := []tests.ApiScenario{ err := apis.WrapStdHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
{ w.Write([]byte("test"))
Name: "custom route", }))(e)
Method: http.MethodGet,
Url: "/custom",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "custom route with url encoded parameter",
Method: http.MethodGet,
Url: "/a%2Bb%2Bc",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/:param",
Handler: func(c echo.Context) error {
return c.String(200, c.PathParam("param"))
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"a+b+c"},
},
{
Name: "route with HTTPError",
Method: http.MethodGet,
Url: "/http-error",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/http-error",
Handler: func(c echo.Context) error {
return echo.ErrBadRequest
},
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`{"code":400,"message":"Bad Request.","data":{}}`},
},
{
Name: "route with api error",
Method: http.MethodGet,
Url: "/api-error",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/api-error",
Handler: func(c echo.Context) error {
return apis.NewApiError(500, "test message", errors.New("internal_test"))
},
})
},
ExpectedStatus: 500,
ExpectedContent: []string{`{"code":500,"message":"Test message.","data":{}}`},
},
{
Name: "route with plain error",
Method: http.MethodGet,
Url: "/plain-error",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/plain-error",
Handler: func(c echo.Context) error {
return errors.New("Test error")
},
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRemoveTrailingSlashMiddleware(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "non /api/* route (exact match)",
Method: http.MethodGet,
Url: "/custom",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "non /api/* route (with trailing slash)",
Method: http.MethodGet,
Url: "/custom/",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "/api/* route (exact match)",
Method: http.MethodGet,
Url: "/api/custom",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/api/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "/api/* route (with trailing slash)",
Method: http.MethodGet,
Url: "/api/custom/",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/api/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestMultiBinder(t *testing.T) {
t.Parallel()
rawJson := `{"name":"test123"}`
formData, mp, err := tests.MockMultipartData(map[string]string{
rest.MultipartJsonKey: rawJson,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
scenarios := []tests.ApiScenario{ if body := rec.Body.String(); body != "test" {
{ t.Fatalf("Expected body %q, got %q", "test", body)
Name: "non-api group route",
Method: "POST",
Url: "/custom",
Body: strings.NewReader(rawJson),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: "POST",
Path: "/custom",
Handler: func(c echo.Context) error {
data := &struct {
Name string `json:"name"`
}{}
if err := c.Bind(data); err != nil {
return err
}
// try to read the body again
r := apis.RequestInfo(c)
if v := cast.ToString(r.Data["name"]); v != "test123" {
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
}
return c.NoContent(200)
},
})
},
ExpectedStatus: 200,
},
{
Name: "api group route",
Method: "GET",
Url: "/api/admins",
Body: strings.NewReader(rawJson),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// it is not important whether the route handler return an error since
// we just need to ensure that the eagerRequestInfoCache was registered
next(c)
// ensure that the body was read at least once
data := &struct {
Name string `json:"name"`
}{}
c.Bind(data)
// try to read the body again
r := apis.RequestInfo(c)
if v := cast.ToString(r.Data["name"]); v != "test123" {
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
}
return nil
}
})
},
ExpectedStatus: 200,
},
{
Name: "custom route with @jsonPayload as multipart body",
Method: "POST",
Url: "/custom",
Body: formData,
RequestHeaders: map[string]string{
"Content-Type": mp.FormDataContentType(),
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: "POST",
Path: "/custom",
Handler: func(c echo.Context) error {
data := &struct {
Name string `json:"name"`
}{}
if err := c.Bind(data); err != nil {
return err
}
// try to read the body again
r := apis.RequestInfo(c)
if v := cast.ToString(r.Data["name"]); v != "test123" {
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
}
return c.NoContent(200)
},
})
},
ExpectedStatus: 200,
},
}
for _, scenario := range scenarios {
scenario.Test(t)
} }
} }
func TestErrorHandler(t *testing.T) { func TestWrapStdMiddleware(t *testing.T) {
t.Parallel() t.Parallel()
scenarios := []tests.ApiScenario{ app, _ := tests.NewTestApp()
defer app.Cleanup()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.WrapStdMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})
})(e)
if err != nil {
t.Fatal(err)
}
if body := rec.Body.String(); body != "test" {
t.Fatalf("Expected body %q, got %q", "test", body)
}
}
func TestStatic(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
dir := createTestDir(t)
defer os.RemoveAll(dir)
fsys := os.DirFS(filepath.Join(dir, "sub"))
type staticScenario struct {
path string
indexFallback bool
expectedStatus int
expectBody string
expectError bool
}
scenarios := []staticScenario{
{ {
Name: "apis.ApiError", path: "",
Method: http.MethodGet, indexFallback: false,
Url: "/test", expectedStatus: 200,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectBody: "sub index.html",
e.GET("/test", func(c echo.Context) error { expectError: false,
return apis.NewApiError(418, "test", nil)
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
}, },
{ {
Name: "wrapped apis.ApiError", path: "missing/a/b/c",
Method: http.MethodGet, indexFallback: false,
Url: "/test", expectedStatus: 404,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectBody: "",
e.GET("/test", func(c echo.Context) error { expectError: true,
return fmt.Errorf("example 123: %w", apis.NewApiError(418, "test", nil))
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
NotExpectedContent: []string{"example", "123"},
}, },
{ {
Name: "echo.HTTPError", path: "missing/a/b/c",
Method: http.MethodGet, indexFallback: true,
Url: "/test", expectedStatus: 200,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectBody: "sub index.html",
e.GET("/test", func(c echo.Context) error { expectError: false,
return echo.NewHTTPError(418, "test")
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
}, },
{ {
Name: "wrapped echo.HTTPError", path: "testroot", // parent directory file
Method: http.MethodGet, indexFallback: false,
Url: "/test", expectedStatus: 404,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectBody: "",
e.GET("/test", func(c echo.Context) error { expectError: true,
return fmt.Errorf("example 123: %w", echo.NewHTTPError(418, "test"))
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
NotExpectedContent: []string{"example", "123"},
}, },
{ {
Name: "wrapped sql.ErrNoRows", path: "test",
Method: http.MethodGet, indexFallback: false,
Url: "/test", expectedStatus: 200,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectBody: "sub test",
e.GET("/test", func(c echo.Context) error { expectError: false,
return fmt.Errorf("example 123: %w", sql.ErrNoRows)
})
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
NotExpectedContent: []string{"example", "123"},
}, },
{ {
Name: "custom error", path: "sub2",
Method: http.MethodGet, indexFallback: false,
Url: "/test", expectedStatus: 301,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectBody: "",
e.GET("/test", func(c echo.Context) error { expectError: false,
return fmt.Errorf("example 123") },
}) {
}, path: "sub2/",
ExpectedStatus: 400, indexFallback: false,
ExpectedContent: []string{`"data":{}`}, expectedStatus: 200,
NotExpectedContent: []string{"example", "123"}, expectBody: "sub2 index.html",
expectError: false,
},
{
path: "sub2/test",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub2 test",
expectError: false,
},
{
path: "sub2/test/",
indexFallback: false,
expectedStatus: 301,
expectBody: "",
expectError: false,
}, },
} }
for _, scenario := range scenarios { // extra directory traversal checks
scenario.Test(t) dtp := []string{
"/../",
"\\../",
"../",
"../../",
"..\\",
"..\\..\\",
"../..\\",
"..\\..//",
`%2e%2e%2f`,
`%2e%2e%2f%2e%2e%2f`,
`%2e%2e/`,
`%2e%2e/%2e%2e/`,
`..%2f`,
`..%2f..%2f`,
`%2e%2e%5c`,
`%2e%2e%5c%2e%2e%5c`,
`%2e%2e\`,
`%2e%2e\%2e%2e\`,
`..%5c`,
`..%5c..%5c`,
`%252e%252e%255c`,
`%252e%252e%255c%252e%252e%255c`,
`..%255c`,
`..%255c..%255c`,
}
for _, p := range dtp {
scenarios = append(scenarios,
staticScenario{
path: p + "testroot",
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
staticScenario{
path: p + "testroot",
indexFallback: true,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%v", i, s.path, s.indexFallback), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+s.path, nil)
req.SetPathValue(apis.StaticWildcardParam, s.path)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.Static(fsys, s.indexFallback)(e)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
body := rec.Body.String()
if body != s.expectBody {
t.Fatalf("Expected body %q, got %q", s.expectBody, body)
}
if hasErr {
apiErr := router.ToApiError(err)
if apiErr.Status != s.expectedStatus {
t.Fatalf("Expected status code %d, got %d", s.expectedStatus, apiErr.Status)
}
}
})
} }
} }
func TestFindUploadedFiles(t *testing.T) {
scenarios := []struct {
filename string
expectedPattern string
}{
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
{"test", `^test_\w{10}\.txt$`},
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
}
for _, s := range scenarios {
t.Run(s.filename, func(t *testing.T) {
// create multipart form file body
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
w, err := mp.CreateFormFile("test", s.filename)
if err != nil {
t.Fatal(err)
}
w.Write([]byte("test"))
mp.Close()
// ---
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
result, err := apis.FindUploadedFiles(req, "test")
if err != nil {
t.Fatal(err)
}
if len(result) != 1 {
t.Fatalf("Expected 1 file, got %d", len(result))
}
if result[0].Size != 4 {
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size)
}
pattern, err := regexp.Compile(s.expectedPattern)
if err != nil {
t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err)
}
if !pattern.MatchString(result[0].Name) {
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
}
})
}
}
func TestFindUploadedFilesMissing(t *testing.T) {
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
mp.Close()
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
result, err := apis.FindUploadedFiles(req, "test")
if err == nil {
t.Error("Expected error, got nil")
}
if result != nil {
t.Errorf("Expected result to be nil, got %v", result)
}
}
func TestMustSubFS(t *testing.T) {
t.Parallel()
dir := createTestDir(t)
defer os.RemoveAll(dir)
// invalid path (no beginning and ending slashes)
if !hasPanicked(func() {
apis.MustSubFS(os.DirFS(dir), "/test/")
}) {
t.Fatalf("Expected to panic")
}
// valid path
if hasPanicked(func() {
apis.MustSubFS(os.DirFS(dir), "./////a/b/c") // checks if ToSlash was called
}) {
t.Fatalf("Didn't expect to panic")
}
// check sub content
sub := apis.MustSubFS(os.DirFS(dir), "sub")
_, err := sub.Open("test")
if err != nil {
t.Fatalf("Missing expected file sub/test")
}
}
// -------------------------------------------------------------------
func hasPanicked(f func()) (didPanic bool) {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
f()
return
}
// note: make sure to call os.RemoveAll(dir) after you are done
// working with the created test dir.
func createTestDir(t *testing.T) string {
dir, err := os.MkdirTemp(os.TempDir(), "test_dir")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("root index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "testroot"), []byte("root test"), 0644); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "sub"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/index.html"), []byte("sub index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/test"), []byte("sub test"), 0644); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "sub", "sub2"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/index.html"), []byte("sub2 index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/test"), []byte("sub2 test"), 0644); err != nil {
t.Fatal(err)
}
return dir
}

542
apis/batch.go Normal file
View File

@ -0,0 +1,542 @@
package apis
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/batch")
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
}
type HandleFunc func(e *core.RequestEvent) error
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
//
// Note: when adding new routes make sure that their middlewares are inlined!
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
// "upsert" handler
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
var id string
if len(ir.Body) > 0 && ir.Body["id"] != "" {
id = cast.ToString(ir.Body["id"])
}
if id != "" {
_, err := app.FindRecordById(params["collection"], id)
if err == nil {
// update
// ---
params["id"] = id // required for the path value
ir.Method = "PATCH"
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
return recordUpdate(next)
}
}
// create
// ---
ir.Method = "POST"
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
return recordCreate(next)
},
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
return recordCreate(next)
},
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
return recordUpdate(next)
},
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
return recordDelete(next)
},
}
type BatchRequestResult struct {
Body any `json:"body"`
Status int `json:"status"`
}
type batchRequestsForm struct {
Requests []*core.InternalRequest `form:"requests" json:"requests"`
max int
}
func (brs batchRequestsForm) validate() error {
return validation.ValidateStruct(&brs,
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
)
}
// NB! When the request is submitted as multipart/form-data,
// the regular fields data is expected to be submitted as serailized
// json under the @jsonPayload field and file keys need to follow the
// pattern "requests.N.fileField" or requests[N].fileField.
func batchTransaction(e *core.RequestEvent) error {
maxRequests := e.App.Settings().Batch.MaxRequests
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
return e.ForbiddenError("Batch requests are not allowed.", nil)
}
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
if txTimeout <= 0 {
txTimeout = 3 * time.Second // for now always limit
}
maxBodySize := e.App.Settings().Batch.MaxBodySize
if maxBodySize <= 0 {
maxBodySize = 128 << 20
}
err := applyBodyLimit(e, maxBodySize)
if err != nil {
return err
}
form := &batchRequestsForm{max: maxRequests}
// load base requests data
err = e.BindBody(form)
if err != nil {
return e.BadRequestError("Failed to read the submitted batch data.", err)
}
// load uploaded files into each request item
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
// (the other regular fields must be put under `@jsonPayload` as serialized json)
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
for i, ir := range form.Requests {
iStr := strconv.Itoa(i)
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
if err != nil {
return e.BadRequestError("Failed to read the submitted batch files data.", err)
}
for key, files := range files {
if ir.Body == nil {
ir.Body = map[string]any{}
}
ir.Body[key] = files
}
}
}
// validate batch request form
err = form.validate()
if err != nil {
return e.BadRequestError("Invalid batch request data.", err)
}
event := new(core.BatchRequestEvent)
event.RequestEvent = e
event.Batch = form.Requests
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
bp := batchProcessor{
app: e.App,
baseEvent: e.RequestEvent,
infoContext: core.RequestInfoContextBatch,
}
if err := bp.Process(e.Batch, txTimeout); err != nil {
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
}
return e.JSON(http.StatusOK, bp.results)
})
}
type batchProcessor struct {
app core.App
baseEvent *core.RequestEvent
infoContext string
results []*BatchRequestResult
failedIndex int
errCh chan error
stopCh chan struct{}
}
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
p.results = make([]*BatchRequestResult, 0, len(batch))
if p.stopCh != nil {
close(p.stopCh)
}
p.stopCh = make(chan struct{}, 1)
if p.errCh != nil {
close(p.errCh)
}
p.errCh = make(chan error, 1)
return p.app.RunInTransaction(func(txApp core.App) error {
// used to interupts the recursive processing calls in case of a timeout or connection close
defer func() {
p.stopCh <- struct{}{}
}()
go func() {
err := p.process(txApp, batch, 0)
if err != nil {
err = validation.Errors{
"requests": validation.Errors{
strconv.Itoa(p.failedIndex): &BatchResponseError{
code: "batch_request_failed",
message: "Batch request failed.",
err: router.ToApiError(err),
},
},
}
}
// note: to avoid copying and due to the process recursion the final results order is reversed
if err == nil {
slices.Reverse(p.results)
}
p.errCh <- err
}()
select {
case responseErr := <-p.errCh:
return responseErr
case <-time.After(timeout):
// note: we don't return 408 Reques Timeout error because
// some browsers perform automatic retry behind the scenes
// which are hard to debug and unnecessary
return errors.New("batch transaction timeout")
case <-p.baseEvent.Request.Context().Done():
return errors.New("batch request interrupted")
}
})
}
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
select {
case <-p.stopCh:
return nil
default:
if len(batch) == 0 {
return nil
}
result, err := processInternalRequest(
activeApp,
p.baseEvent,
batch[0],
p.infoContext,
func() error {
if len(batch) == 1 {
return nil
}
err := p.process(activeApp, batch[1:], i+1)
// update the failed batch index (if not already)
if err != nil && p.failedIndex == 0 {
p.failedIndex = i + 1
}
return err
},
)
if err != nil {
return err
}
p.results = append(p.results, result)
return nil
}
}
func processInternalRequest(
activeApp core.App,
baseEvent *core.RequestEvent,
ir *core.InternalRequest,
infoContext string,
optNext func() error,
) (*BatchRequestResult, error) {
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
if !ok {
return nil, errors.New("unknown batch request action")
}
// construct a new http.Request
// ---------------------------------------------------------------
buf, mw, err := multipartDataFromInternalRequest(ir)
if err != nil {
return nil, err
}
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
if err != nil {
return nil, err
}
// cleanup multipart temp files
defer func() {
if r.MultipartForm != nil {
if err := r.MultipartForm.RemoveAll(); err != nil {
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
}
}
}()
// load batch request path params
// ---
for k, v := range params {
r.SetPathValue(k, v)
}
// clone original request
// ---
r.RequestURI = r.URL.RequestURI()
r.Proto = baseEvent.Request.Proto
r.ProtoMajor = baseEvent.Request.ProtoMajor
r.ProtoMinor = baseEvent.Request.ProtoMinor
r.Host = baseEvent.Request.Host
r.RemoteAddr = baseEvent.Request.RemoteAddr
r.TLS = baseEvent.Request.TLS
if s := baseEvent.Request.TransferEncoding; s != nil {
s2 := make([]string, len(s))
copy(s2, s)
r.TransferEncoding = s2
}
if baseEvent.Request.Trailer != nil {
r.Trailer = baseEvent.Request.Trailer.Clone()
}
if baseEvent.Request.Header != nil {
r.Header = baseEvent.Request.Header.Clone()
}
// apply batch request specific headers
// ---
for k, v := range ir.Headers {
r.Header.Set(k, v)
}
r.Header.Set("Content-Type", mw.FormDataContentType())
// construct a new RequestEvent
// ---------------------------------------------------------------
event := &core.RequestEvent{}
event.App = activeApp
event.Auth = baseEvent.Auth
event.SetAll(baseEvent.GetAll())
// load RequestInfo context
if infoContext == "" {
infoContext = core.RequestInfoContextDefault
}
event.Set(core.RequestEventKeyInfoContext, infoContext)
// assign request
event.Request = r
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
// assign response
rec := httptest.NewRecorder()
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
// execute
// ---------------------------------------------------------------
if err := handle(event); err != nil {
return nil, err
}
result := rec.Result()
defer result.Body.Close()
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
return &BatchRequestResult{
Status: result.StatusCode,
Body: body,
}, nil
}
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
buf := &bytes.Buffer{}
mw := multipart.NewWriter(buf)
regularFields := map[string]any{}
fileFields := map[string][]*filesystem.File{}
// separate regular fields from files
// ---
for k, rawV := range ir.Body {
switch v := rawV.(type) {
case *filesystem.File:
fileFields[k] = append(fileFields[k], v)
case []*filesystem.File:
fileFields[k] = append(fileFields[k], v...)
default:
regularFields[k] = v
}
}
// submit regularFields as @jsonPayload
// ---
rawBody, err := json.Marshal(regularFields)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
jsonPayload, err := mw.CreateFormField("@jsonPayload")
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = jsonPayload.Write(rawBody)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
// submit fileFields as multipart files
// ---
for key, files := range fileFields {
for _, file := range files {
part, err := mw.CreateFormFile(key, file.Name)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
fr, err := file.Reader.Open()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = io.Copy(part, fr)
if err != nil {
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
}
err = fr.Close()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
}
}
return buf, mw, mw.Close()
}
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
if request.MultipartForm == nil {
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
return nil, err
}
}
result := make(map[string][]*filesystem.File)
for k, fhs := range request.MultipartForm.File {
for _, p := range prefixes {
if strings.HasPrefix(k, p) {
resultKey := strings.TrimPrefix(k, p)
for _, fh := range fhs {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result[resultKey] = append(result[resultKey], file)
}
}
}
}
return result, nil
}
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func() error) (HandleFunc, map[string]string, bool) {
full := strings.ToUpper(ir.Method) + " " + ir.URL
for re, actionFactory := range ValidBatchActions {
params, ok := findNamedMatches(re, full)
if ok {
return actionFactory(activeApp, ir, params, optNext), params, true
}
}
return nil, nil, false
}
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
match := re.FindStringSubmatch(str)
if match == nil {
return nil, false
}
result := map[string]string{}
names := re.SubexpNames()
for i, m := range match {
if names[i] != "" {
result[names[i]] = m
}
}
return result, true
}
// -------------------------------------------------------------------
var (
_ router.SafeErrorItem = (*BatchResponseError)(nil)
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
)
type BatchResponseError struct {
err *router.ApiError
code string
message string
}
func (e *BatchResponseError) Error() string {
return e.message
}
func (e *BatchResponseError) Code() string {
return e.code
}
func (e *BatchResponseError) Resolve(errData map[string]any) any {
errData["response"] = e.err
return errData
}
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"message": e.message,
"code": e.code,
"response": e.err,
})
}

691
apis/batch_test.go Normal file
View File

@ -0,0 +1,691 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestBatchRequest(t *testing.T) {
t.Parallel()
formData, mp, err := tests.MockMultipartData(
map[string]string{
router.JSONPayloadKey: `{
"requests":[
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch2"}},
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch3"}},
{"method":"PATCH", "url":"/api/collections/demo3/records/lcl9d87w22ml6jy", "body": {"files-": "test_FLurQTgrY8.txt"}}
]
}`,
},
"requests.0.files",
"requests.0.files",
"requests.0.files",
"requests[2].files",
)
if err != nil {
t.Fatal(err)
}
scenarios := []tests.ApiScenario{
{
Name: "disabled batch requets",
Method: http.MethodPost,
URL: "/api/batch",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = false
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "max request limits reached",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"GET", "url":"/test1"},
{"method":"GET", "url":"/test2"},
{"method":"GET", "url":"/test3"}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = true
app.Settings().Batch.MaxRequests = 2
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{"code":"validation_length_too_long"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "trigger requests validations",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{},
{"method":"GET", "url":"/valid"},
{"method":"invalid", "url":"/valid"},
{"method":"POST", "url":"` + strings.Repeat("a", 2001) + `"}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = true
app.Settings().Batch.MaxRequests = 100
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`"0":{"method":{"code":"validation_required"`,
`"2":{"method":{"code":"validation_in_invalid"`,
`"3":{"url":{"code":"validation_length_too_long"`,
},
NotExpectedContent: []string{
`"1":`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unknown batch request action",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"GET", "url":"/api/health"}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`0":{"code":"batch_request_failed"`,
`"response":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
},
},
{
Name: "base 2 successful and 1 failed (public collection)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": ""}}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"response":{`,
`"2":{"code":"batch_request_failed"`,
`"response":{"data":{"title":{"code":"validation_required"`,
},
NotExpectedContent: []string{
`"0":`,
`"1":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnRecordCreateRequest": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 2,
"OnModelAfterCreateError": 3,
"OnModelValidate": 3,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 2,
"OnRecordAfterCreateError": 3,
"OnRecordValidate": 3,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatalf("Expected no batch records to be persisted, got %d", len(records))
}
},
},
{
Name: "base 4 successful (public collection)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
{"method":"PUT", "url":"/api/collections/demo2/records", "body": {"title": "batch3"}},
{"method":"PUT", "url":"/api/collections/demo2/records?fields=*,id:excerpt(4,true)", "body": {"id":"achvryl401bhse3","title": "batch4"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch1"`,
`"title":"batch2"`,
`"title":"batch3"`,
`"title":"batch4"`,
`"id":"achv..."`,
`"active":false`,
`"active":true`,
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnModelValidate": 4,
"OnRecordValidate": 4,
"OnRecordEnrich": 4,
"OnRecordCreateRequest": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 3,
"OnRecordAfterCreateSuccess": 3,
"OnRecordUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 4 {
t.Fatalf("Expected %d batch records to be persisted, got %d", 3, len(records))
}
},
},
{
Name: "mixed create/update/delete (rules failure)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`"2":{"code":"batch_request_failed"`,
`"response":{`,
},
NotExpectedContent: []string{
// only demo3 requires authentication
`"0":`,
`"1":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateError": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteError": 1,
"OnModelValidate": 1,
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateError": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteError": 1,
"OnRecordEnrich": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err == nil {
t.Fatal("Expected record to not be created")
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err == nil {
t.Fatal("Expected record to not be updated")
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err != nil {
t.Fatal("Expected record to not be deleted")
}
},
},
{
Name: "mixed create/update/delete (rules success)",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, clients
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch_create"`,
`"title":"batch_update"`,
`"status":200`,
`"status":204`,
`"body":{`,
`"body":null`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdateRequest": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err == nil {
t.Fatal("Expected record to be deleted")
}
},
},
{
Name: "mixed create/update/delete (superuser auth)",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch_create"`,
`"title":"batch_update"`,
`"status":200`,
`"status":204`,
`"body":{`,
`"body":null`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdateRequest": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err == nil {
t.Fatal("Expected record to be deleted")
}
},
},
{
Name: "cascade delete/update",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"DELETE", "url":"/api/collections/demo3/records/1tmknxy2868d869"},
{"method":"DELETE", "url":"/api/collections/demo3/records/mk5fmymtx4wsprk"}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"status":204`,
`"body":null`,
},
NotExpectedContent: []string{
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelDelete": 3, // 2 batch + 1 cascade delete
"OnModelDeleteExecute": 3,
"OnModelAfterDeleteSuccess": 3,
"OnModelUpdate": 5, // 5 cascade update
"OnModelUpdateExecute": 5,
"OnModelAfterUpdateSuccess": 5,
// ---
"OnRecordDeleteRequest": 2,
"OnRecordDelete": 3,
"OnRecordDeleteExecute": 3,
"OnRecordAfterDeleteSuccess": 3,
"OnRecordUpdate": 5,
"OnRecordUpdateExecute": 5,
"OnRecordAfterUpdateSuccess": 5,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ids := []string{
"1tmknxy2868d869",
"mk5fmymtx4wsprk",
"qzaqccwrmva4o1n",
}
for _, id := range ids {
_, err := app.FindRecordById("demo2", id)
if err == nil {
t.Fatalf("Expected record %q to be deleted", id)
}
}
},
},
{
Name: "transaction timeout",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}}
]
}`),
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Timeout = 1
app.OnRecordCreateRequest("demo2").BindFunc(func(e *core.RecordRequestEvent) error {
time.Sleep(600 * time.Millisecond) // < 1s so that the first request can succeed
return e.Next()
})
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnRecordCreateRequest": 2,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateError": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateError": 1,
"OnRecordEnrich": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatalf("Expected %d batch records to be persisted, got %d", 0, len(records))
}
},
},
{
Name: "multipart/form-data + file upload",
Method: http.MethodPost,
URL: "/api/batch",
Body: formData,
Headers: map[string]string{
// test@example.com, clients
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
"Content-Type": mp.FormDataContentType(),
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch1"`,
`"title":"batch2"`,
`"title":"batch3"`,
`"id":"lcl9d87w22ml6jy"`,
`"files":["300_UhLKX91HVb.png"]`,
`"tmpfile_`,
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 4,
// ---
"OnRecordCreateRequest": 3,
"OnRecordUpdateRequest": 1,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 3,
"OnRecordAfterCreateSuccess": 3,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 4,
"OnRecordEnrich": 4,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
batch1, err := app.FindFirstRecordByFilter("demo3", `title="batch1"`)
if err != nil {
t.Fatalf("missing batch1: %v", err)
}
batch1Files := batch1.GetStringSlice("files")
if len(batch1Files) != 3 {
t.Fatalf("Expected %d batch1 file(s), got %d", 3, len(batch1Files))
}
batch2, err := app.FindFirstRecordByFilter("demo3", `title="batch2"`)
if err != nil {
t.Fatalf("missing batch2: %v", err)
}
batch2Files := batch2.GetStringSlice("files")
if len(batch2Files) != 0 {
t.Fatalf("Expected %d batch2 file(s), got %d", 0, len(batch2Files))
}
batch3, err := app.FindFirstRecordByFilter("demo3", `title="batch3"`)
if err != nil {
t.Fatalf("missing batch3: %v", err)
}
batch3Files := batch3.GetStringSlice("files")
if len(batch3Files) != 1 {
t.Fatalf("Expected %d batch3 file(s), got %d", 1, len(batch3Files))
}
batch4, err := app.FindRecordById("demo3", "lcl9d87w22ml6jy")
if err != nil {
t.Fatalf("missing batch4: %v", err)
}
batch4Files := batch4.GetStringSlice("files")
if len(batch4Files) != 1 {
t.Fatalf("Expected %d batch4 file(s), got %d", 1, len(batch4Files))
}
},
},
{
Name: "create/update with expand query params",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"body":{`,
`"id":"qjeql998mtp1azp"`,
`"id":"qzaqccwrmva4o1n"`,
`"id":"i9naidtvr6qsgb4"`,
`"expand":{"rel_one"`,
`"expand":{"rel_many"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordUpdateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 5,
},
},
{
Name: "check body limit middleware",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.MaxBodySize = 10
},
ExpectedStatus: 413,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -1,210 +1,186 @@
package apis package apis
import ( import (
"errors"
"net/http" "net/http"
"strings"
"github.com/labstack/echo/v5" validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms" "github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/search" "github.com/pocketbase/pocketbase/tools/search"
) )
// bindCollectionApi registers the collection api endpoints and the corresponding handlers. // bindCollectionApi registers the collection api endpoints and the corresponding handlers.
func bindCollectionApi(app core.App, rg *echo.Group) { func bindCollectionApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := collectionApi{app: app} subGroup := rg.Group("/collections").Bind(RequireSuperuserAuth())
subGroup.GET("", collectionsList)
subGroup := rg.Group("/collections", ActivityLogger(app), RequireAdminAuth()) subGroup.POST("", collectionCreate)
subGroup.GET("", api.list) subGroup.GET("/{collection}", collectionView)
subGroup.POST("", api.create) subGroup.PATCH("/{collection}", collectionUpdate)
subGroup.GET("/:collection", api.view) subGroup.DELETE("/{collection}", collectionDelete)
subGroup.PATCH("/:collection", api.update) subGroup.DELETE("/{collection}/truncate", collectionTruncate)
subGroup.DELETE("/:collection", api.delete) subGroup.PUT("/import", collectionsImport)
subGroup.PUT("/import", api.bulkImport) subGroup.GET("/meta/scaffolds", collectionScaffolds)
} }
type collectionApi struct { func collectionsList(e *core.RequestEvent) error {
app core.App
}
func (api *collectionApi) list(c echo.Context) error {
fieldResolver := search.NewSimpleFieldResolver( fieldResolver := search.NewSimpleFieldResolver(
"id", "created", "updated", "name", "system", "type", "id", "created", "updated", "name", "system", "type",
) )
collections := []*models.Collection{} collections := []*core.Collection{}
result, err := search.NewProvider(fieldResolver). result, err := search.NewProvider(fieldResolver).
Query(api.app.Dao().CollectionQuery()). Query(e.App.CollectionQuery()).
ParseAndExec(c.QueryParams().Encode(), &collections) ParseAndExec(e.Request.URL.Query().Encode(), &collections)
if err != nil { if err != nil {
return NewBadRequestError("", err) return e.BadRequestError("", err)
} }
event := new(core.CollectionsListEvent) event := new(core.CollectionsListRequestEvent)
event.HttpContext = c event.RequestEvent = e
event.Collections = collections event.Collections = collections
event.Result = result event.Result = result
return api.app.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListEvent) error { return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
if e.HttpContext.Response().Committed { return e.JSON(http.StatusOK, e.Result)
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Result)
}) })
} }
func (api *collectionApi) view(c echo.Context) error { func collectionView(e *core.RequestEvent) error {
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection")) collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil { if err != nil || collection == nil {
return NewNotFoundError("", err) return e.NotFoundError("", err)
} }
event := new(core.CollectionViewEvent) event := new(core.CollectionRequestEvent)
event.HttpContext = c event.RequestEvent = e
event.Collection = collection event.Collection = collection
return api.app.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionViewEvent) error { return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if e.HttpContext.Response().Committed { return e.JSON(http.StatusOK, e.Collection)
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Collection)
}) })
} }
func (api *collectionApi) create(c echo.Context) error { func collectionCreate(e *core.RequestEvent) error {
collection := &models.Collection{} // populate the minimal required factory collection data (if any)
factoryExtract := struct {
form := forms.NewCollectionUpsert(api.app, collection) Type string `form:"type" json:"type"`
Name string `form:"name" json:"name"`
// load request }{}
if err := c.Bind(form); err != nil { if err := e.BindBody(&factoryExtract); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err) return e.BadRequestError("Failed to load the collection type data due to invalid formatting.", err)
} }
event := new(core.CollectionCreateEvent) // create scaffold
event.HttpContext = c collection := core.NewCollection(factoryExtract.Type, factoryExtract.Name)
// merge the scaffold with the submitted request data
if err := e.BindBody(collection); err != nil {
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection event.Collection = collection
// create the collection return e.App.OnCollectionCreateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
return form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] { if err := e.App.Save(e.Collection); err != nil {
return func(m *models.Collection) error { // validation failure
event.Collection = m var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return api.app.OnCollectionBeforeCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error { return e.BadRequestError("Failed to create collection.", validationErrors)
if err := next(e.Collection); err != nil {
return NewBadRequestError("Failed to create the collection.", err)
}
return api.app.OnCollectionAfterCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Collection)
})
})
}
})
}
func (api *collectionApi) update(c echo.Context) error {
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
if err != nil || collection == nil {
return NewNotFoundError("", err)
}
form := forms.NewCollectionUpsert(api.app, collection)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.CollectionUpdateEvent)
event.HttpContext = c
event.Collection = collection
// update the collection
return form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] {
return func(m *models.Collection) error {
event.Collection = m
return api.app.OnCollectionBeforeUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error {
if err := next(e.Collection); err != nil {
return NewBadRequestError("Failed to update the collection.", err)
}
return api.app.OnCollectionAfterUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Collection)
})
})
}
})
}
func (api *collectionApi) delete(c echo.Context) error {
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
if err != nil || collection == nil {
return NewNotFoundError("", err)
}
event := new(core.CollectionDeleteEvent)
event.HttpContext = c
event.Collection = collection
return api.app.OnCollectionBeforeDeleteRequest().Trigger(event, func(e *core.CollectionDeleteEvent) error {
if err := api.app.Dao().DeleteCollection(e.Collection); err != nil {
return NewBadRequestError("Failed to delete collection due to existing dependency.", err)
}
return api.app.OnCollectionAfterDeleteRequest().Trigger(event, func(e *core.CollectionDeleteEvent) error {
if e.HttpContext.Response().Committed {
return nil
} }
return e.HttpContext.NoContent(http.StatusNoContent) // other generic db error
}) return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
}
return e.JSON(http.StatusOK, e.Collection)
}) })
} }
func (api *collectionApi) bulkImport(c echo.Context) error { func collectionUpdate(e *core.RequestEvent) error {
form := forms.NewCollectionsImport(api.app) collection, err := e.App.FindCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
// load request data return e.NotFoundError("", err)
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
} }
event := new(core.CollectionsImportEvent) if err := e.BindBody(collection); err != nil {
event.HttpContext = c return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
event.Collections = form.Collections }
// import collections event := new(core.CollectionRequestEvent)
return form.Submit(func(next forms.InterceptorNextFunc[[]*models.Collection]) forms.InterceptorNextFunc[[]*models.Collection] { event.RequestEvent = e
return func(imports []*models.Collection) error { event.Collection = collection
event.Collections = imports
return api.app.OnCollectionsBeforeImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error { return event.App.OnCollectionUpdateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := next(e.Collections); err != nil { if err := e.App.Save(e.Collection); err != nil {
return NewBadRequestError("Failed to import the submitted collections.", err) // validation failure
} var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to update collection.", validationErrors)
}
return api.app.OnCollectionsAfterImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error { // other generic db error
if e.HttpContext.Response().Committed { return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
} }
return e.JSON(http.StatusOK, e.Collection)
})
}
func collectionDelete(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return e.App.OnCollectionDeleteRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Delete(e.Collection); err != nil {
msg := "Failed to delete collection"
// check fo references
refs, _ := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
if len(refs) > 0 {
names := make([]string, 0, len(refs))
for ref := range refs {
names = append(names, ref.Name)
}
msg += " probably due to existing reference in " + strings.Join(names, ", ")
}
return e.BadRequestError(msg, err)
}
return e.NoContent(http.StatusNoContent)
})
}
func collectionTruncate(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
err = e.App.TruncateCollection(collection)
if err != nil {
return e.BadRequestError("Failed to truncate collection (most likely due to required cascade delete record references).", err)
}
return e.NoContent(http.StatusNoContent)
}
func collectionScaffolds(e *core.RequestEvent) error {
return e.JSON(http.StatusOK, map[string]*core.Collection{
core.CollectionTypeBase: core.NewBaseCollection(""),
core.CollectionTypeAuth: core.NewAuthCollection(""),
core.CollectionTypeView: core.NewViewCollection(""),
}) })
} }

60
apis/collection_import.go Normal file
View File

@ -0,0 +1,60 @@
package apis
import (
"errors"
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func collectionsImport(e *core.RequestEvent) error {
form := new(collectionsImportForm)
err := e.BindBody(form)
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
err = form.validate()
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.CollectionsImportRequestEvent)
event.RequestEvent = e
event.CollectionsData = form.Collections
event.DeleteMissing = form.DeleteMissing
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
if importErr == nil {
return e.NoContent(http.StatusNoContent)
}
// validation failure
var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to import collections.", validationErrors)
}
// generic/db failure
return e.BadRequestError("Failed to import collections.", validation.Errors{"collections": validation.NewError(
"validation_collections_import_failure",
"Failed to import the collections configuration. Raw error:\n"+importErr.Error(),
)})
})
}
// -------------------------------------------------------------------
type collectionsImportForm struct {
Collections []map[string]any `form:"collections" json:"collections"`
DeleteMissing bool `form:"deleteMissing" json:"deleteMissing"`
}
func (form *collectionsImportForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Collections, validation.Required),
)
}

View File

@ -0,0 +1,257 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestCollectionsImport(t *testing.T) {
t.Parallel()
totalCollections := 16
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPut,
URL: "/api/collections/import",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPut,
URL: "/api/collections/import",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser + empty collections",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{"collections":[]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + collections validator failure",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{"name": "import1"},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "expand",
"type": "text"
}
]
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_collections_import_failure"`,
`import2`,
`fields`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 2,
"OnCollectionCreateExecute": 2,
"OnCollectionAfterCreateError": 2,
"OnModelCreate": 2,
"OnModelCreateExecute": 2,
"OnModelAfterCreateError": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + successful collections create",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{
"name": "import1",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
]
},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
],
"indexes": [
"create index idx_test on import2 (test)"
]
},
{
"name": "auth_without_fields",
"type": "auth"
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 3,
"OnCollectionCreateExecute": 3,
"OnCollectionAfterCreateSuccess": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections + 3
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
indexes, err := app.TableIndexes("import2")
if err != nil || indexes["idx_test"] == "" {
t.Fatalf("Missing index %s (%v)", "idx_test", err)
}
},
},
{
Name: "authorized as superuser + create/update/delete",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"deleteMissing": true,
"collections":[
{"name": "test123"},
{
"id":"wsmn24bux7wo113",
"name":"demo1",
"fields":[
{
"id":"_2hlxbmp",
"name":"title",
"type":"text",
"required":true
}
],
"indexes": []
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnCollectionCreate": 1,
"OnCollectionCreateExecute": 1,
"OnCollectionAfterCreateSuccess": 1,
// ---
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnCollectionUpdate": 1,
"OnCollectionUpdateExecute": 1,
"OnCollectionAfterUpdateSuccess": 1,
// ---
"OnModelDelete": 14,
"OnModelAfterDeleteSuccess": 14,
"OnModelDeleteExecute": 14,
"OnCollectionDelete": 9,
"OnCollectionDeleteExecute": 9,
"OnCollectionAfterDeleteSuccess": 9,
"OnRecordAfterDeleteSuccess": 5,
"OnRecordDelete": 5,
"OnRecordDeleteExecute": 5,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
systemCollections := 0
for _, c := range collections {
if c.System {
systemCollections++
}
}
expected := systemCollections + 2
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

138
apis/dashboard.go Normal file
View File

@ -0,0 +1,138 @@
package apis
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
const installerParam = "pbinstal"
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
func stripWildcard(pattern string) string {
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "")
}
// installerRedirect redirects the user to the installer dashboard UI page
// when the application needs some preliminary configurations to be done.
func installerRedirect(app core.App, cpPath string) hook.HandlerFunc[*core.RequestEvent] {
// note: to avoid locks contention it is not concurrent safe but it
// is expected to be updated only once during initialization
var hasSuperuser bool
// strip named wildcard
cpPath = stripWildcard(cpPath)
updateHasSuperuser := func(app core.App) error {
total, err := app.CountRecords(core.CollectionNameSuperusers)
if err != nil {
return err
}
hasSuperuser = total > 0
return nil
}
// load initial state on app init
app.OnBootstrap().BindFunc(func(e *core.BootstrapEvent) error {
err := e.Next()
if err != nil {
return err
}
err = updateHasSuperuser(e.App)
if err != nil {
return fmt.Errorf("failed to check for existing superuser: %w", err)
}
return nil
})
// update on superuser create
app.OnRecordCreateRequest(core.CollectionNameSuperusers).BindFunc(func(e *core.RecordRequestEvent) error {
err := e.Next()
if err != nil {
return err
}
if !hasSuperuser {
hasSuperuser = true
}
return nil
})
return func(e *core.RequestEvent) error {
if hasSuperuser {
return e.Next()
}
isAPI := strings.HasPrefix(e.Request.URL.Path, "/api/")
isControlPanel := strings.HasPrefix(e.Request.URL.Path, cpPath)
wildcard := e.Request.PathValue(StaticWildcardParam)
// skip redirect checks for API and non-root level dashboard index.html requests (css, images, etc.)
if isAPI || (isControlPanel && wildcard != "" && wildcard != router.IndexPage) {
return e.Next()
}
// check again in case the superuser was created by some other process
if err := updateHasSuperuser(e.App); err != nil {
return err
}
if hasSuperuser {
return e.Next()
}
_, hasInstallerParam := e.Request.URL.Query()[installerParam]
// redirect to the installer page
if !hasInstallerParam {
return e.Redirect(http.StatusTemporaryRedirect, cpPath+"?"+installerParam+"#")
}
return e.Next()
}
}
// dashboardRemoveInstallerParam redirects to a non-installer
// query param in case there is already a superuser created.
//
// Note: intended to be registered only for the dashboard route
// to prevent excessive checks for every other route in installerRedirect.
func dashboardRemoveInstallerParam() hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
_, hasInstallerParam := e.Request.URL.Query()[installerParam]
if !hasInstallerParam {
return e.Next() // nothing to remove
}
// clear installer param
total, _ := e.App.CountRecords(core.CollectionNameSuperusers)
if total > 0 {
return e.Redirect(http.StatusTemporaryRedirect, "?")
}
return e.Next()
}
}
// dashboardCacheControl adds default Cache-Control header for all
// dashboard UI resources (ignoring the root index.html path)
func dashboardCacheControl() hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
if e.Request.PathValue(StaticWildcardParam) != "" {
e.Response.Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
}
return e.Next()
}
}

View File

@ -7,18 +7,12 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"runtime" "runtime"
"strings"
"time" "time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/filesystem" "github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/list" "github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/security" "github.com/pocketbase/pocketbase/tools/router"
"github.com/spf13/cast"
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
@ -27,23 +21,19 @@ var imageContentTypes = []string{"image/png", "image/jpg", "image/jpeg", "image/
var defaultThumbSizes = []string{"100x100"} var defaultThumbSizes = []string{"100x100"}
// bindFileApi registers the file api endpoints and the corresponding handlers. // bindFileApi registers the file api endpoints and the corresponding handlers.
func bindFileApi(app core.App, rg *echo.Group) { func bindFileApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := fileApi{ api := fileApi{
app: app,
thumbGenSem: semaphore.NewWeighted(int64(runtime.NumCPU() + 2)), // the value is arbitrary chosen and may change in the future thumbGenSem: semaphore.NewWeighted(int64(runtime.NumCPU() + 2)), // the value is arbitrary chosen and may change in the future
thumbGenPending: new(singleflight.Group), thumbGenPending: new(singleflight.Group),
thumbGenMaxWait: 60 * time.Second, thumbGenMaxWait: 60 * time.Second,
} }
subGroup := rg.Group("/files", ActivityLogger(app)) sub := rg.Group("/files")
subGroup.POST("/token", api.fileToken) sub.POST("/token", api.fileToken).Bind(RequireAuth())
subGroup.HEAD("/:collection/:recordId/:filename", api.download, LoadCollectionContext(api.app)) sub.GET("/{collection}/{recordId}/{filename}", api.download).Bind(collectionPathRateLimit("", "file"))
subGroup.GET("/:collection/:recordId/:filename", api.download, LoadCollectionContext(api.app))
} }
type fileApi struct { type fileApi struct {
app core.App
// thumbGenSem is a semaphore to prevent too much concurrent // thumbGenSem is a semaphore to prevent too much concurrent
// requests generating new thumbs at the same time. // requests generating new thumbs at the same time.
thumbGenSem *semaphore.Weighted thumbGenSem *semaphore.Weighted
@ -57,84 +47,67 @@ type fileApi struct {
thumbGenMaxWait time.Duration thumbGenMaxWait time.Duration
} }
func (api *fileApi) fileToken(c echo.Context) error { func (api *fileApi) fileToken(e *core.RequestEvent) error {
event := new(core.FileTokenEvent) if e.Auth == nil {
event.HttpContext = c return e.UnauthorizedError("Missing auth context.", nil)
if admin, _ := c.Get(ContextAdminKey).(*models.Admin); admin != nil {
event.Model = admin
event.Token, _ = tokens.NewAdminFileToken(api.app, admin)
} else if record, _ := c.Get(ContextAuthRecordKey).(*models.Record); record != nil {
event.Model = record
event.Token, _ = tokens.NewRecordFileToken(api.app, record)
} }
return api.app.OnFileBeforeTokenRequest().Trigger(event, func(e *core.FileTokenEvent) error { token, err := e.Auth.NewFileToken()
if e.Model == nil || e.Token == "" { if err != nil {
return NewBadRequestError("Failed to generate file token.", nil) return e.InternalServerError("Failed to generate file token", err)
} }
return api.app.OnFileAfterTokenRequest().Trigger(event, func(e *core.FileTokenEvent) error { event := new(core.FileTokenRequestEvent)
if e.HttpContext.Response().Committed { event.RequestEvent = e
return nil event.Token = token
}
return e.HttpContext.JSON(http.StatusOK, map[string]string{ return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
"token": e.Token, return e.JSON(http.StatusOK, map[string]string{
}) "token": e.Token,
}) })
}) })
} }
func (api *fileApi) download(c echo.Context) error { func (api *fileApi) download(e *core.RequestEvent) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection) collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if collection == nil {
return NewNotFoundError("", nil)
}
recordId := c.PathParam("recordId")
if recordId == "" {
return NewNotFoundError("", nil)
}
record, err := api.app.Dao().FindRecordById(collection.Id, recordId)
if err != nil { if err != nil {
return NewNotFoundError("", err) return e.NotFoundError("", nil)
} }
filename := c.PathParam("filename") recordId := e.Request.PathValue("recordId")
if recordId == "" {
return e.NotFoundError("", nil)
}
record, err := e.App.FindRecordById(collection, recordId)
if err != nil {
return e.NotFoundError("", err)
}
filename := e.Request.PathValue("filename")
fileField := record.FindFileFieldByFile(filename) fileField := record.FindFileFieldByFile(filename)
if fileField == nil { if fileField == nil {
return NewNotFoundError("", nil) return e.NotFoundError("", nil)
}
options, ok := fileField.Options.(*schema.FileOptions)
if !ok {
return NewBadRequestError("", errors.New("failed to load file options"))
} }
// check whether the request is authorized to view the protected file // check whether the request is authorized to view the protected file
if options.Protected { if fileField.Protected {
token := c.QueryParam("token") originalRequestInfo, err := e.RequestInfo()
if err != nil {
adminOrAuthRecord, _ := api.findAdminOrAuthRecordByFileToken(token) return e.InternalServerError("Failed to load request info", err)
// create a copy of the cached request data and adjust it for the current auth model
requestInfo := *RequestInfo(c)
requestInfo.Context = models.RequestInfoContextProtectedFile
requestInfo.Admin = nil
requestInfo.AuthRecord = nil
if adminOrAuthRecord != nil {
if admin, _ := adminOrAuthRecord.(*models.Admin); admin != nil {
requestInfo.Admin = admin
} else if record, _ := adminOrAuthRecord.(*models.Record); record != nil {
requestInfo.AuthRecord = record
}
} }
if ok, _ := api.app.Dao().CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok { token := e.Request.URL.Query().Get("token")
return NewForbiddenError("Insufficient permissions to access the file resource.", nil) authRecord, _ := e.App.FindAuthRecordByToken(token, core.TokenTypeFile)
// create a shallow copy of the cached request data and adjust it to the current auth record (if any)
requestInfo := *originalRequestInfo
requestInfo.Context = core.RequestInfoContextProtectedFile
requestInfo.Auth = authRecord
if ok, _ := e.App.CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
return e.NotFoundError("", errors.New("insufficient permissions to access the file resource"))
} }
} }
@ -142,16 +115,16 @@ func (api *fileApi) download(c echo.Context) error {
// fetch the original view file field related record // fetch the original view file field related record
if collection.IsView() { if collection.IsView() {
fileRecord, err := api.app.Dao().FindRecordByViewFile(collection.Id, fileField.Name, filename) fileRecord, err := e.App.FindRecordByViewFile(collection.Id, fileField.Name, filename)
if err != nil { if err != nil {
return NewNotFoundError("", fmt.Errorf("Failed to fetch view file field record: %w", err)) return e.NotFoundError("", fmt.Errorf("failed to fetch view file field record: %w", err))
} }
baseFilesPath = fileRecord.BaseFilesPath() baseFilesPath = fileRecord.BaseFilesPath()
} }
fsys, err := api.app.NewFilesystem() fsys, err := e.App.NewFilesystem()
if err != nil { if err != nil {
return NewBadRequestError("Filesystem initialization failure.", err) return e.InternalServerError("Filesystem initialization failure.", err)
} }
defer fsys.Close() defer fsys.Close()
@ -160,12 +133,12 @@ func (api *fileApi) download(c echo.Context) error {
servedName := filename servedName := filename
// check for valid thumb size param // check for valid thumb size param
thumbSize := c.QueryParam("thumb") thumbSize := e.Request.URL.Query().Get("thumb")
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, options.Thumbs)) { if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, fileField.Thumbs)) {
// extract the original file meta attributes and check it existence // extract the original file meta attributes and check it existence
oAttrs, oAttrsErr := fsys.Attributes(originalPath) oAttrs, oAttrsErr := fsys.Attributes(originalPath)
if oAttrsErr != nil { if oAttrsErr != nil {
return NewNotFoundError("", err) return e.NotFoundError("", err)
} }
// check if it is an image // check if it is an image
@ -176,8 +149,8 @@ func (api *fileApi) download(c echo.Context) error {
// create a new thumb if it doesn't exist // create a new thumb if it doesn't exist
if exists, _ := fsys.Exists(servedPath); !exists { if exists, _ := fsys.Exists(servedPath); !exists {
if err := api.createThumb(c, fsys, originalPath, servedPath, thumbSize); err != nil { if err := api.createThumb(e, fsys, originalPath, servedPath, thumbSize); err != nil {
api.app.Logger().Warn( e.App.Logger().Warn(
"Fallback to original - failed to create thumb "+servedName, "Fallback to original - failed to create thumb "+servedName,
slog.Any("error", err), slog.Any("error", err),
slog.String("original", originalPath), slog.String("original", originalPath),
@ -192,8 +165,8 @@ func (api *fileApi) download(c echo.Context) error {
} }
} }
event := new(core.FileDownloadEvent) event := new(core.FileDownloadRequestEvent)
event.HttpContext = c event.RequestEvent = e
event.Collection = collection event.Collection = collection
event.Record = record event.Record = record
event.FileField = fileField event.FileField = fileField
@ -203,61 +176,26 @@ func (api *fileApi) download(c echo.Context) error {
// clickjacking shouldn't be a concern when serving uploaded files, // clickjacking shouldn't be a concern when serving uploaded files,
// so it safe to unset the global X-Frame-Options to allow files embedding // so it safe to unset the global X-Frame-Options to allow files embedding
// (note: it is out of the hook to allow users to customize the behavior) // (note: it is out of the hook to allow users to customize the behavior)
c.Response().Header().Del("X-Frame-Options") e.Response.Header().Del("X-Frame-Options")
return api.app.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadEvent) error { return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
if e.HttpContext.Response().Committed { if err := fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName); err != nil {
return nil return e.NotFoundError("", err)
}
if err := fsys.Serve(e.HttpContext.Response(), e.HttpContext.Request(), e.ServedPath, e.ServedName); err != nil {
return NewNotFoundError("", err)
} }
return nil return nil
}) })
} }
func (api *fileApi) findAdminOrAuthRecordByFileToken(fileToken string) (models.Model, error) {
fileToken = strings.TrimSpace(fileToken)
if fileToken == "" {
return nil, errors.New("missing file token")
}
claims, _ := security.ParseUnverifiedJWT(strings.TrimSpace(fileToken))
tokenType := cast.ToString(claims["type"])
switch tokenType {
case tokens.TypeAdmin:
admin, err := api.app.Dao().FindAdminByToken(
fileToken,
api.app.Settings().AdminFileToken.Secret,
)
if err == nil && admin != nil {
return admin, nil
}
case tokens.TypeAuthRecord:
record, err := api.app.Dao().FindAuthRecordByToken(
fileToken,
api.app.Settings().RecordFileToken.Secret,
)
if err == nil && record != nil {
return record, nil
}
}
return nil, errors.New("missing or invalid file token")
}
func (api *fileApi) createThumb( func (api *fileApi) createThumb(
c echo.Context, e *core.RequestEvent,
fsys *filesystem.System, fsys *filesystem.System,
originalPath string, originalPath string,
thumbPath string, thumbPath string,
thumbSize string, thumbSize string,
) error { ) error {
ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) { ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) {
ctx, cancel := context.WithTimeout(c.Request().Context(), api.thumbGenMaxWait) ctx, cancel := context.WithTimeout(e.Request.Context(), api.thumbGenMaxWait)
defer cancel() defer cancel()
if err := api.thumbGenSem.Acquire(ctx, 1); err != nil { if err := api.thumbGenSem.Acquire(ctx, 1); err != nil {

View File

@ -10,11 +10,8 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types" "github.com/pocketbase/pocketbase/tools/types"
) )
@ -26,23 +23,54 @@ func TestFileToken(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/files/token", URL: "/api/files/token",
ExpectedStatus: 400, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "regular user",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1, "*": 0,
"OnFileTokenRequest": 1,
}, },
}, },
{ {
Name: "unauthorized with model and token via hook", Name: "superuser",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/files/token", URL: "/api/files/token",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { Headers: map[string]string{
app.OnFileBeforeTokenRequest().Add(func(e *core.FileTokenEvent) error { "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
record, _ := app.Dao().FindAuthRecordByEmail("users", "test@example.com") },
e.Model = record ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileTokenRequest": 1,
},
},
{
Name: "hook token overwrite",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnFileTokenRequest().BindFunc(func(e *core.FileTokenRequestEvent) error {
e.Token = "test" e.Token = "test"
return nil return e.Next()
}) })
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
@ -50,40 +78,8 @@ func TestFileToken(t *testing.T) {
`"token":"test"`, `"token":"test"`,
}, },
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1, "*": 0,
"OnFileAfterTokenRequest": 1, "OnFileTokenRequest": 1,
},
},
{
Name: "auth record",
Method: http.MethodPost,
Url: "/api/files/token",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1,
"OnFileAfterTokenRequest": 1,
},
},
{
Name: "admin",
Method: http.MethodPost,
Url: "/api/files/token",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1,
"OnFileAfterTokenRequest": 1,
}, },
}, },
} }
@ -152,233 +148,271 @@ func TestFileDownload(t *testing.T) {
{ {
Name: "missing collection", Name: "missing collection",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png", URL: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
ExpectedStatus: 404, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "missing record", Name: "missing record",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png", URL: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
ExpectedStatus: 404, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "missing file", Name: "missing file",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
ExpectedStatus: 404, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "existing image", Name: "existing image",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testImg)}, ExpectedContent: []string{string(testImg)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - missing thumb (should fallback to the original)", Name: "existing image - missing thumb (should fallback to the original)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testImg)}, ExpectedContent: []string{string(testImg)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - existing thumb (crop center)", Name: "existing image - existing thumb (crop center)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropCenter)}, ExpectedContent: []string{string(testThumbCropCenter)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - existing thumb (crop top)", Name: "existing image - existing thumb (crop top)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropTop)}, ExpectedContent: []string{string(testThumbCropTop)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - existing thumb (crop bottom)", Name: "existing image - existing thumb (crop bottom)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropBottom)}, ExpectedContent: []string{string(testThumbCropBottom)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - existing thumb (fit)", Name: "existing image - existing thumb (fit)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbFit)}, ExpectedContent: []string{string(testThumbFit)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - existing thumb (zero width)", Name: "existing image - existing thumb (zero width)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbZeroWidth)}, ExpectedContent: []string{string(testThumbZeroWidth)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing image - existing thumb (zero height)", Name: "existing image - existing thumb (zero height)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0", URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbZeroHeight)}, ExpectedContent: []string{string(testThumbZeroHeight)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "existing non image file - thumb parameter should be ignored", Name: "existing non image file - thumb parameter should be ignored",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100", URL: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{string(testFile)}, ExpectedContent: []string{string(testFile)},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
// protected file access checks // protected file access checks
{ {
Name: "protected file - expired token", Name: "protected file - superuser with expired file token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100", URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.hTNDzikwJdcoWrLnRnp7xbaifZ2vuYZ0oOYRHtJfnk4",
ExpectedStatus: 200, ExpectedStatus: 404,
ExpectedContent: []string{string(testFile)},
ExpectedEvents: map[string]int{
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file - admin with expired file token",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImFkbWluIn0.g7Q_3UX6H--JWJ7yt1Hoe-1ugTX1KpbKzdt0zjGSe-E",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "protected file - admin with valid file token", Name: "protected file - superuser with valid file token",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU", URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{"PNG"}, ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "protected file - guest without view access", Name: "protected file - guest without view access",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png", URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
ExpectedStatus: 403, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "protected file - guest with view access", Name: "protected file - guest with view access",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png", URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
dao := daos.New(app.Dao().DB())
// mock public view access // mock public view access
c, err := dao.FindCollectionByNameOrId("demo1") c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil { if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err) t.Fatalf("Failed to fetch mock collection: %v", err)
} }
c.ViewRule = types.Pointer("") c.ViewRule = types.Pointer("")
if err := dao.SaveCollection(c); err != nil { if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err) t.Fatalf("Failed to update mock collection: %v", err)
} }
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{"PNG"}, ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "protected file - auth record without view access", Name: "protected file - auth record without view access",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk", URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
dao := daos.New(app.Dao().DB())
// mock restricted user view access // mock restricted user view access
c, err := dao.FindCollectionByNameOrId("demo1") c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil { if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err) t.Fatalf("Failed to fetch mock collection: %v", err)
} }
c.ViewRule = types.Pointer("@request.auth.verified = true") c.ViewRule = types.Pointer("@request.auth.verified = true")
if err := dao.SaveCollection(c); err != nil { if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err) t.Fatalf("Failed to update mock collection: %v", err)
} }
}, },
ExpectedStatus: 403, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "protected file - auth record with view access", Name: "protected file - auth record with view access",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk", URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
dao := daos.New(app.Dao().DB())
// mock user view access // mock user view access
c, err := dao.FindCollectionByNameOrId("demo1") c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil { if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err) t.Fatalf("Failed to fetch mock collection: %v", err)
} }
c.ViewRule = types.Pointer("@request.auth.verified = false") c.ViewRule = types.Pointer("@request.auth.verified = false")
if err := dao.SaveCollection(c); err != nil { if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err) t.Fatalf("Failed to update mock collection: %v", err)
} }
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{"PNG"}, ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
{ {
Name: "protected file in view (view's View API rule failure)", Name: "protected file in view (view's View API rule failure)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk", URL: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
ExpectedStatus: 403, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "protected file in view (view's View API rule success)", Name: "protected file in view (view's View API rule success)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk", URL: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{"test"}, ExpectedContent: []string{"test"},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1, "OnFileDownloadRequest": 1,
}, },
}, },
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:file"},
{MaxRequests: 0, Label: "users:file"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:file"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
} }
for _, scenario := range scenarios { for _, scenario := range scenarios {
@ -410,30 +444,23 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
defer fsys.Close() defer fsys.Close()
// create a dummy file field collection // create a dummy file field collection
demo1, err := app.Dao().FindCollectionByNameOrId("demo1") demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
fileField := demo1.Schema.GetFieldByName("file_one") fileField := demo1.Fields.GetByName("file_one").(*core.FileField)
fileField.Options = &schema.FileOptions{ fileField.Protected = false
Protected: false, fileField.MaxSelect = 1
MaxSelect: 1, fileField.MaxSize = 999999
MaxSize: 999999, // new thumbs
// new thumbs fileField.Thumbs = []string{"111x111", "111x222", "111x333"}
Thumbs: []string{"111x111", "111x222", "111x333"}, demo1.Fields.Add(fileField)
} if err = app.Save(demo1); err != nil {
demo1.Schema.AddField(fileField)
if err := app.Dao().SaveCollection(demo1); err != nil {
t.Fatal(err) t.Fatal(err)
} }
fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png" fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png"
e, err := apis.InitApi(app)
if err != nil {
t.Fatal(err)
}
urls := []string{ urls := []string{
"/api/files/" + fileKey + "?thumb=111x111", "/api/files/" + fileKey + "?thumb=111x111",
"/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb "/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb
@ -446,7 +473,6 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
wg.Add(len(urls)) wg.Add(len(urls))
for _, url := range urls { for _, url := range urls {
url := url
go func() { go func() {
defer wg.Done() defer wg.Done()
@ -454,7 +480,11 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", url, nil)
e.ServeHTTP(recorder, req) pbRouter, _ := apis.NewRouter(app)
mux, _ := pbRouter.BuildMux()
if mux != nil {
mux.ServeHTTP(recorder, req)
}
}() }()
} }

View File

@ -2,42 +2,52 @@ package apis
import ( import (
"net/http" "net/http"
"slices"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
) )
// bindHealthApi registers the health api endpoint. // bindHealthApi registers the health api endpoint.
func bindHealthApi(app core.App, rg *echo.Group) { func bindHealthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := healthApi{app: app}
subGroup := rg.Group("/health") subGroup := rg.Group("/health")
subGroup.HEAD("", api.healthCheck) subGroup.GET("", healthCheck)
subGroup.GET("", api.healthCheck)
}
type healthApi struct {
app core.App
}
type healthCheckResponse struct {
Message string `json:"message"`
Code int `json:"code"`
Data struct {
CanBackup bool `json:"canBackup"`
} `json:"data"`
} }
// healthCheck returns a 200 OK response if the server is healthy. // healthCheck returns a 200 OK response if the server is healthy.
func (api *healthApi) healthCheck(c echo.Context) error { func healthCheck(e *core.RequestEvent) error {
if c.Request().Method == http.MethodHead { resp := struct {
return c.NoContent(http.StatusOK) Message string `json:"message"`
Code int `json:"code"`
Data map[string]any `json:"data"`
}{
Code: http.StatusOK,
Message: "API is healthy.",
} }
resp := new(healthCheckResponse) if e.HasSuperuserAuth() {
resp.Code = http.StatusOK resp.Data = make(map[string]any, 3)
resp.Message = "API is healthy." resp.Data["canBackup"] = !e.App.Store().Has(core.StoreKeyActiveBackup)
resp.Data.CanBackup = !api.app.Store().Has(core.StoreKeyActiveBackup) resp.Data["realIP"] = e.RealIP()
return c.JSON(http.StatusOK, resp) // loosely check if behind a reverse proxy
// (usually used in the dashboard to remind superusers in case deployed behind reverse-proxy)
possibleProxyHeader := ""
headersToCheck := append(
slices.Clone(e.App.Settings().TrustedProxy.Headers),
// common proxy headers
"CF-Connecting-IP", "Fly-Client-IP", "XForwarded-For",
)
for _, header := range headersToCheck {
if e.Request.Header.Get(header) != "" {
possibleProxyHeader = header
break
}
}
resp.Data["possibleProxyHeader"] = possibleProxyHeader
} else {
resp.Data = map[string]any{} // ensure that it is returned as object
}
return e.JSON(http.StatusOK, resp)
} }

View File

@ -12,21 +12,56 @@ func TestHealthAPI(t *testing.T) {
scenarios := []tests.ApiScenario{ scenarios := []tests.ApiScenario{
{ {
Name: "HEAD health status", Name: "GET health status (guest)",
Method: http.MethodHead, Method: http.MethodGet, // automatically matches also HEAD as a side-effect of the Go std mux
Url: "/api/health", URL: "/api/health",
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{}`,
},
NotExpectedContent: []string{
"canBackup",
"realIP",
"possibleProxyHeader",
},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "GET health status", Name: "GET health status (regular user)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/health", URL: "/api/health",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{}`,
},
NotExpectedContent: []string{
"canBackup",
"realIP",
"possibleProxyHeader",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "GET health status (superuser)",
Method: http.MethodGet,
URL: "/api/health",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
`"code":200`, `"code":200`,
`"data":{`, `"data":{`,
`"canBackup":true`, `"canBackup":true`,
`"realIP"`,
`"possibleProxyHeader"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }

View File

@ -3,79 +3,71 @@ package apis
import ( import (
"net/http" "net/http"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search" "github.com/pocketbase/pocketbase/tools/search"
) )
// bindLogsApi registers the request logs api endpoints. // bindLogsApi registers the request logs api endpoints.
func bindLogsApi(app core.App, rg *echo.Group) { func bindLogsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := logsApi{app: app} sub := rg.Group("/logs").Bind(RequireSuperuserAuth(), SkipSuccessActivityLog())
sub.GET("", logsList)
subGroup := rg.Group("/logs", RequireAdminAuth()) sub.GET("/stats", logsStats)
subGroup.GET("", api.list) sub.GET("/{id}", logsView)
subGroup.GET("/stats", api.stats)
subGroup.GET("/:id", api.view)
}
type logsApi struct {
app core.App
} }
var logFilterFields = []string{ var logFilterFields = []string{
"rowid", "id", "created", "updated", "id", "created", "level", "message", "data",
"level", "message", "data",
`^data\.[\w\.\:]*\w+$`, `^data\.[\w\.\:]*\w+$`,
} }
func (api *logsApi) list(c echo.Context) error { func logsList(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...) fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
result, err := search.NewProvider(fieldResolver). result, err := search.NewProvider(fieldResolver).
Query(api.app.LogsDao().LogQuery()). Query(e.App.AuxModelQuery(&core.Log{})).
ParseAndExec(c.QueryParams().Encode(), &[]*models.Log{}) ParseAndExec(e.Request.URL.Query().Encode(), &[]*core.Log{})
if err != nil { if err != nil {
return NewBadRequestError("", err) return e.BadRequestError("", err)
} }
return c.JSON(http.StatusOK, result) return e.JSON(http.StatusOK, result)
} }
func (api *logsApi) stats(c echo.Context) error { func logsStats(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...) fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
filter := c.QueryParam(search.FilterQueryParam) filter := e.Request.URL.Query().Get(search.FilterQueryParam)
var expr dbx.Expression var expr dbx.Expression
if filter != "" { if filter != "" {
var err error var err error
expr, err = search.FilterData(filter).BuildExpr(fieldResolver) expr, err = search.FilterData(filter).BuildExpr(fieldResolver)
if err != nil { if err != nil {
return NewBadRequestError("Invalid filter format.", err) return e.BadRequestError("Invalid filter format.", err)
} }
} }
stats, err := api.app.LogsDao().LogsStats(expr) stats, err := e.App.LogsStats(expr)
if err != nil { if err != nil {
return NewBadRequestError("Failed to generate logs stats.", err) return e.BadRequestError("Failed to generate logs stats.", err)
} }
return c.JSON(http.StatusOK, stats) return e.JSON(http.StatusOK, stats)
} }
func (api *logsApi) view(c echo.Context) error { func logsView(e *core.RequestEvent) error {
id := c.PathParam("id") id := e.Request.PathValue("id")
if id == "" { if id == "" {
return NewNotFoundError("", nil) return e.NotFoundError("", nil)
} }
log, err := api.app.LogsDao().FindLogById(id) log, err := e.App.FindLogById(id)
if err != nil || log == nil { if err != nil || log == nil {
return NewNotFoundError("", err) return e.NotFoundError("", err)
} }
return c.JSON(http.StatusOK, log) return e.JSON(http.StatusOK, log)
} }

View File

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/labstack/echo/v5" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
) )
@ -15,29 +15,31 @@ func TestLogsList(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs", URL: "/api/logs",
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs", URL: "/api/logs",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin", Name: "authorized as superuser",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs", URL: "/api/logs",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.MockLogsData(app); err != nil { if err := tests.StubLogsData(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
@ -50,16 +52,17 @@ func TestLogsList(t *testing.T) {
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`, `"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`, `"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin + filter", Name: "authorized as superuser + filter",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs?filter=data.status>200", URL: "/api/logs?filter=data.status>200",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.MockLogsData(app); err != nil { if err := tests.StubLogsData(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
@ -71,6 +74,7 @@ func TestLogsList(t *testing.T) {
`"items":[{`, `"items":[{`,
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`, `"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -86,44 +90,47 @@ func TestLogView(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91", URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91", URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (nonexisting request log)", Name: "authorized as superuser (nonexisting request log)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91", URL: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.MockLogsData(app); err != nil { if err := tests.StubLogsData(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 404, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (existing request log)", Name: "authorized as superuser (existing request log)",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91", URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.MockLogsData(app); err != nil { if err := tests.StubLogsData(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
@ -131,6 +138,7 @@ func TestLogView(t *testing.T) {
ExpectedContent: []string{ ExpectedContent: []string{
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`, `"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -146,52 +154,54 @@ func TestLogsStats(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/stats", URL: "/api/logs/stats",
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/stats", URL: "/api/logs/stats",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin", Name: "authorized as superuser",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/stats", URL: "/api/logs/stats",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.MockLogsData(app); err != nil { if err := tests.StubLogsData(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
`[{"total":1,"date":"2022-05-01 10:00:00.000Z"},{"total":1,"date":"2022-05-02 10:00:00.000Z"}]`, `[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
}, },
}, },
{ {
Name: "authorized as admin + filter", Name: "authorized as superuser + filter",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/logs/stats?filter=data.status>200", URL: "/api/logs/stats?filter=data.status>200",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.MockLogsData(app); err != nil { if err := tests.StubLogsData(app); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
`[{"total":1,"date":"2022-05-02 10:00:00.000Z"}]`, `[{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
}, },
}, },
} }

View File

@ -3,303 +3,321 @@ package apis
import ( import (
"fmt" "fmt"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/url" "net/url"
"slices"
"strings" "strings"
"time" "time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/list" "github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine" "github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast" "github.com/spf13/cast"
) )
// Common request context keys used by the middlewares and api handlers. // Common request event store keys used by the middlewares and api handlers.
const ( const (
ContextAdminKey string = "admin" RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log
ContextAuthRecordKey string = "authRecord"
ContextCollectionKey string = "collection" requestEventKeyExecStart = "__execStart" // the value must be time.Time
ContextExecStartKey string = "execStart" requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool
)
const (
DefaultWWWRedirectMiddlewarePriority = -99999
DefaultWWWRedirectMiddlewareId = "pbWWWRedirect"
DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30
DefaultActivityLoggerMiddlewareId = "pbActivityLogger"
DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog"
DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog"
DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20
DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken"
DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10
DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders"
DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly"
DefaultRequireAuthMiddlewareId = "pbRequireAuth"
DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth"
DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId = "pbRequireSuperuserAuthOnlyIfAny"
DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth"
DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth"
) )
// RequireGuestOnly middleware requires a request to NOT have a valid // RequireGuestOnly middleware requires a request to NOT have a valid
// Authorization header. // Authorization header.
// //
// This middleware is the opposite of [apis.RequireAdminOrRecordAuth()]. // This middleware is the opposite of [apis.RequireAuth()].
func RequireGuestOnly() echo.MiddlewareFunc { func RequireGuestOnly() *hook.Handler[*core.RequestEvent] {
return func(next echo.HandlerFunc) echo.HandlerFunc { return &hook.Handler[*core.RequestEvent]{
return func(c echo.Context) error { Id: DefaultRequireGuestOnlyMiddlewareId,
err := NewBadRequestError("The request can be accessed only by guests.", nil) Func: func(e *core.RequestEvent) error {
if e.Auth != nil {
record, _ := c.Get(ContextAuthRecordKey).(*models.Record) return router.NewBadRequestError("The request can be accessed only by guests.", nil)
if record != nil {
return err
} }
admin, _ := c.Get(ContextAdminKey).(*models.Admin) return e.Next()
if admin != nil { },
return err
}
return next(c)
}
} }
} }
// RequireRecordAuth middleware requires a request to have // RequireAuth middleware requires a request to have a valid record Authorization header.
// a valid record auth Authorization header.
// //
// The auth record could be from any collection. // The auth record could be from any collection.
// // You can further filter the allowed record auth collections by specifying their names.
// You can further filter the allowed record auth collections by
// specifying their names.
// //
// Example: // Example:
// //
// apis.RequireRecordAuth() // apis.RequireAuth() // any auth collection
// // apis.RequireAuth("_superusers", "users") // only the listed auth collections
// Or: func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] {
// return &hook.Handler[*core.RequestEvent]{
// apis.RequireRecordAuth("users", "supervisors") Id: DefaultRequireAuthMiddlewareId,
// Func: requireAuth(optCollectionNames...),
// To restrict the auth record only to the loaded context collection,
// use [apis.RequireSameContextRecordAuth()] instead.
func RequireRecordAuth(optCollectionNames ...string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires valid record authorization token to be set.", nil)
}
// check record collection name
if len(optCollectionNames) > 0 && !list.ExistInSlice(record.Collection().Name, optCollectionNames) {
return NewForbiddenError("The authorized record model is not allowed to perform this action.", nil)
}
return next(c)
}
} }
} }
// RequireSameContextRecordAuth middleware requires a request to have func requireAuth(optCollectionNames ...string) hook.HandlerFunc[*core.RequestEvent] {
// a valid record Authorization header. return func(e *core.RequestEvent) error {
// if e.Auth == nil {
// The auth record must be from the same collection already loaded in the context. return e.UnauthorizedError("The request requires valid record authorization token.", nil)
func RequireSameContextRecordAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires valid record authorization token to be set.", nil)
}
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil || record.Collection().Id != collection.Id {
return NewForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", record.Collection().Name), nil)
}
return next(c)
} }
// check record collection name
if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) {
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
}
return e.Next()
} }
} }
// RequireAdminAuth middleware requires a request to have // RequireSuperuserAuth middleware requires a request to have
// a valid admin Authorization header. // a valid superuser Authorization header.
func RequireAdminAuth() echo.MiddlewareFunc { func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] {
return func(next echo.HandlerFunc) echo.HandlerFunc { return &hook.Handler[*core.RequestEvent]{
return func(c echo.Context) error { Id: DefaultRequireSuperuserAuthMiddlewareId,
admin, _ := c.Get(ContextAdminKey).(*models.Admin) Func: requireAuth(core.CollectionNameSuperusers),
if admin == nil {
return NewUnauthorizedError("The request requires valid admin authorization token to be set.", nil)
}
return next(c)
}
} }
} }
// RequireAdminAuthOnlyIfAny middleware requires a request to have // RequireSuperuserAuthOnlyIfAny middleware requires a request to have
// a valid admin Authorization header ONLY if the application has // a valid superuser Authorization header ONLY if the application has
// at least 1 existing Admin model. // at least 1 existing superuser.
func RequireAdminAuthOnlyIfAny(app core.App) echo.MiddlewareFunc { func RequireSuperuserAuthOnlyIfAny() *hook.Handler[*core.RequestEvent] {
return func(next echo.HandlerFunc) echo.HandlerFunc { return &hook.Handler[*core.RequestEvent]{
return func(c echo.Context) error { Id: DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId,
admin, _ := c.Get(ContextAdminKey).(*models.Admin) Func: func(e *core.RequestEvent) error {
if admin != nil { if e.HasSuperuserAuth() {
return next(c) return e.Next()
} }
totalAdmins, err := app.Dao().TotalAdmins() totalSuperusers, err := e.App.CountRecords(core.CollectionNameSuperusers)
if err != nil { if err != nil {
return NewBadRequestError("Failed to fetch admins info.", err) return e.InternalServerError("Failed to fetch superusers info.", err)
} }
if totalAdmins == 0 { if totalSuperusers == 0 {
return next(c) return e.Next()
} }
return NewUnauthorizedError("The request requires valid admin authorization token to be set.", nil) return requireAuth(core.CollectionNameSuperusers)(e)
} },
} }
} }
// RequireAdminOrRecordAuth middleware requires a request to have // RequireSuperuserOrOwnerAuth middleware requires a request to have
// a valid admin or record Authorization header set. // a valid superuser or regular record owner Authorization header set.
// //
// You can further filter the allowed auth record collections by providing their names. // This middleware is similar to [apis.RequireAuth()] but
//
// This middleware is the opposite of [apis.RequireGuestOnly()].
func RequireAdminOrRecordAuth(optCollectionNames ...string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if admin == nil && record == nil {
return NewUnauthorizedError("The request requires admin or record authorization token to be set.", nil)
}
if record != nil && len(optCollectionNames) > 0 && !list.ExistInSlice(record.Collection().Name, optCollectionNames) {
return NewForbiddenError("The authorized record model is not allowed to perform this action.", nil)
}
return next(c)
}
}
}
// RequireAdminOrOwnerAuth middleware requires a request to have
// a valid admin or auth record owner Authorization header set.
//
// This middleware is similar to [apis.RequireAdminOrRecordAuth()] but
// for the auth record token expects to have the same id as the path // for the auth record token expects to have the same id as the path
// parameter ownerIdParam (default to "id" if empty). // parameter ownerIdPathParam (default to "id" if empty).
func RequireAdminOrOwnerAuth(ownerIdParam string) echo.MiddlewareFunc { func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] {
return func(next echo.HandlerFunc) echo.HandlerFunc { return &hook.Handler[*core.RequestEvent]{
return func(c echo.Context) error { Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId,
admin, _ := c.Get(ContextAdminKey).(*models.Admin) Func: func(e *core.RequestEvent) error {
if admin != nil { if e.Auth == nil {
return next(c) return e.UnauthorizedError("The request requires superuser or record authorization token.", nil)
} }
record, _ := c.Get(ContextAuthRecordKey).(*models.Record) if e.Auth.IsSuperuser() {
if record == nil { return e.Next()
return NewUnauthorizedError("The request requires admin or record authorization token to be set.", nil)
} }
if ownerIdParam == "" { if ownerIdPathParam == "" {
ownerIdParam = "id" ownerIdPathParam = "id"
} }
ownerId := c.PathParam(ownerIdParam) ownerId := e.Request.PathValue(ownerIdPathParam)
// note: it is "safe" to compare only the record id since the auth // note: it is considered "safe" to compare only the record id
// record ids are treated as unique across all auth collections // since the auth record ids are treated as unique across all auth collections
if record.Id != ownerId { if e.Auth.Id != ownerId {
return NewForbiddenError("You are not allowed to perform this request.", nil) return e.ForbiddenError("You are not allowed to perform this request.", nil)
} }
return next(c) return e.Next()
} },
} }
} }
// LoadAuthContext middleware reads the Authorization request header // RequireSameCollectionContextAuth middleware requires a request to have
// and loads the token related record or admin instance into the // a valid record Authorization header and the auth record's collection to
// request's context. // match the one from the route path parameter (default to "collection" if collectionParam is empty).
func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSameCollectionContextAuthMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
}
if collectionPathParam == "" {
collectionPathParam = "collection"
}
collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if collection == nil || e.Auth.Collection().Id != collection.Id {
return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil)
}
return e.Next()
},
}
}
// loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value.
// //
// This middleware is expected to be already registered by default for all routes. // This middleware does nothing in case of missing, invalid or expired token.
func LoadAuthContext(app core.App) echo.MiddlewareFunc { //
return func(next echo.HandlerFunc) echo.HandlerFunc { // This middleware is registered by default for all routes.
return func(c echo.Context) error { //
token := c.Request().Header.Get("Authorization") // Note: We don't throw an error on invalid or expired token to allow
// users to extend with their own custom handling in external middleware(s).
func loadAuthToken() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultLoadAuthTokenMiddlewareId,
Priority: DefaultLoadAuthTokenMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
token := getAuthTokenFromRequest(e)
if token == "" { if token == "" {
return next(c) return e.Next()
} }
// the schema is not required and it is only for record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth)
// compatibility with the defaults of some HTTP clients if err != nil {
token = strings.TrimPrefix(token, "Bearer ") e.App.Logger().Debug("loadAuthToken failure", "error", err)
} else if record != nil {
claims, _ := security.ParseUnverifiedJWT(token) e.Auth = record
tokenType := cast.ToString(claims["type"])
switch tokenType {
case tokens.TypeAdmin:
admin, err := app.Dao().FindAdminByToken(
token,
app.Settings().AdminAuthToken.Secret,
)
if err == nil && admin != nil {
c.Set(ContextAdminKey, admin)
}
case tokens.TypeAuthRecord:
record, err := app.Dao().FindAuthRecordByToken(
token,
app.Settings().RecordAuthToken.Secret,
)
if err == nil && record != nil {
c.Set(ContextAuthRecordKey, record)
}
} }
return next(c) return e.Next()
} },
} }
} }
// LoadCollectionContext middleware finds the collection with related func getAuthTokenFromRequest(e *core.RequestEvent) string {
// path identifier and loads it into the request context. token := e.Request.Header.Get("Authorization")
if token != "" {
// the schema prefix is not required and it is only for
// compatibility with the defaults of some HTTP clients
token = strings.TrimPrefix(token, "Bearer ")
}
return token
}
// wwwRedirect performs www->non-www redirect(s) if the request host
// matches with one of the values in redirectHosts.
// //
// Set optCollectionTypes to further filter the found collection by its type. // This middleware is registered by default on Serve for all routes.
func LoadCollectionContext(app core.App, optCollectionTypes ...string) echo.MiddlewareFunc { func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] {
return func(next echo.HandlerFunc) echo.HandlerFunc { return &hook.Handler[*core.RequestEvent]{
return func(c echo.Context) error { Id: DefaultWWWRedirectMiddlewareId,
if param := c.PathParam("collection"); param != "" { Priority: DefaultWWWRedirectMiddlewarePriority,
collection, err := core.FindCachedCollectionByNameOrId(app, param) Func: func(e *core.RequestEvent) error {
if err != nil || collection == nil { host := e.Request.Host
return NewNotFoundError("", err)
}
if len(optCollectionTypes) > 0 && !list.ExistInSlice(collection.Type, optCollectionTypes) { if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) {
return NewBadRequestError("Unsupported collection type.", nil) return e.Redirect(
} http.StatusTemporaryRedirect,
(e.Request.URL.Scheme + "://" + host[4:] + e.Request.RequestURI),
c.Set(ContextCollectionKey, collection) )
} }
return next(c) return e.Next()
} },
} }
} }
// ActivityLogger middleware takes care to save the request information // securityHeaders middleware adds common security headers to the response.
//
// This middleware is registered by default for all routes.
func securityHeaders() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultSecurityHeadersMiddlewareId,
Priority: DefaultSecurityHeadersMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
e.Response.Header().Set("X-XSS-Protection", "1; mode=block")
e.Response.Header().Set("X-Content-Type-Options", "nosniff")
e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN")
// @todo consider a default HSTS?
// (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/)
return e.Next()
},
}
}
// SkipSuccessActivityLog is a helper middleware that instructs the global
// activity logger to log only requests that have failed/returned an error.
func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultSkipSuccessActivityLogMiddlewareId,
Func: func(e *core.RequestEvent) error {
e.Set(requestEventKeySkipSuccessActivityLog, true)
return e.Next()
},
}
}
// activityLogger middleware takes care to save the request information
// into the logs database. // into the logs database.
// //
// This middleware is registered by default for all routes.
//
// The middleware does nothing if the app logs retention period is zero // The middleware does nothing if the app logs retention period is zero
// (aka. app.Settings().Logs.MaxDays = 0). // (aka. app.Settings().Logs.MaxDays = 0).
func ActivityLogger(app core.App) echo.MiddlewareFunc { //
return func(next echo.HandlerFunc) echo.HandlerFunc { // Users can attach the [apis.SkipSuccessActivityLog()] middleware if
return func(c echo.Context) error { // you want to log only the failed requests.
if err := next(c); err != nil { func activityLogger() *hook.Handler[*core.RequestEvent] {
return err return &hook.Handler[*core.RequestEvent]{
} Id: DefaultActivityLoggerMiddlewareId,
Priority: DefaultActivityLoggerMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
e.Set(requestEventKeyExecStart, time.Now())
logRequest(app, c, nil) err := e.Next()
return nil logRequest(e, err)
}
return err
},
} }
} }
func logRequest(app core.App, c echo.Context, err *ApiError) { func logRequest(event *core.RequestEvent, err error) {
// no logs retention // no logs retention
if app.Settings().Logs.MaxDays == 0 { if event.App.Settings().Logs.MaxDays == 0 {
return
}
// the non-error route has explicitly disabled the activity logger
if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil {
return return
} }
@ -307,32 +325,31 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
attrs = append(attrs, slog.String("type", "request")) attrs = append(attrs, slog.String("type", "request"))
started := cast.ToTime(c.Get(ContextExecStartKey)) started := cast.ToTime(event.Get(requestEventKeyExecStart))
if !started.IsZero() { if !started.IsZero() {
attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond))) attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond)))
} }
httpRequest := c.Request() if meta := event.Get(RequestEventKeyLogMeta); meta != nil {
httpResponse := c.Response() attrs = append(attrs, slog.Any("meta", meta))
method := strings.ToUpper(httpRequest.Method) }
status := httpResponse.Status
requestUri := httpRequest.URL.RequestURI() status := event.Status()
method := cutStr(strings.ToUpper(event.Request.Method), 50)
requestUri := cutStr(event.Request.URL.RequestURI(), 3000)
// parse the request error // parse the request error
if err != nil { if err != nil {
status = err.Code if apiErr, ok := err.(*router.ApiError); ok {
attrs = append( status = apiErr.Status
attrs, attrs = append(
slog.String("error", err.Message), attrs,
slog.Any("details", err.RawData()), slog.String("error", apiErr.Message),
) slog.Any("details", apiErr.RawData()),
} )
} else {
requestAuth := models.RequestAuthGuest attrs = append(attrs, slog.String("error", err.Error()))
if c.Get(ContextAuthRecordKey) != nil { }
requestAuth = models.RequestAuthRecord
} else if c.Get(ContextAdminKey) != nil {
requestAuth = models.RequestAuthAdmin
} }
attrs = append( attrs = append(
@ -340,17 +357,33 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
slog.String("url", requestUri), slog.String("url", requestUri),
slog.String("method", method), slog.String("method", method),
slog.Int("status", status), slog.Int("status", status),
slog.String("auth", requestAuth), slog.String("referer", cutStr(event.Request.Referer(), 2000)),
slog.String("referer", httpRequest.Referer()), slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)),
slog.String("userAgent", httpRequest.UserAgent()),
) )
if app.Settings().Logs.LogIp { if event.Auth != nil {
ip, _, _ := net.SplitHostPort(httpRequest.RemoteAddr) attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name))
if event.App.Settings().Logs.LogAuthId {
attrs = append(attrs, slog.String("authId", event.Auth.Id))
}
} else {
attrs = append(attrs, slog.String("auth", ""))
}
if event.App.Settings().Logs.LogIP {
var userIP string
if len(event.App.Settings().TrustedProxy.Headers) > 0 {
userIP = event.RealIP()
} else {
// fallback to the legacy behavior (it is "safe" since it is only for log purposes)
userIP = cutStr(event.UnsafeRealIP(), 50)
}
attrs = append( attrs = append(
attrs, attrs,
slog.String("userIp", realUserIp(httpRequest, ip)), slog.String("userIP", userIP),
slog.String("remoteIp", ip), slog.String("remoteIP", event.RemoteIP()),
) )
} }
@ -358,64 +391,23 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
routine.FireAndForget(func() { routine.FireAndForget(func() {
message := method + " " message := method + " "
if escaped, err := url.PathUnescape(requestUri); err == nil { if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil {
message += escaped message += escaped
} else { } else {
message += requestUri message += requestUri
} }
if err != nil { if err != nil {
app.Logger().Error(message, attrs...) event.App.Logger().Error(message, attrs...)
} else { } else {
app.Logger().Info(message, attrs...) event.App.Logger().Info(message, attrs...)
} }
}) })
} }
// Returns the "real" user IP from common proxy headers (or fallbackIp if none is found). func cutStr(str string, max int) string {
// if len(str) > max {
// The returned IP value shouldn't be trusted if not behind a trusted reverse proxy! return str[:max] + "..."
func realUserIp(r *http.Request, fallbackIp string) string {
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("Fly-Client-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ipsList := r.Header.Get("X-Forwarded-For"); ipsList != "" {
// extract the first non-empty leftmost-ish ip
ips := strings.Split(ipsList, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip != "" {
return ip
}
}
}
return fallbackIp
}
// @todo consider removing as this may no longer be needed due to the custom rest.MultiBinder.
//
// eagerRequestInfoCache ensures that the request data is cached in the request
// context to allow reading for example the json request body data more than once.
func eagerRequestInfoCache(app core.App) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
switch c.Request().Method {
// currently we are eagerly caching only the requests with body
case "POST", "PUT", "PATCH", "DELETE":
RequestInfo(c)
}
return next(c)
}
} }
return str
} }

View File

@ -0,0 +1,123 @@
package apis
import (
"io"
"net/http"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
var ErrRequestEntityTooLarge = router.NewApiError(http.StatusRequestEntityTooLarge, "Request entity too large", nil)
const DefaultMaxBodySize int64 = 32 << 20
const (
DefaultBodyLimitMiddlewareId = "pbBodyLimit"
DefaultBodyLimitMiddlewarePriority = DefaultRateLimitMiddlewarePriority + 10
)
// BodyLimit returns a middleware function that changes the default request body size limit.
//
// Note that in order to have effect this middleware should be registered
// before other middlewares that reads the request body.
//
// If limitBytes <= 0, no limit is applied.
//
// Otherwise, if the request body size exceeds the configured limitBytes,
// it sends 413 error response.
func BodyLimit(limitBytes int64) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultBodyLimitMiddlewareId,
Priority: DefaultBodyLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
err := applyBodyLimit(e, limitBytes)
if err != nil {
return err
}
return e.Next()
},
}
}
func dynamicCollectionBodyLimit(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
if collectionPathParam == "" {
collectionPathParam = "collection"
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultBodyLimitMiddlewareId,
Priority: DefaultBodyLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if err != nil {
return e.NotFoundError("Missing or invalid collection context.", err)
}
limitBytes := DefaultMaxBodySize
if !collection.IsView() {
for _, f := range collection.Fields {
if calc, ok := f.(core.MaxBodySizeCalculator); ok {
limitBytes += calc.CalculateMaxBodySize()
}
}
}
err = applyBodyLimit(e, limitBytes)
if err != nil {
return err
}
return e.Next()
},
}
}
func applyBodyLimit(e *core.RequestEvent, limitBytes int64) error {
// no limit
if limitBytes <= 0 {
return nil
}
// optimistically check the submitted request content length
if e.Request.ContentLength > limitBytes {
return ErrRequestEntityTooLarge
}
// replace the request body
//
// note: we don't use sync.Pool since the size of the elements could vary too much
// and it might not be efficient (see https://github.com/golang/go/issues/23199)
e.Request.Body = &limitedReader{ReadCloser: e.Request.Body, limit: limitBytes}
return nil
}
type limitedReader struct {
io.ReadCloser
limit int64
totalRead int64
}
func (r *limitedReader) Read(b []byte) (int, error) {
n, err := r.ReadCloser.Read(b)
if err != nil {
return n, err
}
r.totalRead += int64(n)
if r.totalRead > r.limit {
return n, ErrRequestEntityTooLarge
}
return n, nil
}
func (r *limitedReader) Reread() {
rr, ok := r.ReadCloser.(router.Rereader)
if ok {
rr.Reread()
}
}

View File

@ -0,0 +1,60 @@
package apis_test
import (
"bytes"
"fmt"
"net/http/httptest"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestBodyLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.POST("/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
}) // default global BodyLimit check
pbRouter.POST("/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
}).Bind(apis.BodyLimit(20))
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
size int64
expectedStatus int
}{
{"/a", 21, 200},
{"/a", apis.DefaultMaxBodySize + 1, 413},
{"/b", 20, 200},
{"/b", 21, 413},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("%s_%d", s.url, s.size), func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("POST", s.url, bytes.NewReader(make([]byte, s.size)))
mux.ServeHTTP(rec, req)
result := rec.Result()
defer result.Body.Close()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}

307
apis/middlewares_cors.go Normal file
View File

@ -0,0 +1,307 @@
package apis
// -------------------------------------------------------------------
// This middleware is ported from echo/middleware to minimize the breaking
// changes and differences in the API behavior from earlier PocketBase versions
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go).
//
// I doubt that this would matter for most cases, but the only major difference
// is that for non-supported routes this middleware doesn't return 405 and fallbacks
// to the default catch-all PocketBase route (aka. returns 404) to avoid
// the extra overhead of further hijacking and wrapping the Go default mux
// (https://github.com/golang/go/issues/65648#issuecomment-1955328807).
// -------------------------------------------------------------------
import (
"net/http"
"regexp"
"strconv"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
)
const (
DefaultCorsMiddlewareId = "pbCors"
DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted
)
// CORSConfig defines the config for CORS middleware.
type CORSConfig struct {
// AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the
// resource. The wildcard characters '*' and '?' are supported and are
// converted to regex fragments '.*' and '.' accordingly.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional. Default value []string{"*"}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional.
AllowOriginFunc func(origin string) (bool, error)
// AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when
// accessing the resource. This is used in response to a preflight request.
//
// Optional. Default value DefaultCORSConfig.AllowMethods.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
AllowMethods []string
// AllowHeaders determines the value of the Access-Control-Allow-Headers
// response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request.
//
// Optional. Default value []string{}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
AllowHeaders []string
// AllowCredentials determines the value of the
// Access-Control-Allow-Credentials response header. This header indicates
// whether or not the response to the request can be exposed when the
// credentials mode (Request.credentials) is true. When used as part of a
// response to a preflight request, this indicates whether or not the actual
// request can be made using credentials. See also
// [MDN: Access-Control-Allow-Credentials].
//
// Optional. Default value false, in which case the header is not set.
//
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
//
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
//
// Optional. Default value is false.
UnsafeWildcardOriginWithAllowCredentials bool
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access.
//
// Optional. Default value []string{}, in which case the header is not set.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
ExposeHeaders []string
// MaxAge determines the value of the Access-Control-Max-Age response header.
// This header indicates how long (in seconds) the results of a preflight
// request can be cached.
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
//
// Optional. Default value 0 - meaning header is not sent.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
MaxAge int
}
// DefaultCORSConfig is the default CORS middleware config.
var DefaultCORSConfig = CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
// CORSWithConfig returns a CORS middleware with config.
func CORSWithConfig(config CORSConfig) hook.HandlerFunc[*core.RequestEvent] {
// Defaults
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := "0"
if config.MaxAge > 0 {
maxAge = strconv.Itoa(config.MaxAge)
}
return func(e *core.RequestEvent) error {
req := e.Request
res := e.Response
origin := req.Header.Get("Origin")
allowOrigin := ""
res.Header().Add("Vary", "Origin")
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
if !preflight {
return e.Next()
}
return e.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
checkPatterns := false
if allowOrigin == "" {
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
checkPatterns = true
}
}
if checkPatterns {
for _, re := range allowOriginPatterns {
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return e.Next()
}
return e.NoContent(http.StatusNoContent)
}
res.Header().Set("Access-Control-Allow-Origin", allowOrigin)
if config.AllowCredentials {
res.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Simple request
if !preflight {
if exposeHeaders != "" {
res.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
}
return e.Next()
}
// Preflight request
res.Header().Add("Vary", "Access-Control-Request-Method")
res.Header().Add("Vary", "Access-Control-Request-Headers")
res.Header().Set("Access-Control-Allow-Methods", allowMethods)
if allowHeaders != "" {
res.Header().Set("Access-Control-Allow-Headers", allowHeaders)
} else {
h := req.Header.Get("Access-Control-Request-Headers")
if h != "" {
res.Header().Set("Access-Control-Allow-Headers", h)
}
}
if config.MaxAge != 0 {
res.Header().Set("Access-Control-Max-Age", maxAge)
}
return e.NoContent(http.StatusNoContent)
}
}
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}

237
apis/middlewares_gzip.go Normal file
View File

@ -0,0 +1,237 @@
package apis
// -------------------------------------------------------------------
// This middleware is ported from echo/middleware to minimize the breaking
// changes and differences in the API behavior from earlier PocketBase versions
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/compress.go).
// -------------------------------------------------------------------
import (
"bufio"
"bytes"
"compress/gzip"
"errors"
"io"
"net"
"net/http"
"strings"
"sync"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
const (
gzipScheme = "gzip"
)
// GzipConfig defines the config for Gzip middleware.
type GzipConfig struct {
// Gzip compression level.
// Optional. Default value -1.
Level int
// Length threshold before gzip compression is applied.
// Optional. Default value 0.
//
// Most of the time you will not need to change the default. Compressing
// a short response might increase the transmitted data because of the
// gzip format overhead. Compressing the response will also consume CPU
// and time on the server and the client (for decompressing). Depending on
// your use case such a threshold might be useful.
//
// See also:
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
MinLength int
}
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() hook.HandlerFunc[*core.RequestEvent] {
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) hook.HandlerFunc[*core.RequestEvent] {
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
panic(errors.New("invalid gzip level"))
}
if config.Level == 0 {
config.Level = -1
}
if config.MinLength < 0 {
config.MinLength = 0
}
pool := sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
bpool := sync.Pool{
New: func() interface{} {
b := &bytes.Buffer{}
return b
},
}
return func(e *core.RequestEvent) error {
e.Response.Header().Add("Vary", "Accept-Encoding")
if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) {
w, ok := pool.Get().(*gzip.Writer)
if !ok {
return e.InternalServerError("", errors.New("failed to get gzip.Writer"))
}
rw := e.Response
w.Reset(rw)
buf := bpool.Get().(*bytes.Buffer)
buf.Reset()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
defer func() {
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
if !grw.wroteBody {
if rw.Header().Get("Content-Encoding") == gzipScheme {
rw.Header().Del("Content-Encoding")
}
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue echo#424, echo#407.
e.Response = rw
w.Reset(io.Discard)
} else if !grw.minLengthExceeded {
// Write uncompressed response
e.Response = rw
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
grw.buffer.WriteTo(rw)
w.Reset(io.Discard)
}
w.Close()
bpool.Put(buf)
pool.Put(w)
}()
e.Response = grw
}
return e.Next()
}
}
type gzipResponseWriter struct {
http.ResponseWriter
io.Writer
buffer *bytes.Buffer
minLength int
code int
wroteHeader bool
wroteBody bool
minLengthExceeded bool
}
func (w *gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length") // Issue echo#444
w.wroteHeader = true
// Delay writing of the header until we know if we'll actually compress the response
w.code = code
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", http.DetectContentType(b))
}
w.wroteBody = true
if !w.minLengthExceeded {
n, err := w.buffer.Write(b)
if w.buffer.Len() >= w.minLength {
w.minLengthExceeded = true
// The minimum length is exceeded, add Content-Encoding header and write the header
w.Header().Set("Content-Encoding", gzipScheme)
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
return w.Writer.Write(w.buffer.Bytes())
}
return n, err
}
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
if !w.minLengthExceeded {
// Enforce compression because we will not know how much more data will come
w.minLengthExceeded = true
w.Header().Set("Content-Encoding", gzipScheme)
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
_, _ = w.Writer.Write(w.buffer.Bytes())
}
_ = w.Writer.(*gzip.Writer).Flush()
_ = http.NewResponseController(w.ResponseWriter).Flush()
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.ResponseWriter).Hijack()
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
rw := w.ResponseWriter
for {
switch p := rw.(type) {
case http.Pusher:
return p.Push(target, opts)
case router.RWUnwrapper:
rw = p.Unwrap()
default:
return http.ErrNotSupported
}
}
}
func (w *gzipResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
rw := w.ResponseWriter
for {
switch rf := rw.(type) {
case io.ReaderFrom:
return rf.ReadFrom(r)
case router.RWUnwrapper:
rw = rf.Unwrap()
default:
return io.Copy(w.ResponseWriter, r)
}
}
}
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@ -0,0 +1,298 @@
package apis
import (
"sync"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/store"
)
const (
DefaultRateLimitMiddlewareId = "pbRateLimit"
DefaultRateLimitMiddlewarePriority = -1000
)
const (
rateLimitersStoreKey = "__pbRateLimiters__"
rateLimitersCronKey = "__pbRateLimitersCleanup__"
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
)
// rateLimit defines the global rate limit middleware.
//
// This middleware is registered by default for all routes.
func rateLimit() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRateLimitMiddlewareId,
Priority: DefaultRateLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
if skipRateLimit(e) {
return e.Next()
}
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(defaultRateLimitLabels(e))
if ok {
err := checkRateLimit(e, e.Request.Pattern, rule)
if err != nil {
return err
}
}
return e.Next()
},
}
}
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
if collectionPathParam == "" {
collectionPathParam = "collection"
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRateLimitMiddlewareId,
Priority: DefaultRateLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if err != nil {
return e.NotFoundError("Missing or invalid collection context.", err)
}
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
return err
}
return e.Next()
},
}
}
// checkCollectionRateLimit checks whether the current request satisfy the
// rate limit configuration for the specific collection.
//
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
if skipRateLimit(e) {
return nil
}
labels := make([]string, 0, 2+len(baseTags)*2)
rtId := collection.Id + e.Request.Pattern
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
for _, baseTag := range baseTags {
rtId += baseTag
labels = append(labels, collection.Name+":"+baseTag)
}
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
for _, baseTag := range baseTags {
labels = append(labels, "*:"+baseTag)
}
labels = append(labels, defaultRateLimitLabels(e)...)
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels)
if ok {
return checkRateLimit(e, rtId, rule)
}
return nil
}
// -------------------------------------------------------------------
// @todo consider exporting as RateLimit helper?
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
return initRateLimitersStore(e.App)
}).(*store.Store[*rateLimiter])
if rateLimiters == nil {
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
return nil
}
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
})
if rt == nil {
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
return nil
}
key := e.RealIP()
if key == "" {
e.App.Logger().Warn("Empty rate limit client key")
return nil
}
if !rt.isAllowed(key) {
return e.TooManyRequestsError("", nil)
}
return nil
}
func skipRateLimit(e *core.RequestEvent) bool {
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
}
func defaultRateLimitLabels(e *core.RequestEvent) []string {
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
}
func destroyRateLimitersStore(app core.App) {
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
app.Cron().Remove(rateLimitersCronKey)
app.Store().Remove(rateLimitersStoreKey)
}
func initRateLimitersStore(app core.App) *store.Store[*rateLimiter] {
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter])
if !ok {
return
}
limiters := limitersStore.GetAll()
for _, limiter := range limiters {
limiter.clean()
}
})
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
Id: rateLimitersSettingsHookId,
Func: func(e *core.SettingsReloadEvent) error {
err := e.Next()
if err != nil {
return err
}
// reset
destroyRateLimitersStore(e.App)
return nil
},
})
return store.New[*rateLimiter](nil)
}
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
return &rateLimiter{
maxAllowed: maxAllowed,
interval: intervalInSec,
minDeleteInterval: minDeleteIntervalInSec,
clients: map[string]*fixedWindow{},
}
}
type rateLimiter struct {
clients map[string]*fixedWindow
maxAllowed int
interval int64
minDeleteInterval int64
totalDeleted int64
sync.RWMutex
}
func (rt *rateLimiter) isAllowed(key string) bool {
// lock only reads to minimize locks contention
rt.RLock()
client, ok := rt.clients[key]
rt.RUnlock()
if !ok {
rt.Lock()
// check again in case the client was added by another request
client, ok = rt.clients[key]
if !ok {
client = newFixedWindow(rt.maxAllowed, rt.interval)
rt.clients[key] = client
}
rt.Unlock()
}
return client.consume()
}
func (rt *rateLimiter) clean() {
rt.Lock()
defer rt.Unlock()
nowUnix := time.Now().Unix()
for k, client := range rt.clients {
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
delete(rt.clients, k)
rt.totalDeleted++
}
}
// "shrink" the map if too may items were deleted
//
// @todo remove after https://github.com/golang/go/issues/20135
if rt.totalDeleted >= 300 {
shrunk := make(map[string]*fixedWindow, len(rt.clients))
for k, v := range rt.clients {
shrunk[k] = v
}
rt.clients = shrunk
rt.totalDeleted = 0
}
}
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
return &fixedWindow{
maxAllowed: maxAllowed,
interval: intervalInSec,
}
}
type fixedWindow struct {
// use plain Mutex instead of RWMutex since the operations are expected
// to be mostly writes (e.g. consume()) and it should perform better
sync.Mutex
maxAllowed int // the max allowed tokens per interval
available int // the total available tokens
interval int64 // in seconds
lastConsume int64 // the time of the last consume
}
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
// (usually used to perform periodic cleanup of staled instances).
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
l.Lock()
defer l.Unlock()
return relativeNow-l.lastConsume > minElapsed
}
// consume decrease the current window allowance with 1 (if not exhausted already).
//
// It returns false if the allowance has been already exhausted and the user
// has to wait until it resets back to its maxAllowed value.
func (l *fixedWindow) consume() bool {
l.Lock()
defer l.Unlock()
nowUnix := time.Now().Unix()
// reset consumed counter
if nowUnix-l.lastConsume >= l.interval {
l.available = l.maxAllowed
}
if l.available > 0 {
l.available--
l.lastConsume = nowUnix
return true
}
return false
}

View File

@ -0,0 +1,103 @@
package apis_test
import (
"net/http/httptest"
"testing"
"time"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestDefaultRateLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{
Label: "/rate/",
MaxRequests: 2,
Duration: 1,
},
{
Label: "/rate/b",
MaxRequests: 3,
Duration: 1,
},
{
Label: "POST /rate/b",
MaxRequests: 1,
Duration: 1,
},
}
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
return e.String(200, "norate")
}).BindFunc(func(e *core.RequestEvent) error {
return e.Next()
})
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
})
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
})
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
wait float64
expectedStatus int
}{
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 429},
{"/rate/a", 0, 429},
{"/rate/a", 1.1, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 429},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 429},
{"/rate/b", 1.1, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 429},
}
for _, s := range scenarios {
t.Run(s.url, func(t *testing.T) {
if s.wait > 0 {
time.Sleep(time.Duration(s.wait) * time.Second)
}
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", s.url, nil)
mux.ServeHTTP(rec, req)
result := rec.Result()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +1,17 @@
package apis_test package apis_test
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/subscriptions" "github.com/pocketbase/pocketbase/tools/subscriptions"
) )
@ -22,7 +19,7 @@ func TestRealtimeConnect(t *testing.T) {
scenarios := []tests.ApiScenario{ scenarios := []tests.ApiScenario{
{ {
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/realtime", URL: "/api/realtime",
Timeout: 100 * time.Millisecond, Timeout: 100 * time.Millisecond,
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
@ -31,12 +28,11 @@ func TestRealtimeConnect(t *testing.T) {
`data:{"clientId":`, `data:{"clientId":`,
}, },
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeConnectRequest": 1, "*": 0,
"OnRealtimeBeforeMessageSend": 1, "OnRealtimeConnectRequest": 1,
"OnRealtimeAfterMessageSend": 1, "OnRealtimeMessageSend": 1,
"OnRealtimeDisconnectRequest": 1,
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 { if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
} }
@ -45,23 +41,23 @@ func TestRealtimeConnect(t *testing.T) {
{ {
Name: "PB_CONNECT interrupt", Name: "PB_CONNECT interrupt",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/realtime", URL: "/api/realtime",
Timeout: 100 * time.Millisecond, Timeout: 100 * time.Millisecond,
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeConnectRequest": 1, "*": 0,
"OnRealtimeBeforeMessageSend": 1, "OnRealtimeConnectRequest": 1,
"OnRealtimeDisconnectRequest": 1, "OnRealtimeMessageSend": 1,
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
if e.Message.Name == "PB_CONNECT" { if e.Message.Name == "PB_CONNECT" {
return errors.New("PB_CONNECT error") return errors.New("PB_CONNECT error")
} }
return nil return e.Next()
}) })
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 { if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
} }
@ -70,20 +66,20 @@ func TestRealtimeConnect(t *testing.T) {
{ {
Name: "Skipping/ignoring messages", Name: "Skipping/ignoring messages",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/realtime", URL: "/api/realtime",
Timeout: 100 * time.Millisecond, Timeout: 100 * time.Millisecond,
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeConnectRequest": 1, "*": 0,
"OnRealtimeBeforeMessageSend": 1, "OnRealtimeConnectRequest": 1,
"OnRealtimeDisconnectRequest": 1, "OnRealtimeMessageSend": 1,
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
return hook.StopPropagation return nil
}) })
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 { if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
} }
@ -101,34 +97,34 @@ func TestRealtimeSubscribe(t *testing.T) {
resetClient := func() { resetClient := func() {
client.Unsubscribe() client.Unsubscribe()
client.Set(apis.ContextAdminKey, nil) client.Set(apis.RealtimeClientAuthKey, nil)
client.Set(apis.ContextAuthRecordKey, nil)
} }
scenarios := []tests.ApiScenario{ scenarios := []tests.ApiScenario{
{ {
Name: "missing client", Name: "missing client",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/realtime", URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`), Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 404, ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "existing client - empty subscriptions", Name: "existing client - empty subscriptions",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/realtime", URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`), Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`),
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1, "*": 0,
"OnRealtimeAfterSubscribeRequest": 1, "OnRealtimeSubscribeRequest": 1,
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0") client.Subscribe("test0")
app.SubscriptionsBroker().Register(client) app.SubscriptionsBroker().Register(client)
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(client.Subscriptions()) != 0 { if len(client.Subscriptions()) != 0 {
t.Errorf("Expected no subscriptions, got %v", client.Subscriptions()) t.Errorf("Expected no subscriptions, got %v", client.Subscriptions())
} }
@ -138,18 +134,18 @@ func TestRealtimeSubscribe(t *testing.T) {
{ {
Name: "existing client - 2 new subscriptions", Name: "existing client - 2 new subscriptions",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/realtime", URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1, "*": 0,
"OnRealtimeAfterSubscribeRequest": 1, "OnRealtimeSubscribeRequest": 1,
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0") client.Subscribe("test0")
app.SubscriptionsBroker().Register(client) app.SubscriptionsBroker().Register(client)
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
expectedSubs := []string{"test1", "test2"} expectedSubs := []string{"test1", "test2"}
if len(expectedSubs) != len(client.Subscriptions()) { if len(expectedSubs) != len(client.Subscriptions()) {
t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions()) t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions())
@ -164,49 +160,49 @@ func TestRealtimeSubscribe(t *testing.T) {
}, },
}, },
{ {
Name: "existing client - authorized admin", Name: "existing client - authorized superuser",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/realtime", URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1, "*": 0,
"OnRealtimeAfterSubscribeRequest": 1, "OnRealtimeSubscribeRequest": 1,
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client) app.SubscriptionsBroker().Register(client)
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
admin, _ := client.Get(apis.ContextAdminKey).(*models.Admin) authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if admin == nil { if authRecord == nil || !authRecord.IsSuperuser() {
t.Errorf("Expected admin auth model, got nil") t.Errorf("Expected superuser auth record, got %v", authRecord)
} }
resetClient() resetClient()
}, },
}, },
{ {
Name: "existing client - authorized record", Name: "existing client - authorized regular record",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/realtime", URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1, "*": 0,
"OnRealtimeAfterSubscribeRequest": 1, "OnRealtimeSubscribeRequest": 1,
}, },
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client) app.SubscriptionsBroker().Register(client)
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil { if authRecord == nil {
t.Errorf("Expected auth record model, got nil") t.Errorf("Expected regular user auth record, got %v", authRecord)
} }
resetClient() resetClient()
}, },
@ -214,22 +210,50 @@ func TestRealtimeSubscribe(t *testing.T) {
{ {
Name: "existing client - mismatched auth", Name: "existing client - mismatched auth",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/realtime", URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 403, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
initialAuth := &models.Record{} user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
initialAuth.RefreshId() if err != nil {
client.Set(apis.ContextAuthRecordKey, initialAuth) t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client) app.SubscriptionsBroker().Register(client)
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
}
resetClient()
},
},
{
Name: "existing client - unauthorized client",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil { if authRecord == nil {
t.Errorf("Expected auth record model, got nil") t.Errorf("Expected auth record model, got nil")
} }
@ -247,24 +271,29 @@ func TestRealtimeAuthRecordDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp() testApp, _ := tests.NewTestApp()
defer testApp.Cleanup() defer testApp.Cleanup()
apis.InitApi(testApp) // init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client := subscriptions.NewDefaultClient() client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord) client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client) testApp.SubscriptionsBroker().Register(client)
// mock delete event
e := new(core.ModelEvent) e := new(core.ModelEvent)
e.Dao = testApp.Dao() e.App = testApp
e.Type = core.ModelEventTypeDelete
e.Context = context.Background()
e.Model = authRecord e.Model = authRecord
testApp.OnModelAfterDelete().Trigger(e)
if len(testApp.SubscriptionsBroker().Clients()) != 0 { testApp.OnModelAfterDeleteSuccess().Trigger(e)
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
if total := len(testApp.SubscriptionsBroker().Clients()); total != 0 {
t.Fatalf("Expected no subscription clients, found %d", total)
} }
} }
@ -272,111 +301,58 @@ func TestRealtimeAuthRecordUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp() testApp, _ := tests.NewTestApp()
defer testApp.Cleanup() defer testApp.Cleanup()
apis.InitApi(testApp) // init realtime handlers
apis.NewRouter(testApp)
authRecord1, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client := subscriptions.NewDefaultClient() client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord1) client.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client) testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord and change its email // refetch the authRecord and change its email
authRecord2, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") authRecord2, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
authRecord2.SetEmail("new@example.com") authRecord2.SetEmail("new@example.com")
// mock update event
e := new(core.ModelEvent) e := new(core.ModelEvent)
e.Dao = testApp.Dao() e.App = testApp
e.Type = core.ModelEventTypeUpdate
e.Context = context.Background()
e.Model = authRecord2 e.Model = authRecord2
testApp.OnModelAfterUpdate().Trigger(e)
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) testApp.OnModelAfterUpdateSuccess().Trigger(e)
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if clientAuthRecord.Email() != authRecord2.Email() { if clientAuthRecord.Email() != authRecord2.Email() {
t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email()) t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email())
} }
} }
func TestRealtimeAdminDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
admin, err := testApp.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAdminKey, admin)
testApp.SubscriptionsBroker().Register(client)
e := new(core.ModelEvent)
e.Dao = testApp.Dao()
e.Model = admin
testApp.OnModelAfterDelete().Trigger(e)
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
}
}
func TestRealtimeAdminUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
admin1, err := testApp.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAdminKey, admin1)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord and change its email
admin2, err := testApp.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
admin2.Email = "new@example.com"
e := new(core.ModelEvent)
e.Dao = testApp.Dao()
e.Model = admin2
testApp.OnModelAfterUpdate().Trigger(e)
clientAdmin, _ := client.Get(apis.ContextAdminKey).(*models.Admin)
if clientAdmin.Email != admin2.Email {
t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
}
}
// Custom auth record model struct // Custom auth record model struct
// ------------------------------------------------------------------- // -------------------------------------------------------------------
var _ models.Model = (*CustomUser)(nil) var _ core.Model = (*CustomUser)(nil)
type CustomUser struct { type CustomUser struct {
models.BaseModel core.BaseModel
Email string `db:"email" json:"email"` Email string `db:"email" json:"email"`
} }
func (m *CustomUser) TableName() string { func (m *CustomUser) TableName() string {
return "users" // the name of your collection return "users"
} }
func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) { func findCustomUserByEmail(app core.App, email string) (*CustomUser, error) {
model := &CustomUser{} model := &CustomUser{}
err := dao.ModelQuery(model). err := app.ModelQuery(model).
AndWhere(dbx.HashExp{"email": email}). AndWhere(dbx.HashExp{"email": email}).
Limit(1). Limit(1).
One(model) One(model)
@ -392,30 +368,31 @@ func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp() testApp, _ := tests.NewTestApp()
defer testApp.Cleanup() defer testApp.Cleanup()
apis.InitApi(testApp) // init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client := subscriptions.NewDefaultClient() client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord) client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client) testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser // refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com") customUser, err := findCustomUserByEmail(testApp, "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// delete the custom user (should unset the client auth record) // delete the custom user (should unset the client auth record)
if err := testApp.Dao().Delete(customUser); err != nil { if err := testApp.Delete(customUser); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(testApp.SubscriptionsBroker().Clients()) != 0 { if total := len(testApp.SubscriptionsBroker().Clients()); total != 0 {
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients())) t.Fatalf("Expected no subscription clients, found %d", total)
} }
} }
@ -423,30 +400,31 @@ func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp() testApp, _ := tests.NewTestApp()
defer testApp.Cleanup() defer testApp.Cleanup()
apis.InitApi(testApp) // init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client := subscriptions.NewDefaultClient() client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord) client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client) testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser // refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com") customUser, err := findCustomUserByEmail(testApp, "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// change its email // change its email
customUser.Email = "new@example.com" customUser.Email = "new@example.com"
if err := testApp.Dao().Save(customUser); err != nil { if err := testApp.Save(customUser); err != nil {
t.Fatal(err) t.Fatal(err)
} }
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if clientAuthRecord.Email() != customUser.Email { if clientAuthRecord.Email() != customUser.Email {
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email()) t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
} }

View File

@ -1,765 +1,75 @@
package apis package apis
import ( import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"sort"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
) )
// bindRecordAuthApi registers the auth record api endpoints and // bindRecordAuthApi registers the auth record api endpoints and
// the corresponding handlers. // the corresponding handlers.
func bindRecordAuthApi(app core.App, rg *echo.Group) { func bindRecordAuthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := recordAuthApi{app: app}
// global oauth2 subscription redirect handler // global oauth2 subscription redirect handler
rg.GET("/oauth2-redirect", api.oauth2SubscriptionRedirect) rg.GET("/oauth2-redirect", oauth2SubscriptionRedirect)
rg.POST("/oauth2-redirect", api.oauth2SubscriptionRedirect) // needed in case of response_mode=form_post // add again as POST in case of response_mode=form_post
rg.POST("/oauth2-redirect", oauth2SubscriptionRedirect)
// common collection record related routes sub := rg.Group("/collections/{collection}")
subGroup := rg.Group(
"/collections/:collection", sub.GET("/auth-methods", recordAuthMethods).Bind(
ActivityLogger(app), collectionPathRateLimit("", "listAuthMethods"),
LoadCollectionContext(app, models.CollectionTypeAuth),
) )
subGroup.GET("/auth-methods", api.authMethods)
subGroup.POST("/auth-refresh", api.authRefresh, RequireSameContextRecordAuth()) sub.POST("/auth-refresh", recordAuthRefresh).Bind(
subGroup.POST("/auth-with-oauth2", api.authWithOAuth2) collectionPathRateLimit("", "authRefresh"),
subGroup.POST("/auth-with-password", api.authWithPassword) RequireSameCollectionContextAuth(""),
subGroup.POST("/request-password-reset", api.requestPasswordReset) )
subGroup.POST("/confirm-password-reset", api.confirmPasswordReset)
subGroup.POST("/request-verification", api.requestVerification) sub.POST("/auth-with-password", recordAuthWithPassword).Bind(
subGroup.POST("/confirm-verification", api.confirmVerification) collectionPathRateLimit("", "authWithPassword", "auth"),
subGroup.POST("/request-email-change", api.requestEmailChange, RequireSameContextRecordAuth()) )
subGroup.POST("/confirm-email-change", api.confirmEmailChange)
subGroup.GET("/records/:id/external-auths", api.listExternalAuths, RequireAdminOrOwnerAuth("id")) sub.POST("/auth-with-oauth2", recordAuthWithOAuth2).Bind(
subGroup.DELETE("/records/:id/external-auths/:provider", api.unlinkExternalAuth, RequireAdminOrOwnerAuth("id")) collectionPathRateLimit("", "authWithOAuth2", "auth"),
)
sub.POST("/request-otp", recordRequestOTP).Bind(
collectionPathRateLimit("", "requestOTP"),
)
sub.POST("/auth-with-otp", recordAuthWithOTP).Bind(
collectionPathRateLimit("", "authWithOTP", "auth"),
)
sub.POST("/request-password-reset", recordRequestPasswordReset).Bind(
collectionPathRateLimit("", "requestPasswordReset"),
)
sub.POST("/confirm-password-reset", recordConfirmPasswordReset).Bind(
collectionPathRateLimit("", "confirmPasswordReset"),
)
sub.POST("/request-verification", recordRequestVerification).Bind(
collectionPathRateLimit("", "requestVerification"),
)
sub.POST("/confirm-verification", recordConfirmVerification).Bind(
collectionPathRateLimit("", "confirmVerification"),
)
sub.POST("/request-email-change", recordRequestEmailChange).Bind(
collectionPathRateLimit("", "requestEmailChange"),
RequireSameCollectionContextAuth(""),
)
sub.POST("/confirm-email-change", recordConfirmEmailChange).Bind(
collectionPathRateLimit("", "confirmEmailChange"),
)
sub.POST("/impersonate/{id}", recordAuthImpersonate).Bind(RequireSuperuserAuth())
} }
type recordAuthApi struct { func findAuthCollection(e *core.RequestEvent) (*core.Collection, error) {
app core.App collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
}
if err != nil || !collection.IsAuth() {
func (api *recordAuthApi) authRefresh(c echo.Context) error { return nil, e.NotFoundError("Missing or invalid auth collection context.", err)
record, _ := c.Get(ContextAuthRecordKey).(*models.Record) }
if record == nil {
return NewNotFoundError("Missing auth record context.", nil) return collection, nil
}
event := new(core.RecordAuthRefreshEvent)
event.HttpContext = c
event.Collection = record.Collection()
event.Record = record
return api.app.OnRecordBeforeAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error {
return api.app.OnRecordAfterAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error {
return RecordAuthResponse(api.app, e.HttpContext, e.Record, nil)
})
})
}
type providerInfo struct {
Name string `json:"name"`
DisplayName string `json:"displayName"`
State string `json:"state"`
AuthUrl string `json:"authUrl"`
// technically could be omitted if the provider doesn't support PKCE,
// but to avoid breaking existing typed clients we'll return them as empty string
CodeVerifier string `json:"codeVerifier"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
func (api *recordAuthApi) authMethods(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
authOptions := collection.AuthOptions()
result := struct {
AuthProviders []providerInfo `json:"authProviders"`
UsernamePassword bool `json:"usernamePassword"`
EmailPassword bool `json:"emailPassword"`
OnlyVerified bool `json:"onlyVerified"`
}{
UsernamePassword: authOptions.AllowUsernameAuth,
EmailPassword: authOptions.AllowEmailAuth,
OnlyVerified: authOptions.OnlyVerified,
AuthProviders: []providerInfo{},
}
if !authOptions.AllowOAuth2Auth {
return c.JSON(http.StatusOK, result)
}
nameConfigMap := api.app.Settings().NamedAuthProviderConfigs()
for name, config := range nameConfigMap {
if !config.Enabled {
continue
}
provider, err := auth.NewProviderByName(name)
if err != nil {
api.app.Logger().Debug("Missing or invalid provider name", slog.String("name", name))
continue // skip provider
}
if err := config.SetupProvider(provider); err != nil {
api.app.Logger().Debug(
"Failed to setup provider",
slog.String("name", name),
slog.String("error", err.Error()),
)
continue // skip provider
}
info := providerInfo{
Name: name,
DisplayName: provider.DisplayName(),
State: security.RandomString(30),
}
if info.DisplayName == "" {
info.DisplayName = name
}
urlOpts := []oauth2.AuthCodeOption{}
// custom providers url options
switch name {
case auth.NameApple:
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
if provider.PKCE() {
info.CodeVerifier = security.RandomString(43)
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
info.CodeChallengeMethod = "S256"
urlOpts = append(urlOpts,
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
)
}
info.AuthUrl = provider.BuildAuthUrl(
info.State,
urlOpts...,
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
result.AuthProviders = append(result.AuthProviders, info)
}
// sort providers
sort.SliceStable(result.AuthProviders, func(i, j int) bool {
return result.AuthProviders[i].Name < result.AuthProviders[j].Name
})
return c.JSON(http.StatusOK, result)
}
func (api *recordAuthApi) authWithOAuth2(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
if !collection.AuthOptions().AllowOAuth2Auth {
return NewBadRequestError("The collection is not configured to allow OAuth2 authentication.", nil)
}
var fallbackAuthRecord *models.Record
loggedAuthRecord, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if loggedAuthRecord != nil && loggedAuthRecord.Collection().Id == collection.Id {
fallbackAuthRecord = loggedAuthRecord
}
form := forms.NewRecordOAuth2Login(api.app, collection, fallbackAuthRecord)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordAuthWithOAuth2Event)
event.HttpContext = c
event.Collection = collection
event.ProviderName = form.Provider
form.SetBeforeNewRecordCreateFunc(func(createForm *forms.RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error {
return createForm.DrySubmit(func(txDao *daos.Dao) error {
event.IsNewRecord = true
// clone the current request data and assign the form create data as its body data
requestInfo := *RequestInfo(c)
requestInfo.Context = models.RequestInfoContextOAuth2
requestInfo.Data = form.CreateData
createRuleFunc := func(q *dbx.SelectQuery) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
return nil // either admin or the rule is empty
}
if collection.CreateRule == nil {
return errors.New("Only admins can create new accounts with OAuth2")
}
if *collection.CreateRule != "" {
resolver := resolvers.NewRecordFieldResolver(txDao, collection, &requestInfo, true)
expr, err := search.FilterData(*collection.CreateRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
if _, err := txDao.FindRecordById(collection.Id, createForm.Id, createRuleFunc); err != nil {
return fmt.Errorf("Failed create rule constraint: %w", err)
}
return nil
})
})
_, _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData]) forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData] {
return func(data *forms.RecordOAuth2LoginData) error {
event.Record = data.Record
event.OAuth2User = data.OAuth2User
event.ProviderClient = data.ProviderClient
event.IsNewRecord = data.Record == nil
return api.app.OnRecordBeforeAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error {
data.Record = e.Record
data.OAuth2User = e.OAuth2User
if err := next(data); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
e.Record = data.Record
e.OAuth2User = data.OAuth2User
meta := struct {
*auth.AuthUser
IsNew bool `json:"isNew"`
}{
AuthUser: e.OAuth2User,
IsNew: event.IsNewRecord,
}
return api.app.OnRecordAfterAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error {
// clear the lastLoginAlertSentAt field so that we can enforce password auth notifications
if !e.Record.LastLoginAlertSentAt().IsZero() {
e.Record.Set(schema.FieldNameLastLoginAlertSentAt, "")
if err := api.app.Dao().SaveRecord(e.Record); err != nil {
api.app.Logger().Warn("Failed to reset lastLoginAlertSentAt", "error", err, "recordId", e.Record.Id)
}
}
return RecordAuthResponse(api.app, e.HttpContext, e.Record, meta)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) authWithPassword(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordPasswordLogin(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordAuthWithPasswordEvent)
event.HttpContext = c
event.Collection = collection
event.Password = form.Password
event.Identity = form.Identity
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
// @todo remove after the refactoring
if collection.AuthOptions().AllowOAuth2Auth && e.Record.Email() != "" {
externalAuths, err := api.app.Dao().FindAllExternalAuthsByRecord(e.Record)
if err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
if len(externalAuths) > 0 {
lastLoginAlert := e.Record.LastLoginAlertSentAt().Time()
// send an email alert if the password auth is after OAuth2 auth (lastLoginAlert will be empty)
// or if it has been ~7 days since the last alert
if lastLoginAlert.IsZero() || time.Now().UTC().Sub(lastLoginAlert).Hours() > 168 {
providerNames := make([]string, len(externalAuths))
for i, ea := range externalAuths {
var name string
if provider, err := auth.NewProviderByName(ea.Provider); err == nil {
name = provider.DisplayName()
}
if name == "" {
name = ea.Provider
}
providerNames[i] = name
}
if err := mails.SendRecordPasswordLoginAlert(api.app, e.Record, providerNames...); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
e.Record.SetLastLoginAlertSentAt(types.NowDateTime())
if err := api.app.Dao().SaveRecord(e.Record); err != nil {
api.app.Logger().Warn("Failed to update lastLoginAlertSentAt", "error", err, "recordId", e.Record.Id)
}
}
}
}
return api.app.OnRecordAfterAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error {
return RecordAuthResponse(api.app, e.HttpContext, e.Record, nil)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) requestPasswordReset(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
authOptions := collection.AuthOptions()
if !authOptions.AllowUsernameAuth && !authOptions.AllowEmailAuth {
return NewBadRequestError("The collection is not configured to allow password authentication.", nil)
}
form := forms.NewRecordPasswordResetRequest(api.app, collection)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
if err := form.Validate(); err != nil {
return NewBadRequestError("An error occurred while validating the form.", err)
}
event := new(core.RecordRequestPasswordResetEvent)
event.HttpContext = c
event.Collection = collection
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetEvent) error {
// run in background because we don't need to show the result to the client
routine.FireAndForget(func() {
if err := next(e.Record); err != nil {
api.app.Logger().Debug(
"Failed to send password reset email",
slog.String("error", err.Error()),
)
}
})
return api.app.OnRecordAfterRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
// eagerly write 204 response and skip submit errors
// as a measure against emails enumeration
if !c.Response().Committed {
c.NoContent(http.StatusNoContent)
}
return submitErr
}
func (api *recordAuthApi) confirmPasswordReset(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordPasswordResetConfirm(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordConfirmPasswordResetEvent)
event.HttpContext = c
event.Collection = collection
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to set new password.", err)
}
return api.app.OnRecordAfterConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) requestVerification(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordVerificationRequest(api.app, collection)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
if err := form.Validate(); err != nil {
return NewBadRequestError("An error occurred while validating the form.", err)
}
event := new(core.RecordRequestVerificationEvent)
event.HttpContext = c
event.Collection = collection
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationEvent) error {
// run in background because we don't need to show the result to the client
routine.FireAndForget(func() {
if err := next(e.Record); err != nil {
api.app.Logger().Debug(
"Failed to send verification email",
slog.String("error", err.Error()),
)
}
})
return api.app.OnRecordAfterRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
// eagerly write 204 response and skip submit errors
// as a measure against users enumeration
if !c.Response().Committed {
c.NoContent(http.StatusNoContent)
}
return submitErr
}
func (api *recordAuthApi) confirmVerification(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordVerificationConfirm(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordConfirmVerificationEvent)
event.HttpContext = c
event.Collection = collection
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("An error occurred while submitting the form.", err)
}
return api.app.OnRecordAfterConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) requestEmailChange(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires valid auth record.", nil)
}
form := forms.NewRecordEmailChangeRequest(api.app, record)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
event := new(core.RecordRequestEmailChangeEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
return api.app.OnRecordBeforeRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to request email change.", err)
}
return api.app.OnRecordAfterRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
}
func (api *recordAuthApi) confirmEmailChange(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordEmailChangeConfirm(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordConfirmEmailChangeEvent)
event.HttpContext = c
event.Collection = collection
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to confirm email change.", err)
}
return api.app.OnRecordAfterConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) listExternalAuths(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
record, err := api.app.Dao().FindRecordById(collection.Id, id)
if err != nil || record == nil {
return NewNotFoundError("", err)
}
externalAuths, err := api.app.Dao().FindAllExternalAuthsByRecord(record)
if err != nil {
return NewBadRequestError("Failed to fetch the external auths for the specified auth record.", err)
}
event := new(core.RecordListExternalAuthsEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.ExternalAuths = externalAuths
return api.app.OnRecordListExternalAuthsRequest().Trigger(event, func(e *core.RecordListExternalAuthsEvent) error {
return e.HttpContext.JSON(http.StatusOK, e.ExternalAuths)
})
}
func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
id := c.PathParam("id")
provider := c.PathParam("provider")
if id == "" || provider == "" {
return NewNotFoundError("", nil)
}
record, err := api.app.Dao().FindRecordById(collection.Id, id)
if err != nil || record == nil {
return NewNotFoundError("", err)
}
externalAuth, err := api.app.Dao().FindExternalAuthByRecordAndProvider(record, provider)
if err != nil {
return NewNotFoundError("Missing external auth provider relation.", err)
}
event := new(core.RecordUnlinkExternalAuthEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.ExternalAuth = externalAuth
return api.app.OnRecordBeforeUnlinkExternalAuthRequest().Trigger(event, func(e *core.RecordUnlinkExternalAuthEvent) error {
if err := api.app.Dao().DeleteExternalAuth(externalAuth); err != nil {
return NewBadRequestError("Cannot unlink the external auth provider.", err)
}
return api.app.OnRecordAfterUnlinkExternalAuthRequest().Trigger(event, func(e *core.RecordUnlinkExternalAuthEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
const (
oauth2SubscriptionTopic string = "@oauth2"
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
)
type oauth2RedirectData struct {
State string `form:"state" query:"state" json:"state"`
Code string `form:"code" query:"code" json:"code"`
Error string `form:"error" query:"error" json:"error,omitempty"`
}
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
redirectStatusCode := http.StatusTemporaryRedirect
if c.Request().Method != http.MethodGet {
redirectStatusCode = http.StatusSeeOther
}
data := oauth2RedirectData{}
if err := c.Bind(&data); err != nil {
api.app.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
if data.State == "" {
api.app.Logger().Debug("Missing OAuth2 state parameter")
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
client, err := api.app.SubscriptionsBroker().ClientById(data.State)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
api.app.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
defer client.Unsubscribe(oauth2SubscriptionTopic)
encodedData, err := json.Marshal(data)
if err != nil {
api.app.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
msg := subscriptions.Message{
Name: oauth2SubscriptionTopic,
Data: encodedData,
}
client.Send(msg)
if data.Error != "" || data.Code == "" {
api.app.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
return c.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
} }

View File

@ -0,0 +1,121 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
)
func recordConfirmEmailChange(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers can change their emails directly.", nil)
}
form := newEmailChangeConfirmForm(e.App, collection)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
authRecord, newEmail, err := form.parseToken()
if err != nil {
return firstApiError(err, e.BadRequestError("Invalid or expired token.", err))
}
event := new(core.RecordConfirmEmailChangeRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = authRecord
event.NewEmail = newEmail
return e.App.OnRecordConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeRequestEvent) error {
authRecord.Set(core.FieldNameEmail, e.NewEmail)
authRecord.Set(core.FieldNameVerified, true)
authRecord.RefreshTokenKey() // invalidate old tokens
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
}
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
func newEmailChangeConfirmForm(app core.App, collection *core.Collection) *EmailChangeConfirmForm {
return &EmailChangeConfirmForm{
app: app,
collection: collection,
}
}
type EmailChangeConfirmForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
Password string `form:"password" json:"password"`
}
func (form *EmailChangeConfirmForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 100), validation.By(form.checkPassword)),
)
}
func (form *EmailChangeConfirmForm) checkToken(value any) error {
_, _, err := form.parseToken()
return err
}
func (form *EmailChangeConfirmForm) checkPassword(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
authRecord, _, _ := form.parseToken()
if authRecord == nil || !authRecord.ValidatePassword(v) {
return validation.NewError("validation_invalid_password", "Missing or invalid auth record password.")
}
return nil
}
func (form *EmailChangeConfirmForm) parseToken() (*core.Record, string, error) {
// check token payload
claims, _ := security.ParseUnverifiedJWT(form.Token)
newEmail, _ := claims[core.TokenClaimNewEmail].(string)
if newEmail == "" {
return nil, "", validation.NewError("validation_invalid_token_payload", "Invalid token payload - newEmail must be set.")
}
// ensure that there aren't other users with the new email
_, err := form.app.FindAuthRecordByEmail(form.collection, newEmail)
if err == nil {
return nil, "", validation.NewError("validation_existing_token_email", "The new email address is already registered: "+newEmail)
}
// verify that the token is not expired and its signature is valid
authRecord, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeEmailChange)
if err != nil {
return nil, "", validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if authRecord.Collection().Id != form.collection.Id {
return nil, "", validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
return authRecord, newEmail, nil
}

View File

@ -0,0 +1,205 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmEmailChange(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-email-change",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"token":{"code":"validation_required"`,
`"password":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{"token`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token and correct password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoxNjQwOTkxNjYxfQ.dff842MO0mgRTHY8dktp0dqG9-7LGQOgRuiAbQpYBls",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{`,
`"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-email change token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{`,
`"code":"validation_invalid_token_payload"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and incorrect password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567891"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"password":{`,
`"code":"validation_invalid_password"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and correct password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmEmailChangeRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByEmail("users", "change@example.com")
if err != nil {
t.Fatalf("Expected to find user with email %q, got error: %v", "change@example.com", err)
}
},
},
{
Name: "valid token in different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAfterConfirmEmailChangeRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmEmailChangeRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:confirmEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmEmailChange"},
{MaxRequests: 0, Label: "users:confirmEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,90 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
)
func recordRequestEmailChange(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers can change their emails directly.", nil)
}
record := e.Auth
if record == nil {
return e.UnauthorizedError("The request requires valid auth record.", nil)
}
form := newEmailChangeRequestForm(e.App, record)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.RecordRequestEmailChangeRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
event.NewEmail = form.NewEmail
return e.App.OnRecordRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeRequestEvent) error {
if err := mails.SendRecordChangeEmail(e.App, e.Record, e.NewEmail); err != nil {
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
}
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
func newEmailChangeRequestForm(app core.App, record *core.Record) *emailChangeRequestForm {
return &emailChangeRequestForm{
app: app,
record: record,
}
}
type emailChangeRequestForm struct {
app core.App
record *core.Record
NewEmail string `form:"newEmail" json:"newEmail"`
}
func (form *emailChangeRequestForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.NewEmail,
validation.Required,
validation.Length(1, 255),
is.EmailFormat,
validation.NotIn(form.record.Email()),
validation.By(form.checkUniqueEmail),
),
)
}
func (form *emailChangeRequestForm) checkUniqueEmail(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
found, _ := form.app.FindAuthRecordByEmail(form.record.Collection(), v)
if found != nil && found.Id != form.record.Id {
return validation.NewError("validation_invalid_new_email", "Invalid new email address.")
}
return nil
}

View File

@ -0,0 +1,168 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestEmailChange(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "record authentication but from different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser authentication",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"newEmail":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid data (existing email)",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"test2@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"newEmail":{"code":"validation_invalid_new_email"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid data (new email)",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestEmailChangeRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordEmailChangeSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-email-change") {
t.Fatalf("Expected email change email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestEmailChange"},
{MaxRequests: 0, Label: "users:requestEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,54 @@
package apis
import (
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
// note: for now allow superusers but it may change in the future to allow access
// also to users with "Manage API" rule access depending on the use cases that will arise
func recordAuthImpersonate(e *core.RequestEvent) error {
if !e.HasSuperuserAuth() {
return e.ForbiddenError("", nil)
}
collection, err := findAuthCollection(e)
if err != nil {
return err
}
record, err := e.App.FindRecordById(collection, e.Request.PathValue("id"))
if err != nil {
return e.NotFoundError("", err)
}
form := &impersonateForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)
if err != nil {
e.InternalServerError("Failed to generate static auth token", err)
}
return recordAuthResponse(e, record, token, "", nil)
}
// -------------------------------------------------------------------
type impersonateForm struct {
// Duration is the optional custom token duration in seconds.
Duration int64 `form:"duration" json:"duration"`
}
func (form *impersonateForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Duration, validation.Min(0)),
)
}

View File

@ -0,0 +1,109 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthImpersonate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as different user",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as the same user",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"id":"4q1xlclmfloku33"`,
`"record":{`,
},
NotExpectedContent: []string{
// hidden fields should remain hidden even though we are authenticated as superuser
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "authorized as superuser with custom invalid duration",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{"duration":-1}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"duration":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser with custom valid duration",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{"duration":100}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"id":"4q1xlclmfloku33"`,
`"record":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

170
apis/record_auth_methods.go Normal file
View File

@ -0,0 +1,170 @@
package apis
import (
"log/slog"
"net/http"
"slices"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/security"
"golang.org/x/oauth2"
)
type otpResponse struct {
Enabled bool `json:"enabled"`
Duration int64 `json:"duration"` // in seconds
}
type mfaResponse struct {
Enabled bool `json:"enabled"`
Duration int64 `json:"duration"` // in seconds
}
type passwordResponse struct {
IdentityFields []string `json:"identityFields"`
Enabled bool `json:"enabled"`
}
type oauth2Response struct {
Providers []providerInfo `json:"providers"`
Enabled bool `json:"enabled"`
}
type providerInfo struct {
Name string `json:"name"`
DisplayName string `json:"displayName"`
State string `json:"state"`
AuthURL string `json:"authURL"`
// @todo
// deprecated: use AuthURL instead
// AuthUrl will be removed after dropping v0.22 support
AuthUrl string `json:"authUrl"`
// technically could be omitted if the provider doesn't support PKCE,
// but to avoid breaking existing typed clients we'll return them as empty string
CodeVerifier string `json:"codeVerifier"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
type authMethodsResponse struct {
Password passwordResponse `json:"password"`
OAuth2 oauth2Response `json:"oauth2"`
MFA mfaResponse `json:"mfa"`
OTP otpResponse `json:"otp"`
// legacy fields
// @todo remove after dropping v0.22 support
AuthProviders []providerInfo `json:"authProviders"`
UsernamePassword bool `json:"usernamePassword"`
EmailPassword bool `json:"emailPassword"`
}
func (amr *authMethodsResponse) fillLegacyFields() {
amr.EmailPassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "email")
amr.UsernamePassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "username")
if amr.OAuth2.Enabled {
amr.AuthProviders = amr.OAuth2.Providers
}
}
func recordAuthMethods(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
result := authMethodsResponse{
Password: passwordResponse{
IdentityFields: make([]string, 0, len(collection.PasswordAuth.IdentityFields)),
},
OAuth2: oauth2Response{
Providers: make([]providerInfo, 0, len(collection.OAuth2.Providers)),
},
OTP: otpResponse{
Enabled: collection.OTP.Enabled,
},
MFA: mfaResponse{
Enabled: collection.MFA.Enabled,
},
}
if collection.PasswordAuth.Enabled {
result.Password.Enabled = true
result.Password.IdentityFields = collection.PasswordAuth.IdentityFields
}
if collection.OTP.Enabled {
result.OTP.Duration = collection.OTP.Duration
}
if collection.MFA.Enabled {
result.MFA.Duration = collection.MFA.Duration
}
if !collection.OAuth2.Enabled {
result.fillLegacyFields()
return e.JSON(http.StatusOK, result)
}
result.OAuth2.Enabled = true
for _, config := range collection.OAuth2.Providers {
provider, err := config.InitProvider()
if err != nil {
e.App.Logger().Debug(
"Failed to setup OAuth2 provider",
slog.String("name", config.Name),
slog.String("error", err.Error()),
)
continue // skip provider
}
info := providerInfo{
Name: config.Name,
DisplayName: provider.DisplayName(),
State: security.RandomString(30),
}
if info.DisplayName == "" {
info.DisplayName = config.Name
}
urlOpts := []oauth2.AuthCodeOption{}
// custom providers url options
switch config.Name {
case auth.NameApple:
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
if provider.PKCE() {
info.CodeVerifier = security.RandomString(43)
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
info.CodeChallengeMethod = "S256"
urlOpts = append(urlOpts,
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
)
}
info.AuthURL = provider.BuildAuthURL(
info.State,
urlOpts...,
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
info.AuthUrl = info.AuthURL
result.OAuth2.Providers = append(result.OAuth2.Providers, info)
}
result.fillLegacyFields()
return e.JSON(http.StatusOK, result)
}

View File

@ -0,0 +1,106 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthMethodsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "missing collection",
Method: http.MethodGet,
URL: "/api/collections/missing/auth-methods",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodGet,
URL: "/api/collections/demo1/auth-methods",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with none auth methods allowed",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
ExpectedStatus: 200,
ExpectedContent: []string{
`"password":{"identityFields":[],"enabled":false}`,
`"oauth2":{"providers":[],"enabled":false}`,
`"mfa":{"enabled":false,"duration":0}`,
`"otp":{"enabled":false,"duration":0}`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with all auth methods allowed",
Method: http.MethodGet,
URL: "/api/collections/users/auth-methods",
ExpectedStatus: 200,
ExpectedContent: []string{
`"password":{"identityFields":["email","username"],"enabled":true}`,
`"mfa":{"enabled":true,"duration":1800}`,
`"otp":{"enabled":true,"duration":300}`,
`"oauth2":{`,
`"providers":[{`,
`"name":"google"`,
`"name":"gitlab"`,
`"state":`,
`"displayName":`,
`"codeVerifier":`,
`"codeChallenge":`,
`"codeChallengeMethod":`,
`"authURL":`,
`redirect_uri="`, // ensures that the redirect_uri is the last url param
},
ExpectedEvents: map[string]int{"*": 0},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - nologin:listAuthMethods",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:listAuthMethods"},
{MaxRequests: 0, Label: "nologin:listAuthMethods"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:listAuthMethods",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:listAuthMethods"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,118 @@
package apis
import (
"errors"
"fmt"
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/security"
)
func recordRequestOTP(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OTP.Enabled {
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
}
form := &createOTPForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write a dummy 200 response as a very rudimentary user emails enumeration protection
e.JSON(http.StatusOK, map[string]string{
"otpId": core.GenerateDefaultRandomId(),
})
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
event := new(core.RecordCreateOTPRequestEvent)
event.RequestEvent = e
event.Password = security.RandomStringWithAlphabet(collection.OTP.Length, "1234567890")
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestOTPRequest().Trigger(event, func(e *core.RecordCreateOTPRequestEvent) error {
var otp *core.OTP
// limit the new OTP creations for a single user
if !e.App.IsDev() {
otps, err := e.App.FindAllOTPsByRecord(e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to fetch previous record OTPs.", err))
}
totalRecent := 0
for _, existingOTP := range otps {
if !existingOTP.HasExpired(collection.OTP.DurationTime()) {
totalRecent++
}
// use the last issued one
if totalRecent > 9 {
otp = otps[0] // otps are DESC sorted
e.App.Logger().Warn(
"Too many OTP requests - reusing the last issued",
"email", form.Email,
"recordId", e.Record.Id,
"otpId", existingOTP.Id,
)
break
}
}
}
if otp == nil {
// create new OTP
// ---
otp = core.NewOTP(e.App)
otp.SetCollectionRef(e.Record.Collection().Id)
otp.SetRecordRef(e.Record.Id)
otp.SetPassword(e.Password)
err = e.App.Save(otp)
if err != nil {
return err
}
// send OTP email
// (in the background as a very basic timing attacks and emails enumeration protection)
// ---
app := e.App
routine.FireAndForget(func() {
err = mails.SendRecordOTP(app, e.Record, otp.Id, e.Password)
if err != nil {
app.Logger().Error("Failed to send OTP email", "error", errors.Join(err, e.App.Delete(otp)))
}
})
}
return e.JSON(http.StatusOK, map[string]string{
"otpId": otp.Id,
})
})
}
// -------------------------------------------------------------------
type createOTPForm struct {
Email string `form:"email" json:"email"`
}
func (form createOTPForm) validate() error {
return validation.ValidateStruct(&form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}

View File

@ -0,0 +1,231 @@
package apis_test
import (
"net/http"
"strconv"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordRequestOTP(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with disabled otp",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
usersCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
usersCol.OTP.Enabled = false
if err := app.Save(usersCol); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid request data",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"invalid"}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"email":{"code":"validation_is_email`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`, // some fake random generated string
},
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record (with < 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert 8 non-expired and 2 expired
for i := 0; i < 10; i++ {
otp := core.NewOTP(app)
otp.Id = "otp_" + strconv.Itoa(i)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if i >= 8 {
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
otp.SetRaw("created", expiredDate)
otp.SetRaw("updated", expiredDate)
}
if err := app.SaveNoValidate(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`,
},
NotExpectedContent: []string{
`"otpId":"otp_`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordOTPSend": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("Expected 1 email, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record (with > 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert 10 non-expired
for i := 0; i < 10; i++ {
otp := core.NewOTP(app)
otp.Id = "otp_" + strconv.Itoa(i)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.SaveNoValidate(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"otp_9"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected 0 sent emails, got %d", app.TestMailer.TotalSend())
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestOTP",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestOTP"},
{MaxRequests: 0, Label: "users:requestOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestOTP",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,102 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordConfirmPasswordReset(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
form := new(recordConfirmPasswordResetForm)
form.app = e.App
form.collection = collection
if err = e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
authRecord, err := e.App.FindAuthRecordByToken(form.Token, core.TokenTypePasswordReset)
if err != nil {
return firstApiError(err, e.BadRequestError("Invalid or expired password reset token.", err))
}
event := new(core.RecordConfirmPasswordResetRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = authRecord
return e.App.OnRecordConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetRequestEvent) error {
authRecord.SetPassword(form.Password)
if !authRecord.Verified() {
payload, err := security.ParseUnverifiedJWT(form.Token)
if err == nil && authRecord.Email() == cast.ToString(payload[core.TokenClaimEmail]) {
// mark as verified if the email hasn't changed
authRecord.SetVerified(true)
}
}
err = form.app.Save(authRecord)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to set new password.", err))
}
form.app.Store().Remove(getPasswordResetResendKey(authRecord))
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordConfirmPasswordResetForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
Password string `form:"password" json:"password"`
PasswordConfirm string `form:"passwordConfirm" json:"passwordConfirm"`
}
func (form *recordConfirmPasswordResetForm) validate() error {
min := 1
passField, ok := form.collection.Fields.GetByName(core.FieldNamePassword).(*core.PasswordField)
if ok && passField != nil && passField.Min > 0 {
min = passField.Min
}
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
validation.Field(&form.Password, validation.Required, validation.Length(min, 255)), // the FieldPassword validator will check further the specicic length constraints
validation.Field(&form.PasswordConfirm, validation.Required, validation.By(validators.Equal(form.Password))),
)
}
func (form *recordConfirmPasswordResetForm) checkToken(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypePasswordReset)
if err != nil || record == nil {
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if record.Collection().Id != form.collection.Id {
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
return nil
}

View File

@ -0,0 +1,345 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"password":{"code":"validation_required"`,
`"passwordConfirm":{"code":"validation_required"`,
`"token":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data format",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token and invalid password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.5Tm6_6amQqOlX3urAnXlEdmxwG5qQJfiTg6U0hHR1hk",
"password":"1234567",
"passwordConfirm":"7654321"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-password reset token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-password-reset?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-password-reset?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and data (unverified user)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to be unverified")
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatal("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if !user.Verified() {
t.Fatal("Expected the user to be marked as verified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "valid token and data (unverified user with different email from the one in the token)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to be unverified")
}
// manually change the email to check whether the verified state will be updated
user.SetEmail("test_update@example.com")
if err := app.Save(user); err != nil {
t.Fatalf("Failed to update user test email: %v", err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatalf("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test_update@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to remain unverified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "valid token and data (verified user)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
// ensure that the user is already verified
user.SetVerified(true)
if err := app.Save(user); err != nil {
t.Fatalf("Failed to update user verified state")
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatal("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if !user.Verified() {
t.Fatal("Expected the user to remain verified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "OnRecordAfterConfirmPasswordResetRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:confirmPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmPasswordReset"},
{MaxRequests: 0, Label: "users:confirmPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,86 @@
package apis
import (
"errors"
"fmt"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
)
func recordRequestPasswordReset(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.PasswordAuth.Enabled {
return e.BadRequestError("The collection is not configured to allow password authentication.", nil)
}
form := new(recordRequestPasswordResetForm)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
resendKey := getPasswordResetResendKey(record)
if e.App.Store().Has(resendKey) {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return errors.New("try again later - you've already requested a password reset email")
}
event := new(core.RecordRequestPasswordResetRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetRequestEvent) error {
// run in background because we don't need to show the result to the client
app := e.App
routine.FireAndForget(func() {
if err := mails.SendRecordPasswordReset(app, e.Record); err != nil {
app.Logger().Error("Failed to send password reset email", "error", err)
return
}
app.Store().Set(resendKey, struct{}{})
time.AfterFunc(2*time.Minute, func() {
app.Store().Remove(resendKey)
})
})
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordRequestPasswordResetForm struct {
Email string `form:"email" json:"email"`
}
func (form *recordRequestPasswordResetForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}
func getPasswordResetResendKey(record *core.Record) string {
return "@limitPasswordResetEmail_" + record.Collection().Id + record.Id
}

View File

@ -0,0 +1,145 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing auth record in a collection with disabled password login",
Method: http.MethodPost,
URL: "/api/collections/nologin/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestPasswordResetRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordPasswordResetSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-password-reset") {
t.Fatalf("Expected password reset email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "existing auth record (after already sent)",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// simulate recent verification sent
authRecord, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
resendKey := "@limitPasswordResetEmail_" + authRecord.Collection().Id + authRecord.Id
app.Store().Set(resendKey, struct{}{})
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestPasswordReset"},
{MaxRequests: 0, Label: "users:requestPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,29 @@
package apis
import (
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordAuthRefresh(e *core.RequestEvent) error {
record := e.Auth
if record == nil {
return e.NotFoundError("Missing auth record context.", nil)
}
currentToken := getAuthTokenFromRequest(e)
claims, _ := security.ParseUnverifiedJWT(currentToken)
if v, ok := claims[core.TokenClaimRefreshable]; !ok || !cast.ToBool(v) {
return e.ForbiddenError("The current auth token is not refreshable.", nil)
}
event := new(core.RecordAuthRefreshRequestEvent)
event.RequestEvent = e
event.Collection = record.Collection()
event.Record = record
return e.App.OnRecordAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshRequestEvent) error {
return RecordAuthResponse(e.RequestEvent, e.Record, "", nil)
})
}

View File

@ -0,0 +1,196 @@
package apis_test
import (
"errors"
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthRefresh(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser trying to refresh the auth of another auth collection",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + same auth collection as the token",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":`,
`"record":`,
`"id":"4q1xlclmfloku33"`,
`"emailVisibility":false`,
`"email":"test@example.com"`, // the owner can always view their email address
`"expand":`,
`"rel":`,
`"id":"llvuca81nly1qls"`,
},
NotExpectedContent: []string{
`"missing":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 2,
},
},
{
Name: "auth record + same auth collection as the token but static/unrefreshable",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unverified auth record in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im8xeTBkZDBzcGQ3ODZtZCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.Zi0yXE-CNmnbTdVaQEzYZVuECqRdn3LgEM6pmB3XWBE",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
},
},
{
Name: "verified auth record in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":`,
`"record":`,
`"id":"gk390qegs4y47wn"`,
`"verified":true`,
`"email":"test@example.com"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "OnRecordAfterAuthRefreshRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authRefresh",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authRefresh"},
{MaxRequests: 0, Label: "users:authRefresh"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authRefresh",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:authRefresh"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,102 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordConfirmVerification(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers are verified by default.", nil)
}
form := new(recordConfirmVerificationForm)
form.app = e.App
form.collection = collection
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeVerification)
if err != nil {
return e.BadRequestError("Invalid or expired verification token.", err)
}
wasVerified := record.Verified()
event := new(core.RecordConfirmVerificationRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
if wasVerified {
return e.NoContent(http.StatusNoContent)
}
e.Record.SetVerified(true)
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
}
e.App.Store().Remove(getVerificationResendKey(e.Record))
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordConfirmVerificationForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
}
func (form *recordConfirmVerificationForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
)
}
func (form *recordConfirmVerificationForm) checkToken(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
claims, _ := security.ParseUnverifiedJWT(v)
email := cast.ToString(claims["email"])
if email == "" {
return validation.NewError("validation_invalid_token_claims", "Missing email token claim.")
}
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypeVerification)
if err != nil || record == nil {
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if record.Collection().Id != form.collection.Id {
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
if record.Email() != email {
return validation.NewError("validation_token_email_mismatch", "The record email doesn't match with the requested token claims.")
}
return nil
}

View File

@ -0,0 +1,210 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmVerification(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data format",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.qqelNNL2Udl6K_TJ282sNHYCpASgA6SIuSVKGfBHMZU"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-verification token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-verification?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-verification?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
"OnModelUpdate": 1,
"OnModelValidate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordValidate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
},
{
Name: "valid token (already verified)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdDJAZXhhbXBsZS5jb20ifQ.QQmM3odNFVk6u4J4-5H8IBM3dfk9YCD7mPW-8PhBAI8"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
},
},
{
Name: "valid verification token from a collection without allowed login",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
"OnModelUpdate": 1,
"OnModelValidate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordValidate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
},
{
Name: "OnRecordAfterConfirmVerificationRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - nologin:confirmVerification",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmVerification"},
{MaxRequests: 0, Label: "nologin:confirmVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmVerification",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,89 @@
package apis
import (
"errors"
"fmt"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
)
func recordRequestVerification(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers are verified by default.", nil)
}
form := new(recordRequestVerificationForm)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
resendKey := getVerificationResendKey(record)
if !record.Verified() && e.App.Store().Has(resendKey) {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return errors.New("try again later - you've already requested a verification email")
}
event := new(core.RecordRequestVerificationRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationRequestEvent) error {
if e.Record.Verified() {
return e.NoContent(http.StatusNoContent)
}
// run in background because we don't need to show the result to the client
app := e.App
routine.FireAndForget(func() {
if err := mails.SendRecordVerification(app, e.Record); err != nil {
app.Logger().Error("Failed to send verification email", "error", err)
}
app.Store().Set(resendKey, struct{}{})
time.AfterFunc(2*time.Minute, func() {
app.Store().Remove(resendKey)
})
})
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordRequestVerificationForm struct {
Email string `form:"email" json:"email"`
}
func (form *recordRequestVerificationForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}
func getVerificationResendKey(record *core.Record) string {
return "@limitVerificationEmail_" + record.Collection().Id + record.Id
}

View File

@ -0,0 +1,162 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestVerification(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-verification",
Body: strings.NewReader(``),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "already verified auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test2@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestVerificationRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestVerificationRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordVerificationSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-verification") {
t.Fatalf("Expected verification email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "existing auth record (after already sent)",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
// terminated before firing the event
// "OnRecordRequestVerificationRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// simulate recent verification sent
authRecord, err := app.FindFirstRecordByData("users", "email", "test@example.com")
if err != nil {
t.Fatal(err)
}
resendKey := "@limitVerificationEmail_" + authRecord.Collection().Id + authRecord.Id
app.Store().Set(resendKey, struct{}{})
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestVerification",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestVerification"},
{MaxRequests: 0, Label: "users:requestVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestVerification",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,355 @@
package apis
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/security"
"golang.org/x/oauth2"
)
func recordAuthWithOAuth2(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OAuth2.Enabled {
return e.ForbiddenError("The collection is not configured to allow OAuth2 authentication.", nil)
}
var fallbackAuthRecord *core.Record
if e.Auth != nil && e.Auth.Collection().Id == collection.Id {
fallbackAuthRecord = e.Auth
}
form := new(recordOAuth2LoginForm)
form.collection = collection
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if form.RedirectUrl != "" && form.RedirectURL == "" {
e.App.Logger().Warn("[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL.")
form.RedirectURL = form.RedirectUrl
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
// ---------------------------------------------------------------
// load provider configuration
providerConfig, ok := collection.OAuth2.GetProviderConfig(form.Provider)
if !ok {
return e.InternalServerError("Missing or invalid provider config.", nil)
}
provider, err := providerConfig.InitProvider()
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to init provider "+form.Provider, err))
}
ctx, cancel := context.WithTimeout(e.Request.Context(), 30*time.Second)
defer cancel()
provider.SetContext(ctx)
provider.SetRedirectURL(form.RedirectURL)
var opts []oauth2.AuthCodeOption
if provider.PKCE() {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
}
// fetch token
token, err := provider.FetchToken(form.Code, opts...)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 token.", err))
}
// fetch external auth user
authUser, err := provider.FetchAuthUser(token)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 user.", err))
}
var authRecord *core.Record
// check for existing relation with the auth record
externalAuthRel, err := e.App.FindFirstExternalAuthByExpr(dbx.HashExp{
"collectionRef": form.collection.Id,
"provider": form.Provider,
"providerId": authUser.Id,
})
switch {
case err == nil && externalAuthRel != nil:
authRecord, err = e.App.FindRecordById(form.collection, externalAuthRel.RecordRef())
if err != nil {
return err
}
case fallbackAuthRecord != nil && fallbackAuthRecord.Collection().Id == form.collection.Id:
// fallback to the logged auth record (if any)
authRecord = fallbackAuthRecord
case authUser.Email != "":
// look for an existing auth record by the external auth record's email
authRecord, _ = e.App.FindAuthRecordByEmail(form.collection.Id, authUser.Email)
}
// ---------------------------------------------------------------
event := new(core.RecordAuthWithOAuth2RequestEvent)
event.RequestEvent = e
event.Collection = collection
event.ProviderName = form.Provider
event.ProviderClient = provider
event.OAuth2User = authUser
event.CreateData = form.CreateData
event.Record = authRecord
event.IsNewRecord = authRecord == nil
return e.App.OnRecordAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2RequestEvent) error {
if err := oauth2Submit(e, externalAuthRel); err != nil {
return firstApiError(err, e.BadRequestError("Failed to authenticate.", err))
}
meta := struct {
*auth.AuthUser
IsNew bool `json:"isNew"`
}{
AuthUser: e.OAuth2User,
IsNew: e.IsNewRecord,
}
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOAuth2, meta)
})
}
// -------------------------------------------------------------------
type recordOAuth2LoginForm struct {
collection *core.Collection
// Additional data that will be used for creating a new auth record
// if an existing OAuth2 account doesn't exist.
CreateData map[string]any `form:"createData" json:"createData"`
// The name of the OAuth2 client provider (eg. "google")
Provider string `form:"provider" json:"provider"`
// The authorization code returned from the initial request.
Code string `form:"code" json:"code"`
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`
// The redirect url sent with the initial request.
RedirectURL string `form:"redirectURL" json:"redirectURL"`
// @todo
// deprecated: use RedirectURL instead
// RedirectUrl will be removed after dropping v0.22 support
RedirectUrl string `form:"redirectUrl" json:"redirectUrl"`
}
func (form *recordOAuth2LoginForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Provider, validation.Required, validation.By(form.checkProviderName)),
validation.Field(&form.Code, validation.Required),
validation.Field(&form.RedirectURL, validation.Required),
)
}
func (form *recordOAuth2LoginForm) checkProviderName(value any) error {
name, _ := value.(string)
_, ok := form.collection.OAuth2.GetProviderConfig(name)
if !ok {
return validation.NewError("validation_invalid_provider", fmt.Sprintf("Provider with name %q is missing or is not enabled.", name)).
SetParams(map[string]any{"name": name})
}
return nil
}
func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool {
// ensure that username is unique
checkUnique := dbutils.HasSingleColumnUniqueIndex(collection.OAuth2.MappedFields.Username, collection.Indexes)
if checkUnique {
if _, err := txApp.FindFirstRecordByData(collection, collection.OAuth2.MappedFields.Username, username); err == nil {
return false // already exist
}
}
// ensure that the value matches the pattern of the username field (if text)
txtField, _ := collection.Fields.GetByName(collection.OAuth2.MappedFields.Username).(*core.TextField)
return txtField != nil && txtField.ValidatePlainValue(username) == nil
}
func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *core.ExternalAuth) error {
return e.App.RunInTransaction(func(txApp core.App) error {
if e.Record == nil {
// extra check to prevent creating a superuser record via
// OAuth2 in case the method is used by another action
if e.Collection.Name == core.CollectionNameSuperusers {
return errors.New("superusers are not allowed to sign-up with OAuth2")
}
payload := maps.Clone(e.CreateData)
if payload == nil {
payload = map[string]any{}
}
payload[core.FieldNameEmail] = e.OAuth2User.Email
// set a random password if none is set
if v, _ := payload[core.FieldNamePassword].(string); v == "" {
payload[core.FieldNamePassword] = security.RandomString(30)
payload[core.FieldNamePassword+"Confirm"] = payload[core.FieldNamePassword]
}
// map known fields (unless the field was explicitly submitted as part of CreateData)
if _, ok := payload[e.Collection.OAuth2.MappedFields.Id]; !ok && e.Collection.OAuth2.MappedFields.Id != "" {
payload[e.Collection.OAuth2.MappedFields.Id] = e.OAuth2User.Id
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.Name]; !ok && e.Collection.OAuth2.MappedFields.Name != "" {
payload[e.Collection.OAuth2.MappedFields.Name] = e.OAuth2User.Name
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.Username]; !ok &&
// no explicit username payload value and existing OAuth2 mapping
e.Collection.OAuth2.MappedFields.Username != "" &&
// extra checks for backward compatibility with earlier versions
oldCanAssignUsername(txApp, e.Collection, e.OAuth2User.Username) {
payload[e.Collection.OAuth2.MappedFields.Username] = e.OAuth2User.Username
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.AvatarURL]; !ok && e.Collection.OAuth2.MappedFields.AvatarURL != "" {
mappedField := e.Collection.Fields.GetByName(e.Collection.OAuth2.MappedFields.AvatarURL)
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
// download the avatar if the mapped field is a file
avatarFile, err := func() (*filesystem.File, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
}()
if err != nil {
return err
}
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = avatarFile
} else {
// otherwise - assign the url string
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = e.OAuth2User.AvatarURL
}
}
createdRecord, err := sendOAuth2RecordCreateRequest(txApp, e, payload)
if err != nil {
return err
}
e.Record = createdRecord
if e.Record.Email() == e.OAuth2User.Email && !e.Record.Verified() {
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
e.Record.SetVerified(true)
if err := txApp.Save(e.Record); err != nil {
return err
}
}
} else {
var needUpdate bool
isLoggedAuthRecord := e.Auth != nil &&
e.Auth.Id == e.Record.Id &&
e.Auth.Collection().Id == e.Record.Collection().Id
// set random password for users with unverified email
// (this is in case a malicious actor has registered previously with the user email)
if !isLoggedAuthRecord && e.Record.Email() != "" && !e.Record.Verified() {
e.Record.SetPassword(security.RandomString(30))
needUpdate = true
}
// update the existing auth record empty email if the data.OAuth2User has one
// (this is in case previously the auth record was created
// with an OAuth2 provider that didn't return an email address)
if e.Record.Email() == "" && e.OAuth2User.Email != "" {
e.Record.SetEmail(e.OAuth2User.Email)
needUpdate = true
}
// update the existing auth record verified state
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
if !e.Record.Verified() && (e.Record.Email() == "" || e.Record.Email() == e.OAuth2User.Email) {
e.Record.SetVerified(true)
needUpdate = true
}
if needUpdate {
if err := txApp.Save(e.Record); err != nil {
return err
}
}
}
// create ExternalAuth relation if missing
if optExternalAuth == nil {
optExternalAuth = core.NewExternalAuth(txApp)
optExternalAuth.SetCollectionRef(e.Record.Collection().Id)
optExternalAuth.SetRecordRef(e.Record.Id)
optExternalAuth.SetProvider(e.ProviderName)
optExternalAuth.SetProviderId(e.OAuth2User.Id)
if err := txApp.Save(optExternalAuth); err != nil {
return fmt.Errorf("failed to save linked rel: %w", err)
}
}
return nil
})
}
func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2RequestEvent, payload map[string]any) (*core.Record, error) {
ir := &core.InternalRequest{
Method: http.MethodPost,
URL: "/api/collections/" + e.Collection.Name + "/records",
Body: payload,
}
response, err := processInternalRequest(txApp, e.RequestEvent, ir, core.RequestInfoContextOAuth2, nil)
if err != nil {
return nil, err
}
if response.Status != http.StatusOK {
return nil, errors.New("failed to create OAuth2 auth record")
}
recordResponse := struct {
Id string `json:"id"`
}{}
raw, err := json.Marshal(response.Body)
if err != nil {
return nil, err
}
if err = json.Unmarshal(raw, &recordResponse); err != nil {
return nil, err
}
return txApp.FindRecordById(e.Collection, recordResponse.Id)
}

View File

@ -0,0 +1,74 @@
package apis
import (
"encoding/json"
"net/http"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
const (
oauth2SubscriptionTopic string = "@oauth2"
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
)
type oauth2RedirectData struct {
State string `form:"state" json:"state"`
Code string `form:"code" json:"code"`
Error string `form:"error" json:"error,omitempty"`
}
func oauth2SubscriptionRedirect(e *core.RequestEvent) error {
redirectStatusCode := http.StatusTemporaryRedirect
if e.Request.Method != http.MethodGet {
redirectStatusCode = http.StatusSeeOther
}
data := oauth2RedirectData{}
if e.Request.Method == http.MethodPost {
if err := e.BindBody(&data); err != nil {
e.App.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
} else {
query := e.Request.URL.Query()
data.State = query.Get("state")
data.Code = query.Get("code")
data.Error = query.Get("error")
}
if data.State == "" {
e.App.Logger().Debug("Missing OAuth2 state parameter")
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
client, err := e.App.SubscriptionsBroker().ClientById(data.State)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
e.App.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
defer client.Unsubscribe(oauth2SubscriptionTopic)
encodedData, err := json.Marshal(data)
if err != nil {
e.App.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
msg := subscriptions.Message{
Name: oauth2SubscriptionTopic,
Data: encodedData,
}
client.Send(msg)
if data.Error != "" || data.Code == "" {
e.App.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
return e.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
}

View File

@ -0,0 +1,252 @@
package apis_test
import (
"context"
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
func TestRecordAuthWithOAuth2Redirect(t *testing.T) {
t.Parallel()
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
for i := 0; i < 10; i++ {
c1 := subscriptions.NewDefaultClient()
c2 := subscriptions.NewDefaultClient()
c2.Subscribe("@oauth2")
c3 := subscriptions.NewDefaultClient()
c3.Subscribe("test1", "@oauth2")
c4 := subscriptions.NewDefaultClient()
c4.Subscribe("test1", "test2")
c5 := subscriptions.NewDefaultClient()
c5.Subscribe("@oauth2")
c5.Discard()
clientStubs = append(clientStubs, map[string]subscriptions.Client{
"c1": c1,
"c2": c2,
"c3": c3,
"c4": c4,
"c5": c5,
})
}
checkFailureRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-failure") {
t.Fatalf("Expected failure redirect, got %q", loc)
}
}
checkSuccessRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-success") {
t.Fatalf("Expected success redirect, got %q", loc)
}
}
checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
if len(expectedMessages[clientId]) == 0 {
t.Fatalf("Unexpected client %q message, got %s:\n%s", clientId, msg.Name, msg.Data)
}
if msg.Name != "@oauth2" {
t.Fatalf("Expected @oauth2 msg.Name, got %q", msg.Name)
}
for _, txt := range expectedMessages[clientId] {
if !strings.Contains(string(msg.Data), txt) {
t.Fatalf("Failed to find %q in \n%s", txt, msg.Data)
}
}
}
beforeTestFunc := func(
clients map[string]subscriptions.Client,
expectedMessages map[string][]string,
) func(testing.TB, *tests.TestApp, *core.ServeEvent) {
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
for _, client := range clients {
app.SubscriptionsBroker().Register(client)
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
// add to the app store so that it can be cancelled manually after test completion
app.Store().Set("cancelFunc", cancelFunc)
go func() {
defer cancelFunc()
for {
select {
case msg := <-clients["c1"].Channel():
checkClientMessages(t, "c1", msg, expectedMessages)
case msg := <-clients["c2"].Channel():
checkClientMessages(t, "c2", msg, expectedMessages)
case msg := <-clients["c3"].Channel():
checkClientMessages(t, "c3", msg, expectedMessages)
case msg := <-clients["c4"].Channel():
checkClientMessages(t, "c4", msg, expectedMessages)
case msg := <-clients["c5"].Channel():
checkClientMessages(t, "c5", msg, expectedMessages)
case <-ctx.Done():
for _, c := range clients {
close(c.Channel())
}
return
}
}
}()
}
}
scenarios := []tests.ApiScenario{
{
Name: "no state query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123",
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "invalid or missing client",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=missing",
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "no code query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "error query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "discarded client with @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client with @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "(POST) client with @oauth2 subscription",
Method: http.MethodPost,
URL: "/api/oauth2-redirect",
Body: strings.NewReader("code=123&state=" + clientStubs[7]["c3"].Id()),
Headers: map[string]string{
"content-type": "application/x-www-form-urlencoded",
},
BeforeTestFunc: beforeTestFunc(clientStubs[7], map[string][]string{
"c3": {`"state":"` + clientStubs[7]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusSeeOther,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[7]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,99 @@
package apis
import (
"errors"
"fmt"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func recordAuthWithOTP(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OTP.Enabled {
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
}
form := &authWithOTPForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.RecordAuthWithOTPRequestEvent)
event.RequestEvent = e
event.Collection = collection
// extra validations
// (note: returns a generic 400 as a very basic OTPs enumeration protection)
// ---
event.OTP, err = e.App.FindOTPById(form.OTPId)
if err != nil {
return e.BadRequestError("Invalid or expired OTP", err)
}
if event.OTP.CollectionRef() != collection.Id {
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is for a different collection"))
}
if event.OTP.HasExpired(collection.OTP.DurationTime()) {
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is expired"))
}
event.Record, err = e.App.FindRecordById(event.OTP.CollectionRef(), event.OTP.RecordRef())
if err != nil {
return e.BadRequestError("Invalid or expired OTP", fmt.Errorf("missing auth record: %w", err))
}
// since otps are usually simple digit numbers we enforce an extra rate limit rule to prevent enumerations
err = checkRateLimit(e, "@pb_otp_"+event.OTP.Id+event.Record.Id, core.RateLimitRule{MaxRequests: 4, Duration: 180})
if err != nil {
return e.TooManyRequestsError("Too many attempts, please try again later with a new OTP.", nil)
}
if !event.OTP.ValidatePassword(form.Password) {
return e.BadRequestError("Invalid or expired OTP", errors.New("incorrect password"))
}
// ---
return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error {
err = RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil)
if err != nil {
return err
}
// try to delete the used otp
if e.OTP != nil {
err = e.App.Delete(e.OTP)
if err != nil {
e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id)
}
}
// note: we don't update the user verified state the same way as in the password reset confirmation
// at the moment because it is not clear whether the otp confirmation came from the user email
// (e.g. it could be from an sms or some other channel)
return nil
})
}
// -------------------------------------------------------------------
type authWithOTPForm struct {
OTPId string `form:"otpId" json:"otpId"`
Password string `form:"password" json:"password"`
}
func (form *authWithOTPForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.OTPId, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 71)),
)
}

View File

@ -0,0 +1,438 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordAuthWithOTP(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-with-otp",
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with disabled otp",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
usersCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
usersCol.OTP.Enabled = false
if err := app.Save(usersCol); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"otpId":{"code":"validation_required"`,
`"password":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid request data",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 256) + `",
"password":"` + strings.Repeat("a", 72) + `"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"otpId":{"code":"validation_length_out_of_range"`,
`"password":{"code":"validation_length_out_of_range"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing otp",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"missing",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "otp for different collection",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(client.Collection().Id)
otp.SetRecordRef(client.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "otp with wrong password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("1234567890")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired otp with valid password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
otp.SetRaw("created", expiredDate)
otp.SetRaw("updated", expiredDate)
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid otp with valid password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"mfaId":"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
// ---
"OnModelValidate": 1,
"OnModelCreate": 1, // mfa record
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1, // otp delete
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// ---
"OnRecordValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
{
Name: "valid otp with valid password (disabled MFA)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
if err := app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"record":{`,
`"email":"test@example.com"`,
},
NotExpectedContent: []string{
`"meta":`,
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// ---
"OnModelValidate": 1,
"OnModelCreate": 1, // authOrigin
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1, // otp delete
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// ---
"OnRecordValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authWithOTP",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithOTP"},
{MaxRequests: 100, Label: "users:auth"},
{MaxRequests: 0, Label: "users:authWithOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authWithOTP",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:auth"},
{MaxRequests: 0, Label: "*:authWithOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - users:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithOTP"},
{MaxRequests: 0, Label: "users:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordAuthWithOTPManualRateLimiterCheck(t *testing.T) {
t.Parallel()
var storeCache map[string]any
otpAId := strings.Repeat("a", 15)
otpBId := strings.Repeat("b", 15)
scenarios := []struct {
otpId string
password string
expectedStatus int
}{
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpAId, "123456", 429},
{otpBId, "12345", 400},
{otpBId, "123456", 200},
}
for _, s := range scenarios {
(&tests.ApiScenario{
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + s.otpId + `",
"password":"` + s.password + `"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
for k, v := range storeCache {
app.Store().Set(k, v)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
if err := app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
for _, id := range []string{otpAId, otpBId} {
otp := core.NewOTP(app)
otp.Id = id
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: s.expectedStatus,
ExpectedContent: []string{`"`}, // it doesn't matter anything non-empty
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
storeCache = app.Store().GetAll()
},
}).Test(t)
}
}

View File

@ -0,0 +1,97 @@
package apis
import (
"database/sql"
"errors"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/list"
)
func recordAuthWithPassword(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.PasswordAuth.Enabled {
return e.ForbiddenError("The collection is not configured to allow password authentication.", nil)
}
form := &authWithPasswordForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(collection); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
var foundRecord *core.Record
var foundErr error
if form.IdentityField != "" {
foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, form.IdentityField, form.Identity)
} else {
// prioritize email lookup
isEmail := is.EmailFormat.Validate(form.Identity) == nil
if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) {
foundRecord, foundErr = e.App.FindAuthRecordByEmail(collection.Id, form.Identity)
}
// search by the other identity fields
if !isEmail || foundErr != nil {
for _, name := range collection.PasswordAuth.IdentityFields {
if !isEmail && name == core.FieldNameEmail {
continue // no need to search by the email field if it is not an email
}
foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, name, form.Identity)
if foundErr == nil {
break
}
}
}
}
// ignore not found errors to allow custom record find implementations
if foundErr != nil && !errors.Is(foundErr, sql.ErrNoRows) {
return e.InternalServerError("", foundErr)
}
event := new(core.RecordAuthWithPasswordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = foundRecord
event.Identity = form.Identity
event.Password = form.Password
event.IdentityField = form.IdentityField
return e.App.OnRecordAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordRequestEvent) error {
if e.Record == nil || !e.Record.ValidatePassword(e.Password) {
return e.BadRequestError("Failed to authenticate.", errors.New("invalid login credentials"))
}
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodPassword, nil)
})
}
// -------------------------------------------------------------------
type authWithPasswordForm struct {
Identity string `form:"identity" json:"identity"`
Password string `form:"password" json:"password"`
// IdentityField specifies the field to use to search for the identity
// (leave it empty for "auto" detection).
IdentityField string `form:"identityField" json:"identityField"`
}
func (form *authWithPasswordForm) validate(collection *core.Collection) error {
return validation.ValidateStruct(form,
validation.Field(&form.Identity, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 255)),
validation.Field(&form.IdentityField, validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...)),
)
}

View File

@ -0,0 +1,514 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthWithPassword(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "disabled password auth",
Method: http.MethodPost,
URL: "/api/collections/nologin/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body format",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{"identity`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body params",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{"identity":"","password":""}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"identity":{`,
`"password":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAuthWithPasswordRequest error response",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field and invalid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"invalid"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field (email) and valid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity field (username) and valid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"clients57772",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity field and valid password with mismatched explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field and valid password with matched explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"clients57772",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity (unverified) and valid password in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test2@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "already authenticated record",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"gk390qegs4y47wn"`,
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "with mfa first auth check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 401,
ExpectedContent: []string{
`"mfaId":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
// mfa create
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
mfas, err := app.FindAllMFAsByRecord(user)
if err != nil {
t.Fatal(err)
}
if v := len(mfas); v != 1 {
t.Fatalf("Expected 1 mfa record to be created, got %d", v)
}
},
},
{
Name: "with mfa second auth check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"mfaId": "` + strings.Repeat("a", 15) + `",
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert a dummy mfa record
mfa := core.NewMFA(app)
mfa.Id = strings.Repeat("a", 15)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("test")
if err := app.Save(mfa); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 0, // disabled auth email alerts
"OnMailerRecordAuthAlertSend": 0,
// mfa delete
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
{
Name: "with enabled mfa but unsatisfied mfa rule (aka. skip the mfa check)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
users, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
users.MFA.Enabled = true
users.MFA.Rule = "1=2"
if err := app.Save(users); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 0, // disabled auth email alerts
"OnMailerRecordAuthAlertSend": 0,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
mfas, err := app.FindAllMFAsByRecord(user)
if err != nil {
t.Fatal(err)
}
if v := len(mfas); v != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", v)
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authWithPassword",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithPassword"},
{MaxRequests: 100, Label: "users:auth"},
{MaxRequests: 0, Label: "users:authWithPassword"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authWithPassword",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:auth"},
{MaxRequests: 0, Label: "*:authWithPassword"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - users:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithPassword"},
{MaxRequests: 0, Label: "users:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -1,121 +1,123 @@
package apis package apis
import ( import (
"errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"strings"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/forms" "github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/resolvers" "github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search" "github.com/pocketbase/pocketbase/tools/search"
) )
// bindRecordCrudApi registers the record crud api endpoints and // bindRecordCrudApi registers the record crud api endpoints and
// the corresponding handlers. // the corresponding handlers.
func bindRecordCrudApi(app core.App, rg *echo.Group) { //
api := recordApi{app: app} // note: the rate limiter is "inlined" because some of the crud actions are also used in the batch APIs
func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group( subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
"/collections/:collection", subGroup.GET("", recordsList)
ActivityLogger(app), subGroup.GET("/{id}", recordView)
) subGroup.POST("", recordCreate(nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.PATCH("/{id}", recordUpdate(nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.GET("/records", api.list, LoadCollectionContext(app)) subGroup.DELETE("/{id}", recordDelete(nil))
subGroup.GET("/records/:id", api.view, LoadCollectionContext(app))
subGroup.POST("/records", api.create, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
subGroup.PATCH("/records/:id", api.update, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
subGroup.DELETE("/records/:id", api.delete, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
} }
type recordApi struct { func recordsList(e *core.RequestEvent) error {
app core.App collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
} if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
func (api *recordApi) list(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
} }
requestInfo := RequestInfo(c) err = checkCollectionRateLimit(e, collection, "list")
if err != nil {
// forbid users and guests to query special filter/sort fields
if err := checkForAdminOnlyRuleFields(requestInfo); err != nil {
return err return err
} }
if requestInfo.Admin == nil && collection.ListRule == nil { requestInfo, err := e.RequestInfo()
// only admins can access if the rule is nil if err != nil {
return NewForbiddenError("Only admins can perform this action.", nil) return firstApiError(err, e.BadRequestError("", err))
} }
fieldsResolver := resolvers.NewRecordFieldResolver( if collection.ListRule == nil && !requestInfo.HasSuperuserAuth() {
api.app.Dao(), return e.ForbiddenError("Only superusers can perform this action.", nil)
}
// forbid users and guests to query special filter/sort fields
err = checkForSuperuserOnlyRuleFields(requestInfo)
if err != nil {
return err
}
fieldsResolver := core.NewRecordFieldResolver(
e.App,
collection, collection,
requestInfo, requestInfo,
// hidden fields are searchable only by admins // hidden fields are searchable only by superusers
requestInfo.Admin != nil, requestInfo.HasSuperuserAuth(),
) )
searchProvider := search.NewProvider(fieldsResolver). searchProvider := search.NewProvider(fieldsResolver).
Query(api.app.Dao().RecordQuery(collection)) Query(e.App.RecordQuery(collection))
if requestInfo.Admin == nil && collection.ListRule != nil { if !requestInfo.HasSuperuserAuth() && collection.ListRule != nil {
searchProvider.AddFilter(search.FilterData(*collection.ListRule)) searchProvider.AddFilter(search.FilterData(*collection.ListRule))
} }
records := []*models.Record{} records := []*core.Record{}
result, err := searchProvider.ParseAndExec(c.QueryParams().Encode(), &records) result, err := searchProvider.ParseAndExec(e.Request.URL.Query().Encode(), &records)
if err != nil { if err != nil {
return NewBadRequestError("", err) return firstApiError(err, e.BadRequestError("", err))
} }
event := new(core.RecordsListEvent) event := new(core.RecordsListRequestEvent)
event.HttpContext = c event.RequestEvent = e
event.Collection = collection event.Collection = collection
event.Records = records event.Records = records
event.Result = result event.Result = result
return api.app.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListEvent) error { return e.App.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListRequestEvent) error {
if e.HttpContext.Response().Committed { if err := EnrichRecords(e.RequestEvent, e.Records); err != nil {
return nil return firstApiError(err, e.InternalServerError("Failed to enrich records", err))
} }
if err := EnrichRecords(e.HttpContext, api.app.Dao(), e.Records); err != nil { return e.JSON(http.StatusOK, e.Result)
api.app.Logger().Debug("Failed to enrich list records", slog.String("error", err.Error()))
}
return e.HttpContext.JSON(http.StatusOK, e.Result)
}) })
} }
func (api *recordApi) view(c echo.Context) error { func recordView(e *core.RequestEvent) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection) collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if collection == nil { if err != nil || collection == nil {
return NewNotFoundError("", "Missing collection context.") return e.NotFoundError("Missing collection context.", err)
} }
recordId := c.PathParam("id") err = checkCollectionRateLimit(e, collection, "view")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" { if recordId == "" {
return NewNotFoundError("", nil) return e.NotFoundError("", nil)
} }
requestInfo := RequestInfo(c) requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if requestInfo.Admin == nil && collection.ViewRule == nil { if collection.ViewRule == nil && !requestInfo.HasSuperuserAuth() {
// only admins can access if the rule is nil return e.ForbiddenError("Only superusers can perform this action.", nil)
return NewForbiddenError("Only admins can perform this action.", nil)
} }
ruleFunc := func(q *dbx.SelectQuery) error { ruleFunc := func(q *dbx.SelectQuery) error {
if requestInfo.Admin == nil && collection.ViewRule != nil && *collection.ViewRule != "" { if !requestInfo.HasSuperuserAuth() && collection.ViewRule != nil && *collection.ViewRule != "" {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true) resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver) expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver)
if err != nil { if err != nil {
return err return err
@ -126,290 +128,472 @@ func (api *recordApi) view(c echo.Context) error {
return nil return nil
} }
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc) record, fetchErr := e.App.FindRecordById(collection, recordId, ruleFunc)
if fetchErr != nil || record == nil { if fetchErr != nil || record == nil {
return NewNotFoundError("", fetchErr) return firstApiError(err, e.NotFoundError("", fetchErr))
} }
event := new(core.RecordViewEvent) event := new(core.RecordRequestEvent)
event.HttpContext = c event.RequestEvent = e
event.Collection = collection event.Collection = collection
event.Record = record event.Record = record
return api.app.OnRecordViewRequest().Trigger(event, func(e *core.RecordViewEvent) error { return e.App.OnRecordViewRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
if e.HttpContext.Response().Committed { if err := EnrichRecord(e.RequestEvent, e.Record); err != nil {
return nil return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
} }
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil { return e.JSON(http.StatusOK, e.Record)
api.app.Logger().Debug(
"Failed to enrich view record",
slog.String("id", e.Record.Id),
slog.String("collectionName", e.Record.Collection().Name),
slog.String("error", err.Error()),
)
}
return e.HttpContext.JSON(http.StatusOK, e.Record)
}) })
} }
func (api *recordApi) create(c echo.Context) error { func recordCreate(optFinalizer func() error) func(e *core.RequestEvent) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection) return func(e *core.RequestEvent) error {
if collection == nil { collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
return NewNotFoundError("", "Missing collection context.") if err != nil || collection == nil {
} return e.NotFoundError("Missing collection context.", err)
}
requestInfo := RequestInfo(c) if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
if requestInfo.Admin == nil && collection.CreateRule == nil { err = checkCollectionRateLimit(e, collection, "create")
// only admins can access if the rule is nil if err != nil {
return NewForbiddenError("Only admins can perform this action.", nil) return err
} }
hasFullManageAccess := requestInfo.Admin != nil requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
// temporary save the record and check it against the create rule hasSuperuserAuth := requestInfo.HasSuperuserAuth()
if requestInfo.Admin == nil && collection.CreateRule != nil { canSkipRuleCheck := hasSuperuserAuth
testRecord := models.NewRecord(collection)
// special case for the first superuser creation
// ---
if !canSkipRuleCheck && collection.Name == core.CollectionNameSuperusers {
total, totalErr := e.App.CountRecords(core.CollectionNameSuperusers)
canSkipRuleCheck = totalErr == nil && total == 0
}
// ---
if !canSkipRuleCheck && collection.CreateRule == nil {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
record := core.NewRecord(collection)
data, err := recordDataFromRequest(e, record)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
}
// replace modifiers fields so that the resolved value is always // replace modifiers fields so that the resolved value is always
// available when accessing requestInfo.Data using just the field name // available when accessing requestInfo.Body
if requestInfo.HasModifierDataKeys() { requestInfo.Body = data
requestInfo.Data = testRecord.ReplaceModifers(requestInfo.Data)
}
testForm := forms.NewRecordUpsert(api.app, testRecord) form := forms.NewRecordUpsert(e.App, record)
testForm.SetFullManageAccess(true) if hasSuperuserAuth {
if err := testForm.LoadRequest(c.Request(), ""); err != nil { form.GrantSuperuserAccess()
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
} }
form.Load(data)
// force unset the verified state to prevent ManageRule misuse var isOptFinalizerCalled bool
if !hasFullManageAccess {
testForm.Verified = false
}
createRuleFunc := func(q *dbx.SelectQuery) error { event := new(core.RecordRequestEvent)
if *collection.CreateRule == "" { event.RequestEvent = e
return nil // no create rule to resolve event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordCreateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
form.SetApp(e.App)
form.SetRecord(e.Record)
// temporary save the record and check it against the create and manage rules
if !canSkipRuleCheck && e.Collection.CreateRule != nil {
// temporary grant manager access level
form.GrantManagerAccess()
// manually unset the verified field to prevent manage API rule misuse in case the rule relies on it
initialVerified := e.Record.Verified()
if initialVerified {
e.Record.SetVerified(false)
}
createRuleFunc := func(q *dbx.SelectQuery) error {
if *e.Collection.CreateRule == "" {
return nil // no create rule to resolve
}
resolver := core.NewRecordFieldResolver(e.App, e.Collection, requestInfo, true)
expr, err := search.FilterData(*e.Collection.CreateRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
return nil
}
testErr := form.DrySubmit(func(txApp core.App, drySavedRecord *core.Record) error {
foundRecord, err := txApp.FindRecordById(drySavedRecord.Collection(), drySavedRecord.Id, createRuleFunc)
if err != nil {
return fmt.Errorf("DrySubmit create rule failure: %w", err)
}
// reset the form access level in case it satisfies the Manage API rule
if !hasAuthManageAccess(txApp, requestInfo, foundRecord) {
form.ResetAccess()
}
return nil
})
if testErr != nil {
return e.BadRequestError("Failed to create record.", testErr)
}
// restore initial verified state (it will be further validated on submit)
if initialVerified != e.Record.Verified() {
e.Record.SetVerified(initialVerified)
}
} }
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true) err := form.Submit()
expr, err := search.FilterData(*collection.CreateRule).BuildExpr(resolver) if err != nil {
return firstApiError(err, e.BadRequestError("Failed to create record.", err))
}
err = EnrichRecord(e.RequestEvent, e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = e.JSON(http.StatusOK, e.Record)
if err != nil { if err != nil {
return err return err
} }
resolver.UpdateQuery(q)
q.AndWhere(expr)
return nil
}
testErr := testForm.DrySubmit(func(txDao *daos.Dao) error { if optFinalizer != nil {
foundRecord, err := txDao.FindRecordById(collection.Id, testRecord.Id, createRuleFunc) isOptFinalizerCalled = true
if err != nil { err = optFinalizer()
return fmt.Errorf("DrySubmit create rule failure: %w", err) if err != nil {
return firstApiError(err, e.InternalServerError("", err))
}
} }
hasFullManageAccess = hasAuthManageAccess(txDao, foundRecord, requestInfo)
return nil return nil
}) })
if hookErr != nil {
if testErr != nil { return hookErr
return NewBadRequestError("Failed to create record.", testErr)
} }
}
record := models.NewRecord(collection) // e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
form := forms.NewRecordUpsert(api.app, record) if !isOptFinalizerCalled && optFinalizer != nil {
form.SetFullManageAccess(hasFullManageAccess) if err := optFinalizer(); err != nil {
return firstApiError(err, e.InternalServerError("", err))
// load request }
if err := form.LoadRequest(c.Request(), ""); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.RecordCreateEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.UploadedFiles = form.FilesToUpload()
// create the record
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(m *models.Record) error {
event.Record = m
return api.app.OnRecordBeforeCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to create record.", err)
}
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
api.app.Logger().Debug(
"Failed to enrich create record",
slog.String("id", e.Record.Id),
slog.String("collectionName", e.Record.Collection().Name),
slog.String("error", err.Error()),
)
}
return api.app.OnRecordAfterCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Record)
})
})
} }
})
return nil
}
} }
func (api *recordApi) update(c echo.Context) error { func recordUpdate(optFinalizer func() error) func(e *core.RequestEvent) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection) return func(e *core.RequestEvent) error {
if collection == nil { collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
return NewNotFoundError("", "Missing collection context.") if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
err = checkCollectionRateLimit(e, collection, "update")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return e.NotFoundError("", nil)
}
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
if !hasSuperuserAuth && collection.UpdateRule == nil {
return firstApiError(err, e.ForbiddenError("Only superusers can perform this action.", nil))
}
// eager fetch the record so that the modifiers field values can be resolved
record, err := e.App.FindRecordById(collection, recordId)
if err != nil {
return firstApiError(err, e.NotFoundError("", err))
}
data, err := recordDataFromRequest(e, record)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
}
// replace modifiers fields so that the resolved value is always
// available when accessing requestInfo.Body
requestInfo.Body = data
ruleFunc := func(q *dbx.SelectQuery) error {
if !hasSuperuserAuth && collection.UpdateRule != nil && *collection.UpdateRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
// refetch with access checks
record, err = e.App.FindRecordById(collection, recordId, ruleFunc)
if err != nil {
return firstApiError(err, e.NotFoundError("", err))
}
form := forms.NewRecordUpsert(e.App, record)
if hasSuperuserAuth {
form.GrantSuperuserAccess()
}
form.Load(data)
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordUpdateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
form.SetApp(e.App)
form.SetRecord(e.Record)
if !form.HasManageAccess() && hasAuthManageAccess(e.App, requestInfo, e.Record) {
form.GrantManagerAccess()
}
err := form.Submit()
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to update record.", err))
}
err = EnrichRecord(e.RequestEvent, e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = e.JSON(http.StatusOK, e.Record)
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer()
if err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(); err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
}
}
return nil
} }
}
recordId := c.PathParam("id") func recordDelete(optFinalizer func() error) func(e *core.RequestEvent) error {
if recordId == "" { return func(e *core.RequestEvent) error {
return NewNotFoundError("", nil) collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
} if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
requestInfo := RequestInfo(c) if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
if requestInfo.Admin == nil && collection.UpdateRule == nil { err = checkCollectionRateLimit(e, collection, "delete")
// only admins can access if the rule is nil if err != nil {
return NewForbiddenError("Only admins can perform this action.", nil) return err
} }
// eager fetch the record so that the modifier field values are replaced recordId := e.Request.PathValue("id")
// and available when accessing requestInfo.Data using just the field name if recordId == "" {
if requestInfo.HasModifierDataKeys() { return e.NotFoundError("", nil)
record, err := api.app.Dao().FindRecordById(collection.Id, recordId) }
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule == nil {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
ruleFunc := func(q *dbx.SelectQuery) error {
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule != nil && *collection.DeleteRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
record, err := e.App.FindRecordById(collection, recordId, ruleFunc)
if err != nil || record == nil { if err != nil || record == nil {
return NewNotFoundError("", err) return e.NotFoundError("", err)
} }
requestInfo.Data = record.ReplaceModifers(requestInfo.Data)
}
ruleFunc := func(q *dbx.SelectQuery) error { var isOptFinalizerCalled bool
if requestInfo.Admin == nil && collection.UpdateRule != nil && *collection.UpdateRule != "" {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true) event := new(core.RecordRequestEvent)
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver) event.RequestEvent = e
event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordDeleteRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
if err := e.App.Delete(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
}
err = e.NoContent(http.StatusNoContent)
if err != nil { if err != nil {
return err return err
} }
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
// fetch record if optFinalizer != nil {
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc) isOptFinalizerCalled = true
if fetchErr != nil || record == nil { err = optFinalizer()
return NewNotFoundError("", fetchErr) if err != nil {
} return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
form := forms.NewRecordUpsert(api.app, record)
form.SetFullManageAccess(requestInfo.Admin != nil || hasAuthManageAccess(api.app.Dao(), record, requestInfo))
// load request
if err := form.LoadRequest(c.Request(), ""); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.RecordUpdateEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.UploadedFiles = form.FilesToUpload()
// update the record
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(m *models.Record) error {
event.Record = m
return api.app.OnRecordBeforeUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to update record.", err)
} }
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
api.app.Logger().Debug(
"Failed to enrich update record",
slog.String("id", e.Record.Id),
slog.String("collectionName", e.Record.Collection().Name),
slog.String("error", err.Error()),
)
}
return api.app.OnRecordAfterUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Record)
})
})
}
})
}
func (api *recordApi) delete(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
}
recordId := c.PathParam("id")
if recordId == "" {
return NewNotFoundError("", nil)
}
requestInfo := RequestInfo(c)
if requestInfo.Admin == nil && collection.DeleteRule == nil {
// only admins can access if the rule is nil
return NewForbiddenError("Only admins can perform this action.", nil)
}
ruleFunc := func(q *dbx.SelectQuery) error {
if requestInfo.Admin == nil && collection.DeleteRule != nil && *collection.DeleteRule != "" {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
if fetchErr != nil || record == nil {
return NewNotFoundError("", fetchErr)
}
event := new(core.RecordDeleteEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
return api.app.OnRecordBeforeDeleteRequest().Trigger(event, func(e *core.RecordDeleteEvent) error {
// delete the record
if err := api.app.Dao().DeleteRecord(e.Record); err != nil {
return NewBadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err)
}
return api.app.OnRecordAfterDeleteRequest().Trigger(event, func(e *core.RecordDeleteEvent) error {
if e.HttpContext.Response().Committed {
return nil
} }
return e.HttpContext.NoContent(http.StatusNoContent) return nil
}) })
}) if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(); err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
}
}
return nil
}
}
// -------------------------------------------------------------------
func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[string]any, error) {
info, err := e.RequestInfo()
if err != nil {
return nil, err
}
// resolve regular fields
result := record.ReplaceModifiers(info.Body)
// resolve uploaded files
uploadedFiles, err := extractUploadedFiles(e.Request, record.Collection(), "")
if err != nil {
return nil, err
}
if len(uploadedFiles) > 0 {
for k, v := range uploadedFiles {
result[k] = v
}
result = record.ReplaceModifiers(result)
}
isAuth := record.Collection().IsAuth()
// unset hidden fields for non-superusers
if !info.HasSuperuserAuth() {
for _, f := range record.Collection().Fields {
if f.GetHidden() {
// exception for the auth collection "password" field
if isAuth && f.GetName() == core.FieldNamePassword {
continue
}
delete(result, f.GetName())
}
}
}
return result, nil
}
func extractUploadedFiles(request *http.Request, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) {
contentType := request.Header.Get("content-type")
if !strings.HasPrefix(contentType, "multipart/form-data") {
return nil, nil // not multipart/form-data request
}
result := map[string][]*filesystem.File{}
for _, field := range collection.Fields {
if field.Type() != core.FieldTypeFile {
continue
}
baseKey := field.GetName()
keys := []string{
baseKey,
// prepend and append modifiers
"+" + baseKey,
baseKey + "+",
}
for _, k := range keys {
if prefix != "" {
k = prefix + "." + k
}
files, err := FindUploadedFiles(request, k)
if err != nil && !errors.Is(err, http.ErrMissingFile) {
return nil, err
}
if len(files) > 0 {
result[k] = files
}
}
}
return result, nil
} }

View File

@ -0,0 +1,314 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudAuthOriginList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with authOrigins",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"9r2j0m74260ur8i"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without authOrigins",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"9r2j0m74260ur8i"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"fingerprint": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"fingerprint":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"fingerprint":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"id":"9r2j0m74260ur8i"`,
`"fingerprint":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,316 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudExternalAuthList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with externalAuths",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"f1z5b3843pzc964"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without externalAuths",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// users, test2@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"dlmflokuq1xl342"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"provider": "github",
"providerId": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
`"providerId":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"providerId": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"id":"dlmflokuq1xl342"`,
`"providerId":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,388 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudMFAList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with mfas",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"user1_0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without mfas",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFAView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"user1_0"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFADelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFACreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"method": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFAUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"method":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"id":"user1_0"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,388 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudOTPList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with otps",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"user1_0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without otps",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"user1_0"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"password": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"password":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"id":"user1_0"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,371 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudSuperuserList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalPages":1`,
`"totalItems":4`,
`"items":[{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 4,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sywbhecnh46rhm0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 4, // + 3 AuthOrigins
"OnModelDeleteExecute": 4,
"OnModelAfterDeleteSuccess": 4,
"OnRecordDelete": 4,
"OnRecordDeleteExecute": 4,
"OnRecordAfterDeleteSuccess": 4,
},
},
{
Name: "delete the last superuser",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all other superusers
superusers, err := app.FindAllRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{"id": "sywbhecnh46rhm0"}))
if err != nil {
t.Fatal(err)
}
for _, superuser := range superusers {
if err = app.Delete(superuser); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelAfterDeleteError": 1,
"OnRecordDelete": 1,
"OnRecordAfterDeleteError": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"email": "test_new@example.com",
"password": "1234567890",
"passwordConfirm": "1234567890",
"verified": false
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "guest creating first superuser",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all superusers
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
if err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"verified":true`,
},
NotExpectedContent: []string{
// because the action has no auth the email field shouldn't be returned if emailVisibility is not set
`"email"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"email":"test_new@example.com"`,
`"verified":true`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"email": "test_new@example.com",
"verified": true
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"id":"sywbhecnh46rhm0"`,
`"email":"test_new@example.com"`,
`"verified":true`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,121 +1,111 @@
package apis package apis
import ( import (
"database/sql"
"errors"
"fmt" "fmt"
"log"
"log/slog"
"net/http" "net/http"
"strings" "strings"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/search" "github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
) )
const ContextRequestInfoKey = "requestInfo" const (
expandQueryParam = "expand"
fieldsQueryParam = "fields"
)
const expandQueryParam = "expand" // RecordAuthResponse writes standardized json record auth response
const fieldsQueryParam = "fields"
// Deprecated: Use RequestInfo instead.
func RequestData(c echo.Context) *models.RequestInfo {
log.Println("RequestData(c) is deprecated and will be removed in the future! You can replace it with RequestInfo(c).")
return RequestInfo(c)
}
// RequestInfo exports cached common request data fields
// (query, body, logged auth state, etc.) from the provided context.
func RequestInfo(c echo.Context) *models.RequestInfo {
// return cached to avoid copying the body multiple times
if v := c.Get(ContextRequestInfoKey); v != nil {
if data, ok := v.(*models.RequestInfo); ok {
// refresh auth state
data.AuthRecord, _ = c.Get(ContextAuthRecordKey).(*models.Record)
data.Admin, _ = c.Get(ContextAdminKey).(*models.Admin)
return data
}
}
result := &models.RequestInfo{
Context: models.RequestInfoContextDefault,
Method: c.Request().Method,
Query: map[string]any{},
Data: map[string]any{},
Headers: map[string]any{},
}
// extract the first value of all headers and normalizes the keys
// ("X-Token" is converted to "x_token")
for k, v := range c.Request().Header {
if len(v) > 0 {
result.Headers[inflector.Snakecase(k)] = v[0]
}
}
result.AuthRecord, _ = c.Get(ContextAuthRecordKey).(*models.Record)
result.Admin, _ = c.Get(ContextAdminKey).(*models.Admin)
echo.BindQueryParams(c, &result.Query)
rest.BindBody(c, &result.Data)
c.Set(ContextRequestInfoKey, result)
return result
}
// RecordAuthResponse writes standardised json record auth response
// into the specified request context. // into the specified request context.
func RecordAuthResponse( //
app core.App, // The authMethod argument specify the name of the current authentication method (eg. password, oauth2, etc.)
c echo.Context, // that it is used primarily as an auth identifier during MFA and for login alerts.
authRecord *models.Record, //
meta any, // Set authMethod to empty string if you want to ignore the MFA checks and the login alerts
finalizers ...func(token string) error, // (can be also adjusted additionally via the OnRecordAuthRequest hook).
) error { func RecordAuthResponse(e *core.RequestEvent, authRecord *core.Record, authMethod string, meta any) error {
if !authRecord.Verified() && authRecord.Collection().AuthOptions().OnlyVerified { token, tokenErr := authRecord.NewAuthToken()
return NewForbiddenError("Please verify your account first.", nil)
}
token, tokenErr := tokens.NewRecordAuthToken(app, authRecord)
if tokenErr != nil { if tokenErr != nil {
return NewBadRequestError("Failed to create auth token.", tokenErr) return e.InternalServerError("Failed to create auth token.", tokenErr)
} }
event := new(core.RecordAuthEvent) return recordAuthResponse(e, authRecord, token, authMethod, meta)
event.HttpContext = c }
func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token string, authMethod string, meta any) error {
originalRequestInfo, err := e.RequestInfo()
if err != nil {
return err
}
ok, err := e.App.CanAccessRecord(authRecord, originalRequestInfo, authRecord.Collection().AuthRule)
if !ok {
return firstApiError(err, e.ForbiddenError("The request doesn't satisfy the collection requirements to authenticate.", err))
}
event := new(core.RecordAuthRequestEvent)
event.RequestEvent = e
event.Collection = authRecord.Collection() event.Collection = authRecord.Collection()
event.Record = authRecord event.Record = authRecord
event.Token = token event.Token = token
event.Meta = meta event.Meta = meta
event.AuthMethod = authMethod
return app.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthEvent) error { return e.App.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthRequestEvent) error {
if e.HttpContext.Response().Committed { if e.Written() {
return nil return nil
} }
// allow always returning the email address of the authenticated account // MFA
e.Record.IgnoreEmailVisibility(true) // ---
mfaId, err := checkMFA(e.RequestEvent, e.Record, e.AuthMethod)
if err != nil {
return err
}
// expand record relations // require additional authentication
expands := strings.Split(c.QueryParam(expandQueryParam), ",") if mfaId != "" {
if len(expands) > 0 { return e.JSON(http.StatusUnauthorized, map[string]string{
// create a copy of the cached request data and adjust it to the current auth record "mfaId": mfaId,
requestInfo := *RequestInfo(e.HttpContext) })
requestInfo.Admin = nil }
requestInfo.AuthRecord = e.Record // ---
failed := app.Dao().ExpandRecord(
e.Record, // create a shallow copy of the cached request data and adjust it to the current auth record
expands, requestInfo := *originalRequestInfo
expandFetch(app.Dao(), &requestInfo), requestInfo.Auth = e.Record
)
if len(failed) > 0 { err = triggerRecordEnrichHooks(e.App, &requestInfo, []*core.Record{e.Record}, func() error {
app.Logger().Debug("[RecordAuthResponse] Failed to expand relations", slog.Any("errors", failed)) if e.Record.IsSuperuser() {
e.Record.Unhide(e.Record.Collection().Fields.FieldNames()...)
}
// allow always returning the email address of the authenticated model
e.Record.IgnoreEmailVisibility(true)
// expand record relations
expands := strings.Split(e.Request.URL.Query().Get(expandQueryParam), ",")
if len(expands) > 0 {
failed := e.App.ExpandRecord(e.Record, expands, expandFetch(e.App, &requestInfo))
if len(failed) > 0 {
e.App.Logger().Warn("[recordAuthResponse] Failed to expand relations", "error", failed)
}
}
return nil
})
if err != nil {
return err
}
if e.AuthMethod != "" && authRecord.Collection().AuthAlert.Enabled {
if err = authAlert(e.RequestEvent, e.Record); err != nil {
e.App.Logger().Warn("[recordAuthResponse] Failed to send login alert", "error", err)
} }
} }
@ -128,68 +118,254 @@ func RecordAuthResponse(
result["meta"] = e.Meta result["meta"] = e.Meta
} }
for _, f := range finalizers { return e.JSON(http.StatusOK, result)
if err := f(e.Token); err != nil { })
return err }
}
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule.
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
rule := record.Collection().MFA.Rule
if rule == "" {
return true, nil
}
requestInfo, err := e.RequestInfo()
if err != nil {
return false, err
}
var exists bool
query := e.App.RecordQuery(record.Collection()).
Select("(1)").
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
// parse and apply the access rule filter
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
expr, err := search.FilterData(rule).BuildExpr(resolver)
if err != nil {
return false, err
}
resolver.UpdateQuery(query)
err = query.AndWhere(expr).Limit(1).Row(&exists)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return exists, nil
}
// checkMFA handles any MFA auth checks that needs to be performed for the specified request event.
// Returns the mfaId that needs to be written as response to the user.
//
// (note: all auth methods are treated as equal and there is no requirement for "pairing").
func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod string) (string, error) {
if !authRecord.Collection().MFA.Enabled || currentAuthMethod == "" {
return "", nil
}
ok, err := wantsMFA(e, authRecord)
if !ok {
if err != nil {
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
} }
return e.HttpContext.JSON(http.StatusOK, result) return "", nil // no mfa needed for this auth record
}) }
// read the mfaId either from the qyery params or request body
mfaId := e.Request.URL.Query().Get("mfaId")
if mfaId == "" {
// check the body
data := struct {
MfaId string `form:"mfaId" json:"mfaId" xml:"mfaId"`
}{}
if err := e.BindBody(&data); err != nil {
return "", firstApiError(err, e.BadRequestError("Failed to read MFA Id", err))
}
mfaId = data.MfaId
}
// first-time auth
// ---
if mfaId == "" {
mfa := core.NewMFA(e.App)
mfa.SetCollectionRef(authRecord.Collection().Id)
mfa.SetRecordRef(authRecord.Id)
mfa.SetMethod(currentAuthMethod)
if err := e.App.Save(mfa); err != nil {
return "", firstApiError(err, e.InternalServerError("Failed to create MFA record", err))
}
return mfa.Id, nil
}
// second-time auth
// ---
mfa, err := e.App.FindMFAById(mfaId)
deleteMFA := func() {
// try to delete the expired mfa
if mfa != nil {
if deleteErr := e.App.Delete(mfa); deleteErr != nil {
e.App.Logger().Warn("Failed to delete expired MFA record", "error", deleteErr, "mfaId", mfa.Id)
}
}
}
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
deleteMFA()
return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err))
}
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {
return "", e.BadRequestError("Invalid MFA session.", nil)
}
if mfa.Method() == currentAuthMethod {
return "", e.BadRequestError("A different authentication method is required.", nil)
}
deleteMFA()
return "", nil
} }
// EnrichRecord parses the request context and enrich the provided record: // EnrichRecord parses the request context and enrich the provided record:
// - expands relations (if defaultExpands and/or ?expand query param is set) // - expands relations (if defaultExpands and/or ?expand query param is set)
// - ensures that the emails of the auth record and its expanded auth relations // - ensures that the emails of the auth record and its expanded auth relations
// are visible only for the current logged admin, record owner or record with manage access // are visible only for the current logged superuser, record owner or record with manage access
func EnrichRecord(c echo.Context, dao *daos.Dao, record *models.Record, defaultExpands ...string) error { func EnrichRecord(e *core.RequestEvent, record *core.Record, defaultExpands ...string) error {
return EnrichRecords(c, dao, []*models.Record{record}, defaultExpands...) return EnrichRecords(e, []*core.Record{record}, defaultExpands...)
} }
// EnrichRecords parses the request context and enriches the provided records: // EnrichRecords parses the request context and enriches the provided records:
// - expands relations (if defaultExpands and/or ?expand query param is set) // - expands relations (if defaultExpands and/or ?expand query param is set)
// - ensures that the emails of the auth records and their expanded auth relations // - ensures that the emails of the auth records and their expanded auth relations
// are visible only for the current logged admin, record owner or record with manage access // are visible only for the current logged superuser, record owner or record with manage access
func EnrichRecords(c echo.Context, dao *daos.Dao, records []*models.Record, defaultExpands ...string) error { //
requestInfo := RequestInfo(c) // Note: Expects all records to be from the same collection!
func EnrichRecords(e *core.RequestEvent, records []*core.Record, defaultExpands ...string) error {
if err := autoIgnoreAuthRecordsEmailVisibility(dao, records, requestInfo); err != nil { if len(records) == 0 {
return fmt.Errorf("failed to resolve email visibility: %w", err) return nil
} }
expands := defaultExpands info, err := e.RequestInfo()
if param := c.QueryParam(expandQueryParam); param != "" { if err != nil {
expands = append(expands, strings.Split(param, ",")...) return err
}
if len(expands) == 0 {
return nil // nothing to expand
} }
errs := dao.ExpandRecords(records, expands, expandFetch(dao, requestInfo)) return triggerRecordEnrichHooks(e.App, info, records, func() error {
if len(errs) > 0 { expands := defaultExpands
return fmt.Errorf("failed to expand: %v", errs) if param := e.Request.URL.Query().Get(expandQueryParam); param != "" {
expands = append(expands, strings.Split(param, ",")...)
}
err := defaultEnrichRecords(e.App, info, records, expands...)
if err != nil {
// only log as it is not critical
e.App.Logger().Warn("failed to apply default enriching", "error", err)
}
return nil
})
}
var iterate func(record *core.Record) error
type iterator[T any] struct {
items []T
index int
}
func (ri *iterator[T]) next() T {
var item T
if ri.index < len(ri.items) {
item = ri.items[ri.index]
ri.index++
}
return item
}
func triggerRecordEnrichHooks(app core.App, requestInfo *core.RequestInfo, records []*core.Record, finalizer func() error) error {
it := iterator[*core.Record]{items: records}
enrichHook := app.OnRecordEnrich()
event := new(core.RecordEnrichEvent)
event.App = app
event.RequestInfo = requestInfo
iterate = func(record *core.Record) error {
if record == nil {
return nil
}
event.Record = record
return enrichHook.Trigger(event, func(ee *core.RecordEnrichEvent) error {
next := it.next()
if next == nil {
if finalizer != nil {
return finalizer()
}
return nil
}
event.App = ee.App // in case it was replaced with a transaction
event.Record = next
err := iterate(next)
event.App = app
event.Record = record
return err
})
}
return iterate(it.next())
}
func defaultEnrichRecords(app core.App, requestInfo *core.RequestInfo, records []*core.Record, expands ...string) error {
err := autoResolveRecordsFlags(app, records, requestInfo)
if err != nil {
return fmt.Errorf("failed to resolve records flags: %w", err)
}
if len(expands) > 0 {
expandErrs := app.ExpandRecords(records, expands, expandFetch(app, requestInfo))
if len(expandErrs) > 0 {
errsSlice := make([]error, 0, len(expandErrs))
for key, err := range expandErrs {
errsSlice = append(errsSlice, fmt.Errorf("failed to expand %q: %w", key, err))
}
return fmt.Errorf("failed to expand records: %w", errors.Join(errsSlice...))
}
} }
return nil return nil
} }
// expandFetch is the records fetch function that is used to expand related records. // expandFetch is the records fetch function that is used to expand related records.
func expandFetch( func expandFetch(app core.App, originalRequestInfo *core.RequestInfo) core.ExpandFetchFunc {
dao *daos.Dao, requestInfoClone := *originalRequestInfo
requestInfo *models.RequestInfo, requestInfoPtr := &requestInfoClone
) daos.ExpandFetchFunc { requestInfoPtr.Context = core.RequestInfoContextExpand
return func(relCollection *models.Collection, relIds []string) ([]*models.Record, error) {
records, err := dao.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error { return func(relCollection *core.Collection, relIds []string) ([]*core.Record, error) {
if requestInfo.Admin != nil { records, findErr := app.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
return nil // admins can access everything if requestInfoPtr.Auth != nil && requestInfoPtr.Auth.IsSuperuser() {
return nil // superusers can access everything
} }
if relCollection.ViewRule == nil { if relCollection.ViewRule == nil {
return fmt.Errorf("only admins can view collection %q records", relCollection.Name) return fmt.Errorf("only superusers can view collection %q records", relCollection.Name)
} }
if *relCollection.ViewRule != "" { if *relCollection.ViewRule != "" {
resolver := resolvers.NewRecordFieldResolver(dao, relCollection, requestInfo, true) resolver := core.NewRecordFieldResolver(app, relCollection, requestInfoPtr, true)
expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver) expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver)
if err != nil { if err != nil {
return err return err
@ -200,50 +376,66 @@ func expandFetch(
return nil return nil
}) })
if findErr != nil {
if err == nil && len(records) > 0 { return nil, findErr
autoIgnoreAuthRecordsEmailVisibility(dao, records, requestInfo)
} }
return records, err enrichErr := triggerRecordEnrichHooks(app, requestInfoPtr, records, func() error {
if err := autoResolveRecordsFlags(app, records, requestInfoPtr); err != nil {
// non-critical error
app.Logger().Warn("Failed to apply autoResolveRecordsFlags for the expanded records", "error", err)
}
return nil
})
if enrichErr != nil {
return nil, enrichErr
}
return records, nil
} }
} }
// autoIgnoreAuthRecordsEmailVisibility ignores the email visibility check for // autoResolveRecordsFlags resolves various visibility flags of the provided records.
// the provided record if the current auth model is admin, owner or a "manager".
// //
// Note: Expects all records to be from the same auth collection! // Currently it enables:
func autoIgnoreAuthRecordsEmailVisibility( // - export of hidden fields if the current auth model is a superuser
dao *daos.Dao, // - email export ignoring the emailVisibity checks if the current auth model is superuser, owner or a "manager".
records []*models.Record, //
requestInfo *models.RequestInfo, // Note: Expects all records to be from the same collection!
) error { func autoResolveRecordsFlags(app core.App, records []*core.Record, requestInfo *core.RequestInfo) error {
if len(records) == 0 || !records[0].Collection().IsAuth() { if len(records) == 0 {
return nil // nothing to check return nil // nothing to resolve
} }
if requestInfo.Admin != nil { if requestInfo.HasSuperuserAuth() {
hiddenFields := records[0].Collection().Fields.FieldNames()
for _, rec := range records { for _, rec := range records {
rec.Unhide(hiddenFields...)
rec.IgnoreEmailVisibility(true) rec.IgnoreEmailVisibility(true)
} }
return nil }
// additional emailVisibility checks
// ---------------------------------------------------------------
if !records[0].Collection().IsAuth() {
return nil // not auth collection records
} }
collection := records[0].Collection() collection := records[0].Collection()
mappedRecords := make(map[string]*models.Record, len(records)) mappedRecords := make(map[string]*core.Record, len(records))
recordIds := make([]any, len(records)) recordIds := make([]any, len(records))
for i, rec := range records { for i, rec := range records {
mappedRecords[rec.Id] = rec mappedRecords[rec.Id] = rec
recordIds[i] = rec.Id recordIds[i] = rec.Id
} }
if requestInfo != nil && requestInfo.AuthRecord != nil && mappedRecords[requestInfo.AuthRecord.Id] != nil { if requestInfo.Auth != nil && mappedRecords[requestInfo.Auth.Id] != nil {
mappedRecords[requestInfo.AuthRecord.Id].IgnoreEmailVisibility(true) mappedRecords[requestInfo.Auth.Id].IgnoreEmailVisibility(true)
} }
authOptions := collection.AuthOptions() if collection.ManageRule == nil || *collection.ManageRule == "" {
if authOptions.ManageRule == nil || *authOptions.ManageRule == "" {
return nil // no manage rule to check return nil // no manage rule to check
} }
@ -251,12 +443,12 @@ func autoIgnoreAuthRecordsEmailVisibility(
// --- // ---
managedIds := []string{} managedIds := []string{}
query := dao.RecordQuery(collection). query := app.RecordQuery(collection).
Select(dao.DB().QuoteSimpleColumnName(collection.Name) + ".id"). Select(app.DB().QuoteSimpleColumnName(collection.Name) + ".id").
AndWhere(dbx.In(dao.DB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...)) AndWhere(dbx.In(app.DB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
resolver := resolvers.NewRecordFieldResolver(dao, collection, requestInfo, true) resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
expr, err := search.FilterData(*authOptions.ManageRule).BuildExpr(resolver) expr, err := search.FilterData(*collection.ManageRule).BuildExpr(resolver)
if err != nil { if err != nil {
return err return err
} }
@ -278,30 +470,26 @@ func autoIgnoreAuthRecordsEmailVisibility(
return nil return nil
} }
// hasAuthManageAccess checks whether the client is allowed to have full // hasAuthManageAccess checks whether the client is allowed to have
// [forms.RecordUpsert] auth management permissions // [forms.RecordUpsert] auth management permissions
// (aka. allowing to change system auth fields without oldPassword). // (e.g. allowing to change system auth fields without oldPassword).
func hasAuthManageAccess( func hasAuthManageAccess(app core.App, requestInfo *core.RequestInfo, record *core.Record) bool {
dao *daos.Dao,
record *models.Record,
requestInfo *models.RequestInfo,
) bool {
if !record.Collection().IsAuth() { if !record.Collection().IsAuth() {
return false return false
} }
manageRule := record.Collection().AuthOptions().ManageRule manageRule := record.Collection().ManageRule
if manageRule == nil || *manageRule == "" { if manageRule == nil || *manageRule == "" {
return false // only for admins (manageRule can't be empty) return false // only for superusers (manageRule can't be empty)
} }
if requestInfo == nil || requestInfo.AuthRecord == nil { if requestInfo == nil || requestInfo.Auth == nil {
return false // no auth record return false // no auth record
} }
ruleFunc := func(q *dbx.SelectQuery) error { ruleFunc := func(q *dbx.SelectQuery) error {
resolver := resolvers.NewRecordFieldResolver(dao, record.Collection(), requestInfo, true) resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, true)
expr, err := search.FilterData(*manageRule).BuildExpr(resolver) expr, err := search.FilterData(*manageRule).BuildExpr(resolver)
if err != nil { if err != nil {
return err return err
@ -311,35 +499,118 @@ func hasAuthManageAccess(
return nil return nil
} }
_, findErr := dao.FindRecordById(record.Collection().Id, record.Id, ruleFunc) _, findErr := app.FindRecordById(record.Collection().Id, record.Id, ruleFunc)
return findErr == nil return findErr == nil
} }
var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam} var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam}
var adminOnlyRuleFields = []string{"@collection.", "@request."} var superuserOnlyRuleFields = []string{"@collection.", "@request."}
// @todo consider moving the rules check to the RecordFieldResolver. // checkForSuperuserOnlyRuleFields loosely checks and returns an error if
// // the provided RequestInfo contains rule fields that only the superuser can use.
// checkForAdminOnlyRuleFields loosely checks and returns an error if func checkForSuperuserOnlyRuleFields(requestInfo *core.RequestInfo) error {
// the provided RequestInfo contains rule fields that only the admin can use. if len(requestInfo.Query) == 0 || requestInfo.HasSuperuserAuth() {
func checkForAdminOnlyRuleFields(requestInfo *models.RequestInfo) error { return nil // superuser or nothing to check
if requestInfo.Admin != nil || len(requestInfo.Query) == 0 {
return nil // admin or nothing to check
} }
for _, param := range ruleQueryParams { for _, param := range ruleQueryParams {
v, _ := requestInfo.Query[param].(string) v := requestInfo.Query[param]
if v == "" { if v == "" {
continue continue
} }
for _, field := range adminOnlyRuleFields { for _, field := range superuserOnlyRuleFields {
if strings.Contains(v, field) { if strings.Contains(v, field) {
return NewForbiddenError("Only admins can filter by "+field, nil) return router.NewForbiddenError("Only superusers can filter by "+field, nil)
} }
} }
} }
return nil return nil
} }
// firstApiError returns the first ApiError from the errors list
// (this is used usually to prevent unnecessary wraping and to allow bubling ApiError from nested hooks)
//
// If no ApiError is found, returns a default "Internal server" error.
func firstApiError(errs ...error) *router.ApiError {
var apiErr *router.ApiError
var ok bool
for _, err := range errs {
if err == nil {
continue
}
// quick assert to avoid the reflection checks
apiErr, ok = err.(*router.ApiError)
if ok {
return apiErr
}
// nested/wrapped errors
if errors.As(err, &apiErr) {
return apiErr
}
}
return router.NewInternalServerError("", errors.Join(errs...))
}
// -------------------------------------------------------------------
const maxAuthOrigins = 5
func authAlert(e *core.RequestEvent, authRecord *core.Record) error {
// generating fingerprint
// ---
userAgent := e.Request.UserAgent()
if len(userAgent) > 300 {
userAgent = userAgent[:300]
}
fingerprint := security.MD5(e.RealIP() + userAgent)
// ---
origins, err := e.App.FindAllAuthOriginsByRecord(authRecord)
if err != nil {
return err
}
isFirstLogin := len(origins) == 0
var currentOrigin *core.AuthOrigin
for _, origin := range origins {
if origin.Fingerprint() == fingerprint {
currentOrigin = origin
break
}
}
if currentOrigin == nil {
currentOrigin = core.NewAuthOrigin(e.App)
currentOrigin.SetCollectionRef(authRecord.Collection().Id)
currentOrigin.SetRecordRef(authRecord.Id)
currentOrigin.SetFingerprint(fingerprint)
}
// send email alert for the new origin auth (skip first login)
if !isFirstLogin && currentOrigin.IsNew() && authRecord.Email() != "" {
if err := mails.SendRecordAuthAlert(e.App, authRecord); err != nil {
return err
}
}
// try to keep only up to maxAuthOrigins
// (pop the last used ones; it is not executed in a transaction to avoid unnecessary locks)
if currentOrigin.IsNew() && len(origins) >= maxAuthOrigins {
for i := len(origins) - 1; i >= maxAuthOrigins-1; i-- {
if err := e.App.Delete(origins[i]); err != nil {
// treat as non-critical error, just log for now
e.App.Logger().Warn("Failed to delete old AuthOrigin record", "error", err, "authOriginId", origins[i].Id)
}
}
}
// create/update the origin fingerprint
return e.App.Save(currentOrigin)
}

View File

@ -6,231 +6,742 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
) )
func TestRequestInfo(t *testing.T) { func TestEnrichRecords(t *testing.T) {
t.Parallel()
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/?test=123", strings.NewReader(`{"test":456}`))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set("X-Token-Test", "123")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
dummyRecord := &models.Record{}
dummyRecord.Id = "id1"
c.Set(apis.ContextAuthRecordKey, dummyRecord)
dummyAdmin := &models.Admin{}
dummyAdmin.Id = "id2"
c.Set(apis.ContextAdminKey, dummyAdmin)
result := apis.RequestInfo(c)
if result == nil {
t.Fatal("Expected *models.RequestInfo instance, got nil")
}
if result.Method != http.MethodPost {
t.Fatalf("Expected Method %v, got %v", http.MethodPost, result.Method)
}
rawHeaders, _ := json.Marshal(result.Headers)
expectedHeaders := `{"content_type":"application/json","x_token_test":"123"}`
if v := string(rawHeaders); v != expectedHeaders {
t.Fatalf("Expected Query %v, got %v", expectedHeaders, v)
}
rawQuery, _ := json.Marshal(result.Query)
expectedQuery := `{"test":"123"}`
if v := string(rawQuery); v != expectedQuery {
t.Fatalf("Expected Query %v, got %v", expectedQuery, v)
}
rawData, _ := json.Marshal(result.Data)
expectedData := `{"test":456}`
if v := string(rawData); v != expectedData {
t.Fatalf("Expected Data %v, got %v", expectedData, v)
}
if result.AuthRecord == nil || result.AuthRecord.Id != dummyRecord.Id {
t.Fatalf("Expected AuthRecord %v, got %v", dummyRecord, result.AuthRecord)
}
if result.Admin == nil || result.Admin.Id != dummyAdmin.Id {
t.Fatalf("Expected Admin %v, got %v", dummyAdmin, result.Admin)
}
}
func TestRecordAuthResponse(t *testing.T) {
t.Parallel() t.Parallel()
// mock test data
// ---
app, _ := tests.NewTestApp() app, _ := tests.NewTestApp()
defer app.Cleanup() defer app.Cleanup()
dummyAdmin := &models.Admin{} user, err := app.FindAuthRecordByEmail("users", "test@example.com")
dummyAdmin.Id = "id1"
nonAuthRecord, err := app.Dao().FindRecordById("demo1", "al1h9ijdeojtsjy")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
authRecord, err := app.Dao().FindRecordById("users", "4q1xlclmfloku33") superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
unverifiedAuthRecord, err := app.Dao().FindRecordById("clients", "o1y0dd0spd786md") usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
if err != nil {
t.Fatal(err)
}
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
if err != nil {
t.Fatal(err)
}
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
if err != nil {
t.Fatal(err)
}
// temp update the view rule to ensure that request context is set to "expand"
demo4, err := app.FindCollectionByNameOrId("demo4")
if err != nil {
t.Fatal(err)
}
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
if err := app.Save(demo4); err != nil {
t.Fatal(err)
}
// ---
scenarios := []struct { scenarios := []struct {
name string name string
record *models.Record auth *core.Record
meta any records []*core.Record
expectError bool queryExpand string
expectedContent []string defaultExpands []string
notExpectedContent []string expected []string
expectedEvents map[string]int notExpected []string
}{ }{
// email visibility checks
{ {
name: "non auth record", name: "[emailVisibility] guest",
record: nonAuthRecord, auth: nil,
expectError: true, records: usersRecords,
}, queryExpand: "",
{ defaultExpands: nil,
name: "valid auth record but with unverified email in onlyVerified collection", expected: []string{
record: unverifiedAuthRecord, `"customField":"123"`,
expectError: true, `"test3@example.com"`, // emailVisibility=true
},
{
name: "valid auth record - without meta",
record: authRecord,
expectError: false,
expectedContent: []string{
`"token":"`,
`"record":{`,
`"id":"`,
`"expand":{"rel":{`,
}, },
notExpectedContent: []string{ notExpected: []string{
`"meta":`, `"test@example.com"`,
},
expectedEvents: map[string]int{
"OnRecordAuthRequest": 1,
}, },
}, },
{ {
name: "valid auth record - with meta", name: "[emailVisibility] owner",
record: authRecord, auth: user,
meta: map[string]any{"meta_test": 123}, records: usersRecords,
expectError: false, queryExpand: "",
expectedContent: []string{ defaultExpands: nil,
`"token":"`, expected: []string{
`"record":{`, `"customField":"123"`,
`"id":"`, `"test3@example.com"`, // emailVisibility=true
`"expand":{"rel":{`, `"test@example.com"`, // owner
`"meta":{"meta_test":123`,
}, },
expectedEvents: map[string]int{ },
"OnRecordAuthRequest": 1, {
name: "[emailVisibility] manager",
auth: user,
records: nologinRecords,
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`,
`"test@example.com"`,
},
},
{
name: "[emailVisibility] superuser",
auth: superuser,
records: nologinRecords,
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`,
`"test@example.com"`,
},
},
{
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
auth: user,
records: demo1Records,
queryExpand: "",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"expand":{"rel_many"`,
`"expand":{}`,
`"test@example.com"`,
},
notExpected: []string{
`"id":"bgs820n361vj1qd"`,
`"id":"oap640cot4yru2s"`,
},
},
{
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
auth: superuser,
records: demo1Records,
queryExpand: "",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"test@example.com"`,
`"expand":{"rel_many"`,
`"id":"bgs820n361vj1qd"`,
`"id":"oap640cot4yru2s"`,
},
notExpected: []string{
`"expand":{}`,
},
},
// expand checks
{
name: "[expand] guest (query)",
auth: nil,
records: usersRecords,
queryExpand: "rel",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"expand":{"rel"`,
`"id":"llvuca81nly1qls"`,
`"id":"0yxhwia2amd8gec"`,
},
notExpected: []string{
`"expand":{}`,
},
},
{
name: "[expand] guest (default expands)",
auth: nil,
records: usersRecords,
queryExpand: "",
defaultExpands: []string{"rel"},
expected: []string{
`"customField":"123"`,
`"expand":{"rel"`,
`"id":"llvuca81nly1qls"`,
`"id":"0yxhwia2amd8gec"`,
},
},
{
name: "[expand] @request.context=expand check",
auth: nil,
records: demo5Records,
queryExpand: "rel_one",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"expand":{}`,
`"expand":{"`,
`"rel_many":[{`,
`"rel_one":{`,
`"id":"i9naidtvr6qsgb4"`,
`"id":"qzaqccwrmva4o1n"`,
}, },
}, },
} }
for _, s := range scenarios { for _, s := range scenarios {
app.ResetEventCalls() t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
e := echo.New() app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
req := httptest.NewRequest(http.MethodGet, "/?expand=rel", nil) e.Record.WithCustomData(true)
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) e.Record.Set("customField", "123")
rec := httptest.NewRecorder() return e.Next()
c := e.NewContext(req, rec) })
c.Set(apis.ContextAdminKey, dummyAdmin)
responseErr := apis.RecordAuthResponse(app, c, s.record, s.meta) req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
rec := httptest.NewRecorder()
hasErr := responseErr != nil requestEvent := new(core.RequestEvent)
if hasErr != s.expectError { requestEvent.App = app
t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.name, s.expectError, hasErr, responseErr) requestEvent.Request = req
} requestEvent.Response = rec
requestEvent.Auth = s.auth
if len(app.EventCalls) != len(s.expectedEvents) { err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
t.Fatalf("[%s] Expected events \n%v, \ngot \n%v", s.name, s.expectedEvents, app.EventCalls) if err != nil {
} t.Fatal(err)
for k, v := range s.expectedEvents {
if app.EventCalls[k] != v {
t.Fatalf("[%s] Expected event %s to be called %d times, got %d", s.name, k, v, app.EventCalls[k])
} }
}
if hasErr { raw, err := json.Marshal(s.records)
continue if err != nil {
} t.Fatal(err)
response := rec.Body.String()
for _, v := range s.expectedContent {
if !strings.Contains(response, v) {
t.Fatalf("[%s] Missing %v in response \n%v", s.name, v, response)
} }
} rawStr := string(raw)
for _, v := range s.notExpectedContent { for _, str := range s.expected {
if strings.Contains(response, v) { if !strings.Contains(rawStr, str) {
t.Fatalf("[%s] Unexpected %v in response \n%v", s.name, v, response) t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
}
} }
}
for _, str := range s.notExpected {
if strings.Contains(rawStr, str) {
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
}
}
})
} }
} }
func TestEnrichRecords(t *testing.T) { func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
t.Parallel()
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?expand=rel_many", nil)
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
dummyAdmin := &models.Admin{}
dummyAdmin.Id = "test_id"
c.Set(apis.ContextAdminKey, dummyAdmin)
app, _ := tests.NewTestApp() app, _ := tests.NewTestApp()
defer app.Cleanup() defer app.Cleanup()
records, err := app.Dao().FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"}) event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = httptest.NewRecorder()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
apis.EnrichRecords(c, app.Dao(), records, "rel_one") scenarios := []struct {
name string
rule *string
expectError bool
}{
{
"admin only rule",
nil,
true,
},
{
"empty rule",
types.Pointer(""),
false,
},
{
"false rule",
types.Pointer("1=2"),
true,
},
{
"true rule",
types.Pointer("1=1"),
false,
},
}
for _, record := range records { for _, s := range scenarios {
expand := record.Expand() t.Run(s.name, func(t *testing.T) {
if len(expand) == 0 { user.Collection().AuthRule = s.rule
t.Fatalf("Expected non-empty expand, got nil for record %v", record)
}
if len(record.GetStringSlice("rel_one")) != 0 { err := apis.RecordAuthResponse(event, user, "", nil)
if _, ok := expand["rel_one"]; !ok {
t.Fatalf("Expected rel_one to be expanded for record %v, got \n%v", record, expand) hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
} }
}
if len(record.GetStringSlice("rel_many")) != 0 { // in all cases login alert shouldn't be send because of the empty auth method
if _, ok := expand["rel_many"]; !ok { if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected rel_many to be expanded for record %v, got \n%v", record, expand) t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
} }
}
if !hasErr {
return
}
apiErr, ok := err.(*router.ApiError)
if !ok || apiErr == nil {
t.Fatalf("Expected ApiError, got %v", apiErr)
}
if apiErr.Status != http.StatusForbidden {
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
}
})
} }
} }
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
scenarios := []struct {
name string
devices []string // mock existing device fingerprints
expectDevices []string
enabled bool
expectEmail bool
}{
{
name: "first login",
devices: nil,
expectDevices: []string{testFingerprint},
enabled: true,
expectEmail: false,
},
{
name: "existing device",
devices: []string{"1", testFingerprint},
expectDevices: []string{"1", testFingerprint},
enabled: true,
expectEmail: false,
},
{
name: "new device (< 5)",
devices: []string{"1", "2"},
expectDevices: []string{"1", "2", testFingerprint},
enabled: true,
expectEmail: true,
},
{
name: "new device (>= 5)",
devices: []string{"1", "2", "3", "4", "5"},
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
enabled: true,
expectEmail: true,
},
{
name: "with disabled auth alert collection flag",
devices: []string{"1", "2"},
expectDevices: []string{"1", "2"},
enabled: false,
expectEmail: false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = httptest.NewRecorder()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
user.Collection().AuthRule = types.Pointer("")
user.Collection().AuthAlert.Enabled = s.enabled
// ensure that there are no other auth origins
err = app.DeleteAllAuthOriginsByRecord(user)
if err != nil {
t.Fatal(err)
}
// insert the mock devices
for _, fingerprint := range s.devices {
d := core.NewAuthOrigin(app)
d.SetCollectionRef(user.Collection().Id)
d.SetRecordRef(user.Id)
d.SetFingerprint(fingerprint)
if err = app.Save(d); err != nil {
t.Fatal(err)
}
}
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Failed to resolve auth response: %v", err)
}
var expectTotalSend int
if s.expectEmail {
expectTotalSend = 1
}
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
}
devices, err := app.FindAllAuthOriginsByRecord(user)
if err != nil {
t.Fatalf("Failed to retrieve auth origins: %v", err)
}
if len(devices) != len(s.expectDevices) {
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
}
for _, fingerprint := range s.expectDevices {
var exists bool
fingerprints := make([]string, 0, len(devices))
for _, d := range devices {
if d.Fingerprint() == fingerprint {
exists = true
break
}
fingerprints = append(fingerprints, d.Fingerprint())
}
if !exists {
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
}
}
})
}
}
func TestRecordAuthResponseMFACheck(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = rec
resetMFAs := func(authRecord *core.Record) {
// ensure that mfa is enabled
user.Collection().MFA.Enabled = true
user.Collection().MFA.Duration = 5
user.Collection().MFA.Rule = ""
mfas, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
t.Fatalf("Failed to retrieve mfas: %v", err)
}
for _, mfa := range mfas {
if err := app.Delete(mfa); err != nil {
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
}
}
// reset response
rec = httptest.NewRecorder()
event.Response = rec
}
totalMFAs := func(authRecord *core.Record) int {
mfas, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
t.Fatalf("Failed to retrieve mfas: %v", err)
}
return len(mfas)
}
t.Run("no collection MFA enabled", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Enabled = false
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("no explicit auth method", func(t *testing.T) {
resetMFAs(user)
err = apis.RecordAuthResponse(event, user, "", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Rule = "1=2"
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Rule = "1=1"
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if !strings.Contains(body, "mfaId") {
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected a single mfa record to be created, got %d", total)
}
})
t.Run("mfa first-time", func(t *testing.T) {
resetMFAs(user)
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if !strings.Contains(body, "mfaId") {
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected a single mfa record to be created, got %d", total)
}
})
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
}
})
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
}
})
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
event.Request.Header.Add("content-type", "application/json")
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
}
})
t.Run("missing mfa", func(t *testing.T) {
resetMFAs(user)
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected 0 mfa records, got %d", total)
}
})
t.Run("expired mfa", func(t *testing.T) {
resetMFAs(user)
// create a dummy expired mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if totalMFAs(user) != 0 {
t.Fatal("Expected the expired mfa record to be deleted")
}
})
t.Run("mfa for different auth record", func(t *testing.T) {
resetMFAs(user)
// create a dummy expired mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user2.Collection().Id)
mfa.SetRecordRef(user2.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no user mfas, got %d", total)
}
if total := totalMFAs(user2); total != 1 {
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
}
})
}

View File

@ -3,6 +3,7 @@ package apis
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -12,14 +13,10 @@ import (
"time" "time"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/migrations" "github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/migrations/logs"
"github.com/pocketbase/pocketbase/tools/list" "github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/migrate" "github.com/pocketbase/pocketbase/ui"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
) )
@ -29,10 +26,16 @@ type ServeConfig struct {
// ShowStartBanner indicates whether to show or hide the server start console message. // ShowStartBanner indicates whether to show or hide the server start console message.
ShowStartBanner bool ShowStartBanner bool
// HttpAddr is the TCP address to listen for the HTTP server (eg. `127.0.0.1:80`). // DashboardPath specifies the route path to the superusers dashboard interface
// (default to "/_/{path...}").
//
// Note: Must include the "{path...}" wildcard parameter.
DashboardPath string
// HttpAddr is the TCP address to listen for the HTTP server (eg. "127.0.0.1:80").
HttpAddr string HttpAddr string
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. `127.0.0.1:443`). // HttpsAddr is the TCP address to listen for the HTTPS server (eg. "127.0.0.1:443").
HttpsAddr string HttpsAddr string
// Optional domains list to use when issuing the TLS certificate. // Optional domains list to use when issuing the TLS certificate.
@ -58,36 +61,43 @@ type ServeConfig struct {
// HttpAddr: "127.0.0.1:8080", // HttpAddr: "127.0.0.1:8080",
// ShowStartBanner: false, // ShowStartBanner: false,
// }) // })
func Serve(app core.App, config ServeConfig) (*http.Server, error) { func Serve(app core.App, config ServeConfig) error {
if len(config.AllowedOrigins) == 0 { if len(config.AllowedOrigins) == 0 {
config.AllowedOrigins = []string{"*"} config.AllowedOrigins = []string{"*"}
} }
if config.DashboardPath == "" {
config.DashboardPath = "/_/{path...}"
} else if !strings.HasSuffix(config.DashboardPath, "{path...}") {
return errors.New("invalid dashboard path - missing {path...} wildcard")
}
// ensure that the latest migrations are applied before starting the server // ensure that the latest migrations are applied before starting the server
if err := runMigrations(app); err != nil { err := app.RunAllMigrations()
return nil, err
}
// reload app settings in case a new default value was set with a migration
// (or if this is the first time the init migration was executed)
if err := app.RefreshSettings(); err != nil {
color.Yellow("=====================================")
color.Yellow("WARNING: Settings load error! \n%v", err)
color.Yellow("Fallback to the application defaults.")
color.Yellow("=====================================")
}
router, err := InitApi(app)
if err != nil { if err != nil {
return nil, err return err
} }
// configure cors pbRouter, err := NewRouter(app)
router.Use(middleware.CORSWithConfig(middleware.CORSConfig{ if err != nil {
Skipper: middleware.DefaultSkipper, return err
AllowOrigins: config.AllowedOrigins, }
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
})) pbRouter.Bind(&hook.Handler[*core.RequestEvent]{
Id: DefaultCorsMiddlewareId,
Func: CORSWithConfig(CORSConfig{
AllowOrigins: config.AllowedOrigins,
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}),
Priority: DefaultCorsMiddlewarePriority,
})
pbRouter.BindFunc(installerRedirect(app, config.DashboardPath))
pbRouter.GET(config.DashboardPath, Static(ui.DistDirFS, false)).
BindFunc(dashboardRemoveInstallerParam()).
BindFunc(dashboardCacheControl()).
BindFunc(Gzip())
// start http server // start http server
// --- // ---
@ -118,25 +128,12 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
// implicit www->non-www redirect(s) // implicit www->non-www redirect(s)
if len(wwwRedirects) > 0 { if len(wwwRedirects) > 0 {
router.Pre(func(next echo.HandlerFunc) echo.HandlerFunc { pbRouter.Bind(wwwRedirect(wwwRedirects))
return func(c echo.Context) error {
host := c.Request().Host
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, wwwRedirects) {
return c.Redirect(
http.StatusTemporaryRedirect,
(c.Scheme() + "://" + host[4:] + c.Request().RequestURI),
)
}
return next(c)
}
})
} }
certManager := &autocert.Manager{ certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(filepath.Join(app.DataDir(), ".autocert_cache")), Cache: autocert.DirCache(filepath.Join(app.DataDir(), core.LocalAutocertCacheDirName)),
HostPolicy: autocert.HostWhitelist(hostNames...), HostPolicy: autocert.HostWhitelist(hostNames...),
} }
@ -151,24 +148,96 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
GetCertificate: certManager.GetCertificate, GetCertificate: certManager.GetCertificate,
NextProtos: []string{acme.ALPNProto}, NextProtos: []string{acme.ALPNProto},
}, },
ReadTimeout: 10 * time.Minute, // higher defaults to accommodate large file uploads/downloads
WriteTimeout: 3 * time.Minute,
ReadTimeout: 3 * time.Minute,
ReadHeaderTimeout: 30 * time.Second, ReadHeaderTimeout: 30 * time.Second,
// WriteTimeout: 60 * time.Second, // breaks sse! Addr: mainAddr,
Handler: router,
Addr: mainAddr,
BaseContext: func(l net.Listener) context.Context { BaseContext: func(l net.Listener) context.Context {
return baseCtx return baseCtx
}, },
ErrorLog: log.New(&serverErrorLogWriter{app: app}, "", 0),
} }
serveEvent := &core.ServeEvent{ serveEvent := new(core.ServeEvent)
App: app, serveEvent.App = app
Router: router, serveEvent.Router = pbRouter
Server: server, serveEvent.Server = server
CertManager: certManager, serveEvent.CertManager = certManager
}
if err := app.OnBeforeServe().Trigger(serveEvent); err != nil { var listener net.Listener
return nil, err
// graceful shutdown
// ---------------------------------------------------------------
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
// Note that the WaitGroup would do nothing if the app.OnTerminate() hook isn't triggered.
var wg sync.WaitGroup
// try to gracefully shutdown the server on app termination
app.OnTerminate().Bind(&hook.Handler[*core.TerminateEvent]{
Id: "pbGracefulShutdown",
Func: func(te *core.TerminateEvent) error {
cancelBaseCtx()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
wg.Add(1)
_ = server.Shutdown(ctx)
if te.IsRestart {
// wait for execve and other handlers up to 3 seconds before exit
time.AfterFunc(3*time.Second, func() {
wg.Done()
})
} else {
wg.Done()
}
return te.Next()
},
Priority: -9999,
})
// wait for the graceful shutdown to complete before exit
defer func() {
wg.Wait()
if listener != nil {
_ = listener.Close()
}
}()
// ---------------------------------------------------------------
// trigger the OnServe hook and start the tcp listener
serveHookErr := app.OnServe().Trigger(serveEvent, func(e *core.ServeEvent) error {
handler, err := e.Router.BuildMux()
if err != nil {
return err
}
e.Server.Handler = handler
addr := e.Server.Addr
// fallback similar to the std Server.ListenAndServe/ListenAndServeTLS
if addr == "" {
if config.HttpsAddr != "" {
addr = ":https"
} else {
addr = ":http"
}
}
var lnErr error
listener, lnErr = net.Listen("tcp", addr)
return lnErr
})
if serveHookErr != nil {
return serveHookErr
} }
if config.ShowStartBanner { if config.ShowStartBanner {
@ -198,80 +267,32 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
regular.Printf("└─ Admin UI: %s\n", color.CyanString("%s://%s/_/", schema, addr)) regular.Printf("└─ Admin UI: %s\n", color.CyanString("%s://%s/_/", schema, addr))
} }
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately. var serveErr error
// Note that the WaitGroup would not do anything if the app.OnTerminate() hook isn't triggered.
var wg sync.WaitGroup
// try to gracefully shutdown the server on app termination
app.OnTerminate().Add(func(e *core.TerminateEvent) error {
cancelBaseCtx()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
wg.Add(1)
server.Shutdown(ctx)
if e.IsRestart {
// wait for execve and other handlers up to 5 seconds before exit
time.AfterFunc(5*time.Second, func() {
wg.Done()
})
} else {
wg.Done()
}
return nil
})
// wait for the graceful shutdown to complete before exit
defer wg.Wait()
// ---
// @todo consider removing the server return value because it is
// not really useful when combined with the blocking serve calls
// ---
// start HTTPS server
if config.HttpsAddr != "" { if config.HttpsAddr != "" {
// if httpAddr is set, start an HTTP server to redirect the traffic to the HTTPS version
if config.HttpAddr != "" { if config.HttpAddr != "" {
// start an additional HTTP server for redirecting the traffic to the HTTPS version
go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil)) go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil))
} }
return server, server.ListenAndServeTLS("", "") // start HTTPS server
serveErr = server.ServeTLS(listener, "", "")
} else {
// OR start HTTP server
serveErr = server.Serve(listener)
} }
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
// OR start HTTP server return serveErr
return server, server.ListenAndServe()
}
type migrationsConnection struct {
DB *dbx.DB
MigrationsList migrate.MigrationsList
}
func runMigrations(app core.App) error {
connections := []migrationsConnection{
{
DB: app.DB(),
MigrationsList: migrations.AppMigrations,
},
{
DB: app.LogsDB(),
MigrationsList: logs.LogsMigrations,
},
}
for _, c := range connections {
runner, err := migrate.NewRunner(c.DB, c.MigrationsList)
if err != nil {
return err
}
if _, err := runner.Up(); err != nil {
return err
}
} }
return nil return nil
} }
type serverErrorLogWriter struct {
app core.App
}
func (s *serverErrorLogWriter) Write(p []byte) (int, error) {
s.app.Logger().Debug(strings.TrimSpace(string(p)))
return len(p), nil
}

View File

@ -4,136 +4,121 @@ import (
"net/http" "net/http"
validation "github.com/go-ozzo/ozzo-validation/v4" validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms" "github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models/settings" "github.com/pocketbase/pocketbase/tools/router"
) )
// bindSettingsApi registers the settings api endpoints. // bindSettingsApi registers the settings api endpoints.
func bindSettingsApi(app core.App, rg *echo.Group) { func bindSettingsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := settingsApi{app: app} subGroup := rg.Group("/settings").Bind(RequireSuperuserAuth())
subGroup.GET("", settingsList)
subGroup := rg.Group("/settings", ActivityLogger(app), RequireAdminAuth()) subGroup.PATCH("", settingsSet)
subGroup.GET("", api.list) subGroup.POST("/test/s3", settingsTestS3)
subGroup.PATCH("", api.set) subGroup.POST("/test/email", settingsTestEmail)
subGroup.POST("/test/s3", api.testS3) subGroup.POST("/apple/generate-client-secret", settingsGenerateAppleClientSecret)
subGroup.POST("/test/email", api.testEmail)
subGroup.POST("/apple/generate-client-secret", api.generateAppleClientSecret)
} }
type settingsApi struct { func settingsList(e *core.RequestEvent) error {
app core.App clone, err := e.App.Settings().Clone()
}
func (api *settingsApi) list(c echo.Context) error {
settings, err := api.app.Settings().RedactClone()
if err != nil { if err != nil {
return NewBadRequestError("", err) return e.InternalServerError("", err)
} }
event := new(core.SettingsListEvent) event := new(core.SettingsListRequestEvent)
event.HttpContext = c event.RequestEvent = e
event.RedactedSettings = settings event.Settings = clone
return api.app.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListEvent) error { return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
if e.HttpContext.Response().Committed { return e.JSON(http.StatusOK, e.Settings)
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.RedactedSettings)
}) })
} }
func (api *settingsApi) set(c echo.Context) error { func settingsSet(e *core.RequestEvent) error {
form := forms.NewSettingsUpsert(api.app) event := new(core.SettingsUpdateRequestEvent)
event.RequestEvent = e
// load request if clone, err := e.App.Settings().Clone(); err == nil {
if err := c.Bind(form); err != nil { event.OldSettings = clone
return NewBadRequestError("An error occurred while loading the submitted data.", err) } else {
return e.BadRequestError("", err)
} }
event := new(core.SettingsUpdateEvent) if clone, err := e.App.Settings().Clone(); err == nil {
event.HttpContext = c event.NewSettings = clone
event.OldSettings = api.app.Settings() } else {
return e.BadRequestError("", err)
}
// update the settings if err := e.BindBody(&event.NewSettings); err != nil {
return form.Submit(func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] { return e.BadRequestError("An error occurred while loading the submitted data.", err)
return func(s *settings.Settings) error { }
event.NewSettings = s
return api.app.OnSettingsBeforeUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error { return e.App.OnSettingsUpdateRequest().Trigger(event, func(e *core.SettingsUpdateRequestEvent) error {
if err := next(e.NewSettings); err != nil { err := e.App.Save(e.NewSettings)
return NewBadRequestError("An error occurred while submitting the form.", err) if err != nil {
} return e.BadRequestError("An error occurred while saving the new settings.", err)
return api.app.OnSettingsAfterUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
redactedSettings, err := api.app.Settings().RedactClone()
if err != nil {
return NewBadRequestError("", err)
}
return e.HttpContext.JSON(http.StatusOK, redactedSettings)
})
})
} }
appSettings, err := e.App.Settings().Clone()
if err != nil {
return e.InternalServerError("Failed to clone app settings.", err)
}
return e.JSON(http.StatusOK, appSettings)
}) })
} }
func (api *settingsApi) testS3(c echo.Context) error { func settingsTestS3(e *core.RequestEvent) error {
form := forms.NewTestS3Filesystem(api.app) form := forms.NewTestS3Filesystem(e.App)
// load request // load request
if err := c.Bind(form); err != nil { if err := e.BindBody(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err) return e.BadRequestError("An error occurred while loading the submitted data.", err)
} }
// send // send
if err := form.Submit(); err != nil { if err := form.Submit(); err != nil {
// form error // form error
if fErr, ok := err.(validation.Errors); ok { if fErr, ok := err.(validation.Errors); ok {
return NewBadRequestError("Failed to test the S3 filesystem.", fErr) return e.BadRequestError("Failed to test the S3 filesystem.", fErr)
} }
// mailer error // mailer error
return NewBadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil) return e.BadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
} }
return c.NoContent(http.StatusNoContent) return e.NoContent(http.StatusNoContent)
} }
func (api *settingsApi) testEmail(c echo.Context) error { func settingsTestEmail(e *core.RequestEvent) error {
form := forms.NewTestEmailSend(api.app) form := forms.NewTestEmailSend(e.App)
// load request // load request
if err := c.Bind(form); err != nil { if err := e.BindBody(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err) return e.BadRequestError("An error occurred while loading the submitted data.", err)
} }
// send // send
if err := form.Submit(); err != nil { if err := form.Submit(); err != nil {
// form error // form error
if fErr, ok := err.(validation.Errors); ok { if fErr, ok := err.(validation.Errors); ok {
return NewBadRequestError("Failed to send the test email.", fErr) return e.BadRequestError("Failed to send the test email.", fErr)
} }
// mailer error // mailer error
return NewBadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil) return e.BadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
} }
return c.NoContent(http.StatusNoContent) return e.NoContent(http.StatusNoContent)
} }
func (api *settingsApi) generateAppleClientSecret(c echo.Context) error { func settingsGenerateAppleClientSecret(e *core.RequestEvent) error {
form := forms.NewAppleClientSecretCreate(api.app) form := forms.NewAppleClientSecretCreate(e.App)
// load request // load request
if err := c.Bind(form); err != nil { if err := e.BindBody(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err) return e.BadRequestError("An error occurred while loading the submitted data.", err)
} }
// generate // generate
@ -141,14 +126,14 @@ func (api *settingsApi) generateAppleClientSecret(c echo.Context) error {
if err != nil { if err != nil {
// form error // form error
if fErr, ok := err.(validation.Errors); ok { if fErr, ok := err.(validation.Errors); ok {
return NewBadRequestError("Invalid client secret data.", fErr) return e.BadRequestError("Invalid client secret data.", fErr)
} }
// secret generation error // secret generation error
return NewBadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil) return e.BadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
} }
return c.JSON(http.StatusOK, map[string]any{ return e.JSON(http.StatusOK, map[string]string{
"secret": secret, "secret": secret,
}) })
} }

View File

@ -6,14 +6,11 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
) )
@ -24,26 +21,28 @@ func TestSettingsList(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/settings", URL: "/api/settings",
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/settings", URL: "/api/settings",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin", Name: "authorized as superuser",
Method: http.MethodGet, Method: http.MethodGet,
Url: "/api/settings", URL: "/api/settings",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
@ -52,44 +51,10 @@ func TestSettingsList(t *testing.T) {
`"smtp":{`, `"smtp":{`,
`"s3":{`, `"s3":{`,
`"backups":{`, `"backups":{`,
`"adminAuthToken":{`, `"batch":{`,
`"adminPasswordResetToken":{`,
`"adminFileToken":{`,
`"recordAuthToken":{`,
`"recordPasswordResetToken":{`,
`"recordEmailChangeToken":{`,
`"recordVerificationToken":{`,
`"recordFileToken":{`,
`"emailAuth":{`,
`"googleAuth":{`,
`"facebookAuth":{`,
`"githubAuth":{`,
`"gitlabAuth":{`,
`"twitterAuth":{`,
`"discordAuth":{`,
`"microsoftAuth":{`,
`"spotifyAuth":{`,
`"kakaoAuth":{`,
`"twitchAuth":{`,
`"stravaAuth":{`,
`"giteeAuth":{`,
`"livechatAuth":{`,
`"giteaAuth":{`,
`"oidcAuth":{`,
`"oidc2Auth":{`,
`"oidc3Auth":{`,
`"appleAuth":{`,
`"instagramAuth":{`,
`"vkAuth":{`,
`"yandexAuth":{`,
`"patreonAuth":{`,
`"mailcowAuth":{`,
`"bitbucketAuth":{`,
`"planningcenterAuth":{`,
`"secret":"******"`,
`"clientSecret":"******"`,
}, },
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"*": 0,
"OnSettingsListRequest": 1, "OnSettingsListRequest": 1,
}, },
}, },
@ -103,35 +68,41 @@ func TestSettingsList(t *testing.T) {
func TestSettingsSet(t *testing.T) { func TestSettingsSet(t *testing.T) {
t.Parallel() t.Parallel()
validData := `{"meta":{"appName":"update_test"}}` validData := `{
"meta":{"appName":"update_test"},
"s3":{"secret": "s3_secret"},
"backups":{"s3":{"secret":"backups_s3_secret"}}
}`
scenarios := []tests.ApiScenario{ scenarios := []tests.ApiScenario{
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPatch, Method: http.MethodPatch,
Url: "/api/settings", URL: "/api/settings",
Body: strings.NewReader(validData), Body: strings.NewReader(validData),
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPatch, Method: http.MethodPatch,
Url: "/api/settings", URL: "/api/settings",
Body: strings.NewReader(validData), Body: strings.NewReader(validData),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin submitting empty data", Name: "authorized as superuser submitting empty data",
Method: http.MethodPatch, Method: http.MethodPatch,
Url: "/api/settings", URL: "/api/settings",
Body: strings.NewReader(``), Body: strings.NewReader(``),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
@ -140,71 +111,46 @@ func TestSettingsSet(t *testing.T) {
`"smtp":{`, `"smtp":{`,
`"s3":{`, `"s3":{`,
`"backups":{`, `"backups":{`,
`"adminAuthToken":{`, `"batch":{`,
`"adminPasswordResetToken":{`,
`"adminFileToken":{`,
`"recordAuthToken":{`,
`"recordPasswordResetToken":{`,
`"recordEmailChangeToken":{`,
`"recordVerificationToken":{`,
`"recordFileToken":{`,
`"emailAuth":{`,
`"googleAuth":{`,
`"facebookAuth":{`,
`"githubAuth":{`,
`"gitlabAuth":{`,
`"discordAuth":{`,
`"microsoftAuth":{`,
`"spotifyAuth":{`,
`"kakaoAuth":{`,
`"twitchAuth":{`,
`"stravaAuth":{`,
`"giteeAuth":{`,
`"livechatAuth":{`,
`"giteaAuth":{`,
`"oidcAuth":{`,
`"oidc2Auth":{`,
`"oidc3Auth":{`,
`"appleAuth":{`,
`"instagramAuth":{`,
`"vkAuth":{`,
`"yandexAuth":{`,
`"patreonAuth":{`,
`"mailcowAuth":{`,
`"bitbucketAuth":{`,
`"planningcenterAuth":{`,
`"secret":"******"`,
`"clientSecret":"******"`,
`"appName":"acme_test"`,
}, },
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1, "*": 0,
"OnModelAfterUpdate": 1, "OnSettingsUpdateRequest": 1,
"OnSettingsBeforeUpdateRequest": 1, "OnModelUpdate": 1,
"OnSettingsAfterUpdateRequest": 1, "OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnSettingsReload": 1,
}, },
}, },
{ {
Name: "authorized as admin submitting invalid data", Name: "authorized as superuser submitting invalid data",
Method: http.MethodPatch, Method: http.MethodPatch,
Url: "/api/settings", URL: "/api/settings",
Body: strings.NewReader(`{"meta":{"appName":""}}`), Body: strings.NewReader(`{"meta":{"appName":""}}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
`"data":{`, `"data":{`,
`"meta":{"appName":{"code":"validation_required"`, `"meta":{"appName":{"code":"validation_required"`,
}, },
ExpectedEvents: map[string]int{
"*": 0,
"OnModelUpdate": 1,
"OnModelAfterUpdateError": 1,
"OnModelValidate": 1,
"OnSettingsUpdateRequest": 1,
},
}, },
{ {
Name: "authorized as admin submitting valid data", Name: "authorized as superuser submitting valid data",
Method: http.MethodPatch, Method: http.MethodPatch,
Url: "/api/settings", URL: "/api/settings",
Body: strings.NewReader(validData), Body: strings.NewReader(validData),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
@ -213,71 +159,21 @@ func TestSettingsSet(t *testing.T) {
`"smtp":{`, `"smtp":{`,
`"s3":{`, `"s3":{`,
`"backups":{`, `"backups":{`,
`"adminAuthToken":{`, `"batch":{`,
`"adminPasswordResetToken":{`,
`"adminFileToken":{`,
`"recordAuthToken":{`,
`"recordPasswordResetToken":{`,
`"recordEmailChangeToken":{`,
`"recordVerificationToken":{`,
`"recordFileToken":{`,
`"emailAuth":{`,
`"googleAuth":{`,
`"facebookAuth":{`,
`"githubAuth":{`,
`"gitlabAuth":{`,
`"twitterAuth":{`,
`"discordAuth":{`,
`"microsoftAuth":{`,
`"spotifyAuth":{`,
`"kakaoAuth":{`,
`"twitchAuth":{`,
`"stravaAuth":{`,
`"giteeAuth":{`,
`"livechatAuth":{`,
`"giteaAuth":{`,
`"oidcAuth":{`,
`"oidc2Auth":{`,
`"oidc3Auth":{`,
`"appleAuth":{`,
`"instagramAuth":{`,
`"vkAuth":{`,
`"yandexAuth":{`,
`"patreonAuth":{`,
`"mailcowAuth":{`,
`"bitbucketAuth":{`,
`"planningcenterAuth":{`,
`"secret":"******"`,
`"clientSecret":"******"`,
`"appName":"update_test"`, `"appName":"update_test"`,
}, },
NotExpectedContent: []string{
"secret",
"password",
},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1, "*": 0,
"OnModelAfterUpdate": 1, "OnSettingsUpdateRequest": 1,
"OnSettingsBeforeUpdateRequest": 1, "OnModelUpdate": 1,
"OnSettingsAfterUpdateRequest": 1, "OnModelUpdateExecute": 1,
}, "OnModelAfterUpdateSuccess": 1,
}, "OnModelValidate": 1,
{ "OnSettingsReload": 1,
Name: "OnSettingsAfterUpdateRequest error response",
Method: http.MethodPatch,
Url: "/api/settings",
Body: strings.NewReader(validData),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnSettingsAfterUpdateRequest().Add(func(e *core.SettingsUpdateEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnSettingsBeforeUpdateRequest": 1,
"OnSettingsAfterUpdateRequest": 1,
}, },
}, },
} }
@ -294,59 +190,64 @@ func TestSettingsTestS3(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/s3", URL: "/api/settings/test/s3",
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/s3", URL: "/api/settings/test/s3",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (missing body + no s3)", Name: "authorized as superuser (missing body + no s3)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/s3", URL: "/api/settings/test/s3",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
`"data":{`, `"data":{`,
`"filesystem":{`, `"filesystem":{`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (invalid filesystem)", Name: "authorized as superuser (invalid filesystem)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/s3", URL: "/api/settings/test/s3",
Body: strings.NewReader(`{"filesystem":"invalid"}`), Body: strings.NewReader(`{"filesystem":"invalid"}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
`"data":{`, `"data":{`,
`"filesystem":{`, `"filesystem":{`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (valid filesystem and no s3)", Name: "authorized as superuser (valid filesystem and no s3)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/s3", URL: "/api/settings/test/s3",
Body: strings.NewReader(`{"filesystem":"storage"}`), Body: strings.NewReader(`{"filesystem":"storage"}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
`"data":{}`, `"data":{}`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }
@ -362,156 +263,199 @@ func TestSettingsTestEmail(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{ Body: strings.NewReader(`{
"template": "verification", "template": "verification",
"email": "test@example.com" "email": "test@example.com"
}`), }`),
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{ Body: strings.NewReader(`{
"template": "verification", "template": "verification",
"email": "test@example.com" "email": "test@example.com"
}`), }`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (invalid body)", Name: "authorized as superuser (invalid body)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{`), Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (empty json)", Name: "authorized as superuser (empty json)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{}`), Body: strings.NewReader(`{}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
`"email":{"code":"validation_required"`, `"email":{"code":"validation_required"`,
`"template":{"code":"validation_required"`, `"template":{"code":"validation_required"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (verifiation template)", Name: "authorized as superuser (verifiation template)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{ Body: strings.NewReader(`{
"template": "verification", "template": "verification",
"email": "test@example.com" "email": "test@example.com"
}`), }`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend != 1 { if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend) t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
} }
if len(app.TestMailer.LastMessage.To) != 1 { if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To) t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
} }
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" { if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address) t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
} }
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Verify") { if !strings.Contains(app.TestMailer.LastMessage().HTML, "Verify") {
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML) t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedContent: []string{}, ExpectedContent: []string{},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnMailerBeforeRecordVerificationSend": 1, "*": 0,
"OnMailerAfterRecordVerificationSend": 1, "OnMailerSend": 1,
"OnMailerRecordVerificationSend": 1,
}, },
}, },
{ {
Name: "authorized as admin (password reset template)", Name: "authorized as superuser (password reset template)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{ Body: strings.NewReader(`{
"template": "password-reset", "template": "password-reset",
"email": "test@example.com" "email": "test@example.com"
}`), }`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend != 1 { if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend) t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
} }
if len(app.TestMailer.LastMessage.To) != 1 { if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To) t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
} }
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" { if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address) t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
} }
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Reset password") { if !strings.Contains(app.TestMailer.LastMessage().HTML, "Reset password") {
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML) t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedContent: []string{}, ExpectedContent: []string{},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnMailerBeforeRecordResetPasswordSend": 1, "*": 0,
"OnMailerAfterRecordResetPasswordSend": 1, "OnMailerSend": 1,
"OnMailerRecordPasswordResetSend": 1,
}, },
}, },
{ {
Name: "authorized as admin (email change)", Name: "authorized as superuser (email change)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/test/email", URL: "/api/settings/test/email",
Body: strings.NewReader(`{ Body: strings.NewReader(`{
"template": "email-change", "template": "email-change",
"email": "test@example.com" "email": "test@example.com"
}`), }`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) { AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend != 1 { if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend) t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
} }
if len(app.TestMailer.LastMessage.To) != 1 { if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To) t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
} }
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" { if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address) t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
} }
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Confirm new email") { if !strings.Contains(app.TestMailer.LastMessage().HTML, "Confirm new email") {
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML) t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
} }
}, },
ExpectedStatus: 204, ExpectedStatus: 204,
ExpectedContent: []string{}, ExpectedContent: []string{},
ExpectedEvents: map[string]int{ ExpectedEvents: map[string]int{
"OnMailerBeforeRecordChangeEmailSend": 1, "*": 0,
"OnMailerAfterRecordChangeEmailSend": 1, "OnMailerSend": 1,
"OnMailerRecordEmailChangeSend": 1,
},
},
{
Name: "authorized as superuser (otp)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "otp",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[otp] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[otp] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[otp] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage().HTML, "one-time password") {
t.Fatalf("[otp] Expected to sent OTP email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordOTPSend": 1,
}, },
}, },
} }
@ -545,38 +489,41 @@ func TestGenerateAppleClientSecret(t *testing.T) {
{ {
Name: "unauthorized", Name: "unauthorized",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret", URL: "/api/settings/apple/generate-client-secret",
ExpectedStatus: 401, ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as auth record", Name: "authorized as regular user",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret", URL: "/api/settings/apple/generate-client-secret",
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
}, },
ExpectedStatus: 401, ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (invalid body)", Name: "authorized as superuser (invalid body)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret", URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{`), Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`}, ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (empty json)", Name: "authorized as superuser (empty json)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret", URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{}`), Body: strings.NewReader(`{}`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
@ -586,11 +533,12 @@ func TestGenerateAppleClientSecret(t *testing.T) {
`"privateKey":{"code":"validation_required"`, `"privateKey":{"code":"validation_required"`,
`"duration":{"code":"validation_required"`, `"duration":{"code":"validation_required"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (invalid data)", Name: "authorized as superuser (invalid data)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret", URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{ Body: strings.NewReader(`{
"clientId": "", "clientId": "",
"teamId": "123456789", "teamId": "123456789",
@ -598,8 +546,8 @@ func TestGenerateAppleClientSecret(t *testing.T) {
"privateKey": "invalid", "privateKey": "invalid",
"duration": -1 "duration": -1
}`), }`),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 400, ExpectedStatus: 400,
ExpectedContent: []string{ ExpectedContent: []string{
@ -609,11 +557,12 @@ func TestGenerateAppleClientSecret(t *testing.T) {
`"privateKey":{"code":"validation_match_invalid"`, `"privateKey":{"code":"validation_match_invalid"`,
`"duration":{"code":"validation_min_greater_equal_than_required"`, `"duration":{"code":"validation_min_greater_equal_than_required"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
{ {
Name: "authorized as admin (valid data)", Name: "authorized as superuser (valid data)",
Method: http.MethodPost, Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret", URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(fmt.Sprintf(`{ Body: strings.NewReader(fmt.Sprintf(`{
"clientId": "123", "clientId": "123",
"teamId": "1234567890", "teamId": "1234567890",
@ -621,13 +570,14 @@ func TestGenerateAppleClientSecret(t *testing.T) {
"privateKey": %q, "privateKey": %q,
"duration": 1 "duration": 1
}`, privatePem)), }`, privatePem)),
RequestHeaders: map[string]string{ Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
}, },
ExpectedStatus: 200, ExpectedStatus: 200,
ExpectedContent: []string{ ExpectedContent: []string{
`"secret":"`, `"secret":"`,
}, },
ExpectedEvents: map[string]int{"*": 0},
}, },
} }

View File

@ -1,141 +0,0 @@
package cmd
import (
"errors"
"fmt"
"github.com/fatih/color"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/spf13/cobra"
)
// NewAdminCommand creates and returns new command for managing
// admin accounts (create, update, delete).
func NewAdminCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "admin",
Short: "Manages admin accounts",
}
command.AddCommand(adminCreateCommand(app))
command.AddCommand(adminUpdateCommand(app))
command.AddCommand(adminDeleteCommand(app))
return command
}
func adminCreateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "create",
Example: "admin create test@example.com 1234567890",
Short: "Creates a new admin account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
if len(args[1]) < 8 {
return errors.New("The password must be at least 8 chars long.")
}
admin := &models.Admin{}
admin.Email = args[0]
admin.SetPassword(args[1])
if !app.Dao().HasTable(admin.TableName()) {
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
}
if err := app.Dao().SaveAdmin(admin); err != nil {
return fmt.Errorf("Failed to create new admin account: %v", err)
}
color.Green("Successfully created new admin %s!", admin.Email)
return nil
},
}
return command
}
func adminUpdateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "update",
Example: "admin update test@example.com 1234567890",
Short: "Changes the password of a single admin account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
if len(args[1]) < 8 {
return errors.New("The new password must be at least 8 chars long.")
}
if !app.Dao().HasTable((&models.Admin{}).TableName()) {
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
}
admin, err := app.Dao().FindAdminByEmail(args[0])
if err != nil {
return fmt.Errorf("Admin with email %s doesn't exist.", args[0])
}
admin.SetPassword(args[1])
if err := app.Dao().SaveAdmin(admin); err != nil {
return fmt.Errorf("Failed to change admin %s password: %v", admin.Email, err)
}
color.Green("Successfully changed admin %s password!", admin.Email)
return nil
},
}
return command
}
func adminDeleteCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "delete",
Example: "admin delete test@example.com",
Short: "Deletes an existing admin account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Invalid or missing email address.")
}
if !app.Dao().HasTable((&models.Admin{}).TableName()) {
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
}
admin, err := app.Dao().FindAdminByEmail(args[0])
if err != nil {
color.Yellow("Admin %s is already deleted.", args[0])
return nil
}
if err := app.Dao().DeleteAdmin(admin); err != nil {
return fmt.Errorf("Failed to delete admin %s: %v", admin.Email, err)
}
color.Green("Successfully deleted admin %s!", admin.Email)
return nil
},
}
return command
}

View File

@ -1,221 +0,0 @@
package cmd_test
import (
"testing"
"github.com/pocketbase/pocketbase/cmd"
"github.com/pocketbase/pocketbase/tests"
)
func TestAdminCreateCommand(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"duplicated email",
"test@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test_new@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
command := cmd.NewAdminCommand(app)
command.SetArgs([]string{"create", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
// check whether the admin account was actually created
admin, err := app.Dao().FindAdminByEmail(s.email)
if err != nil {
t.Errorf("[%s] Failed to fetch created admin %s: %v", s.name, s.email, err)
} else if !admin.ValidatePassword(s.password) {
t.Errorf("[%s] Expected the admin password to match", s.name)
}
}
}
func TestAdminUpdateCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"nonexisting admin",
"test_missing@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
command := cmd.NewAdminCommand(app)
command.SetArgs([]string{"update", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
// check whether the admin password was actually changed
admin, err := app.Dao().FindAdminByEmail(s.email)
if err != nil {
t.Errorf("[%s] Failed to fetch admin %s: %v", s.name, s.email, err)
} else if !admin.ValidatePassword(s.password) {
t.Errorf("[%s] Expected the admin password to match", s.name)
}
}
}
func TestAdminDeleteCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
expectError bool
}{
{
"empty email",
"",
true,
},
{
"invalid email",
"invalid",
true,
},
{
"nonexisting admin",
"test_missing@example.com",
false,
},
{
"existing admin",
"test@example.com",
false,
},
}
for _, s := range scenarios {
command := cmd.NewAdminCommand(app)
command.SetArgs([]string{"delete", s.email})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
// check whether the admin account was actually deleted
if _, err := app.Dao().FindAdminByEmail(s.email); err == nil {
t.Errorf("[%s] Expected the admin account to be deleted", s.name)
}
}
}

View File

@ -15,6 +15,7 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
var allowedOrigins []string var allowedOrigins []string
var httpAddr string var httpAddr string
var httpsAddr string var httpsAddr string
var dashboardPath string
command := &cobra.Command{ command := &cobra.Command{
Use: "serve [domain(s)]", Use: "serve [domain(s)]",
@ -36,9 +37,10 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
} }
} }
_, err := apis.Serve(app, apis.ServeConfig{ err := apis.Serve(app, apis.ServeConfig{
HttpAddr: httpAddr, HttpAddr: httpAddr,
HttpsAddr: httpsAddr, HttpsAddr: httpsAddr,
DashboardPath: dashboardPath,
ShowStartBanner: showStartBanner, ShowStartBanner: showStartBanner,
AllowedOrigins: allowedOrigins, AllowedOrigins: allowedOrigins,
CertificateDomains: args, CertificateDomains: args,
@ -73,5 +75,12 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
"TCP address to listen for the HTTPS server\n(if domain args are specified - default to 0.0.0.0:443, otherwise - default to empty string, aka. no TLS)\nThe incoming HTTP traffic also will be auto redirected to the HTTPS version", "TCP address to listen for the HTTPS server\n(if domain args are specified - default to 0.0.0.0:443, otherwise - default to empty string, aka. no TLS)\nThe incoming HTTP traffic also will be auto redirected to the HTTPS version",
) )
command.PersistentFlags().StringVar(
&dashboardPath,
"dashboard",
"/_/{path...}",
"The route path to the superusers dashboard; must include the '{path...}' wildcard parameter",
)
return command return command
} }

166
cmd/superuser.go Normal file
View File

@ -0,0 +1,166 @@
package cmd
import (
"errors"
"fmt"
"github.com/fatih/color"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/spf13/cobra"
)
// NewSuperuserCommand creates and returns new command for managing
// superuser accounts (create, update, delete).
func NewSuperuserCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "superuser",
Short: "Manages superuser accounts",
}
command.AddCommand(superuserUpsertCommand(app))
command.AddCommand(superuserCreateCommand(app))
command.AddCommand(superuserUpdateCommand(app))
command.AddCommand(superuserDeleteCommand(app))
return command
}
func superuserUpsertCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "upsert",
Example: "superuser upsert test@example.com 1234567890",
Short: "Creates, or updates if email exists, a single superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
return fmt.Errorf("Failed to fetch %q collection: %w.", core.CollectionNameSuperusers, err)
}
superuser, err := app.FindAuthRecordByEmail(superusersCol, args[0])
if err != nil {
superuser = core.NewRecord(superusersCol)
}
superuser.SetEmail(args[0])
superuser.SetPassword(args[1])
if err := app.Save(superuser); err != nil {
return fmt.Errorf("Failed to upsert superuser account: %w.", err)
}
color.Green("Successfully saved superuser %q!", superuser.Email())
return nil
},
}
return command
}
func superuserCreateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "create",
Example: "superuser create test@example.com 1234567890",
Short: "Creates a new superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
return fmt.Errorf("Failed to fetch %q collection: %w.", core.CollectionNameSuperusers, err)
}
superuser := core.NewRecord(superusersCol)
superuser.SetEmail(args[0])
superuser.SetPassword(args[1])
if err := app.Save(superuser); err != nil {
return fmt.Errorf("Failed to create new superuser account: %w.", err)
}
color.Green("Successfully created new superuser %q!", superuser.Email())
return nil
},
}
return command
}
func superuserUpdateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "update",
Example: "superuser update test@example.com 1234567890",
Short: "Changes the password of a single superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
if err != nil {
return fmt.Errorf("Superuser with email %q doesn't exist.", args[0])
}
superuser.SetPassword(args[1])
if err := app.Save(superuser); err != nil {
return fmt.Errorf("Failed to change superuser %q password: %w.", superuser.Email(), err)
}
color.Green("Successfully changed superuser %q password!", superuser.Email())
return nil
},
}
return command
}
func superuserDeleteCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "delete",
Example: "superuser delete test@example.com",
Short: "Deletes an existing superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Invalid or missing email address.")
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
if err != nil {
color.Yellow("Superuser %q is missing or already deleted.", args[0])
return nil
}
if err := app.Delete(superuser); err != nil {
return fmt.Errorf("Failed to delete superuser %q: %w.", superuser.Email(), err)
}
color.Green("Successfully deleted superuser %q!", superuser.Email())
return nil
},
}
return command
}

310
cmd/superuser_test.go Normal file
View File

@ -0,0 +1,310 @@
package cmd_test
import (
"testing"
"github.com/pocketbase/pocketbase/cmd"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestSuperuserUpsertCommand(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"existing user",
"test@example.com",
"1234567890!",
false,
},
{
"new user",
"test_new@example.com",
"1234567890!",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"upsert", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
// check whether the superuser account was actually upserted
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
if err != nil {
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
} else if !superuser.ValidatePassword(s.password) {
t.Fatal("Expected the superuser password to match")
}
})
}
}
func TestSuperuserCreateCommand(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"duplicated email",
"test@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test_new@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"create", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
// check whether the superuser account was actually created
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
if err != nil {
t.Fatalf("Failed to fetch created superuser %s: %v", s.email, err)
} else if !superuser.ValidatePassword(s.password) {
t.Fatal("Expected the superuser password to match")
}
})
}
}
func TestSuperuserUpdateCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"nonexisting superuser",
"test_missing@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"update", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
// check whether the superuser password was actually changed
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
if err != nil {
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
} else if !superuser.ValidatePassword(s.password) {
t.Fatal("Expected the superuser password to match")
}
})
}
}
func TestSuperuserDeleteCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
expectError bool
}{
{
"empty email",
"",
true,
},
{
"invalid email",
"invalid",
true,
},
{
"nonexisting superuser",
"test_missing@example.com",
false,
},
{
"existing superuser",
"test@example.com",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"delete", s.email})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if _, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email); err == nil {
t.Fatal("Expected the superuser account to be deleted")
}
})
}
}

File diff suppressed because it is too large Load Diff

239
core/auth_origin_model.go Normal file
View File

@ -0,0 +1,239 @@
package core
import (
"context"
"errors"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
const CollectionNameAuthOrigins = "_authOrigins"
var (
_ Model = (*AuthOrigin)(nil)
_ PreValidator = (*AuthOrigin)(nil)
_ RecordProxy = (*AuthOrigin)(nil)
)
// AuthOrigin defines a Record proxy for working with the authOrigins collection.
type AuthOrigin struct {
*Record
}
// NewAuthOrigin instantiates and returns a new blank *AuthOrigin model.
//
// Example usage:
//
// origin := core.NewOrigin(app)
// origin.SetRecordRef(user.Id)
// origin.SetCollectionRef(user.Collection().Id)
// origin.SetFingerprint("...")
// app.Save(origin)
func NewAuthOrigin(app App) *AuthOrigin {
m := &AuthOrigin{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameAuthOrigins)
if err != nil {
// this is just to make tests easier since authOrigins is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on AuthOrigin.PreValidate())
c = NewBaseCollection("@___invalid___")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *AuthOrigin) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameAuthOrigins {
return errors.New("missing or invalid AuthOrigin ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *AuthOrigin) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *AuthOrigin) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *AuthOrigin) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *AuthOrigin) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *AuthOrigin) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *AuthOrigin) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// Fingerprint returns the "fingerprint" record field value.
func (m *AuthOrigin) Fingerprint() string {
return m.GetString("fingerprint")
}
// SetFingerprint updates the "fingerprint" record field value.
func (m *AuthOrigin) SetFingerprint(fingerprint string) {
m.Set("fingerprint", fingerprint)
}
// Created returns the "created" record field value.
func (m *AuthOrigin) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *AuthOrigin) Updated() types.DateTime {
return m.GetDateTime("updated")
}
func (app *BaseApp) registerAuthOriginHooks() {
recordRefHooks[*AuthOrigin](app, CollectionNameAuthOrigins, CollectionTypeAuth)
// delete existing auth origins on password change
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
err := e.Next()
if err != nil || !e.Record.Collection().IsAuth() {
return err
}
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
new := e.Record.GetString(FieldNamePassword + ":hash")
if old != new {
err = e.App.DeleteAllAuthOriginsByRecord(e.Record)
if err != nil {
e.App.Logger().Warn(
"Failed to delete all previous auth origin fingerprints",
"error", err,
"recordId", e.Record.Id,
"collectionId", e.Record.Collection().Id,
)
}
}
return nil
},
Priority: 99,
})
}
// -------------------------------------------------------------------
// recordRefHooks registers common hooks that are usually used with record proxies
// that have polymorphic record relations (aka. "collectionRef" and "recordRef" fields).
func recordRefHooks[T RecordProxy](app App, collectionName string, optCollectionTypes ...string) {
app.OnRecordValidate(collectionName).Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
collectionId := e.Record.GetString("collectionRef")
err := validation.Validate(collectionId, validation.Required, validation.By(validateCollectionId(e.App, optCollectionTypes...)))
if err != nil {
return validation.Errors{"collectionRef": err}
}
recordId := e.Record.GetString("recordRef")
err = validation.Validate(recordId, validation.Required, validation.By(validateRecordId(e.App, collectionId)))
if err != nil {
return validation.Errors{"recordRef": err}
}
return e.Next()
},
Priority: 99,
})
// delete on collection ref delete
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
Func: func(e *CollectionEvent) error {
if e.Collection.Name == collectionName || (len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Collection.Type)) {
return e.Next()
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
if err := e.Next(); err != nil {
return err
}
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{"collectionRef": e.Collection.Id})
if err != nil {
return err
}
for _, mfa := range rels {
if err := txApp.Delete(mfa); err != nil {
return err
}
}
return nil
})
e.App = originalApp
return txErr
},
Priority: 99,
})
// delete on record ref delete
app.OnRecordDeleteExecute().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
if e.Record.Collection().Name == collectionName ||
(len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Record.Collection().Type)) {
return e.Next()
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
if err := e.Next(); err != nil {
return err
}
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{
"collectionRef": e.Record.Collection().Id,
"recordRef": e.Record.Id,
})
if err != nil {
return err
}
for _, rel := range rels {
if err := txApp.Delete(rel); err != nil {
return err
}
}
return nil
})
e.App = originalApp
return txErr
},
Priority: 99,
})
}

View File

@ -0,0 +1,332 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewAuthOrigin(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if origin.Collection().Name != core.CollectionNameAuthOrigins {
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameAuthOrigins, origin.Collection().Name)
}
}
func TestAuthOriginProxyRecord(t *testing.T) {
t.Parallel()
record := core.NewRecord(core.NewBaseCollection("test"))
record.Id = "test_id"
origin := core.AuthOrigin{}
origin.SetProxyRecord(record)
if origin.ProxyRecord() == nil || origin.ProxyRecord().Id != record.Id {
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, origin.ProxyRecord())
}
}
func TestAuthOriginRecordRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetRecordRef(testValue)
if v := origin.RecordRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("recordRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginCollectionRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetCollectionRef(testValue)
if v := origin.CollectionRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("collectionRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginFingerprint(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetFingerprint(testValue)
if v := origin.Fingerprint(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("fingerprint"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginCreated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if v := origin.Created().String(); v != "" {
t.Fatalf("Expected empty created, got %q", v)
}
now := types.NowDateTime()
origin.SetRaw("created", now)
if v := origin.Created().String(); v != now.String() {
t.Fatalf("Expected %q created, got %q", now.String(), v)
}
}
func TestAuthOriginUpdated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if v := origin.Updated().String(); v != "" {
t.Fatalf("Expected empty updated, got %q", v)
}
now := types.NowDateTime()
origin.SetRaw("updated", now)
if v := origin.Updated().String(); v != now.String() {
t.Fatalf("Expected %q updated, got %q", now.String(), v)
}
}
func TestAuthOriginPreValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
originsCol, err := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
if err != nil {
t.Fatal(err)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
t.Run("no proxy record", func(t *testing.T) {
origin := &core.AuthOrigin{}
if err := app.Validate(origin); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("non-AuthOrigin collection", func(t *testing.T) {
origin := &core.AuthOrigin{}
origin.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
origin.SetRecordRef(user.Id)
origin.SetCollectionRef(user.Collection().Id)
origin.SetFingerprint("abc")
if err := app.Validate(origin); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("AuthOrigin collection", func(t *testing.T) {
origin := &core.AuthOrigin{}
origin.SetProxyRecord(core.NewRecord(originsCol))
origin.SetRecordRef(user.Id)
origin.SetCollectionRef(user.Collection().Id)
origin.SetFingerprint("abc")
if err := app.Validate(origin); err != nil {
t.Fatalf("Expected nil validation error, got %v", err)
}
})
}
func TestAuthOriginValidateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
origin func() *core.AuthOrigin
expectErrors []string
}{
{
"empty",
func() *core.AuthOrigin {
return core.NewAuthOrigin(app)
},
[]string{"collectionRef", "recordRef", "fingerprint"},
},
{
"non-auth collection",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(demo1.Collection().Id)
origin.SetRecordRef(demo1.Id)
origin.SetFingerprint("abc")
return origin
},
[]string{"collectionRef"},
},
{
"missing record id",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(user.Collection().Id)
origin.SetRecordRef("missing")
origin.SetFingerprint("abc")
return origin
},
[]string{"recordRef"},
},
{
"valid ref",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(user.Collection().Id)
origin.SetRecordRef(user.Id)
origin.SetFingerprint("abc")
return origin
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := app.Validate(s.origin())
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestAuthOriginPasswordChangeDeletion(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// no auth origin associated with it
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{user1, nil},
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{client1, []string{"9r2j0m74260ur8i"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
s.record.SetPassword("new_password")
err := app.Save(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}

101
core/auth_origin_query.go Normal file
View File

@ -0,0 +1,101 @@
package core
import (
"errors"
"github.com/pocketbase/dbx"
)
// FindAllAuthOriginsByRecord returns all AuthOrigin models linked to the provided auth record (in DESC order).
func (app *BaseApp) FindAllAuthOriginsByRecord(authRecord *Record) ([]*AuthOrigin, error) {
result := []*AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAllAuthOriginsByCollection returns all AuthOrigin models linked to the provided collection (in DESC order).
func (app *BaseApp) FindAllAuthOriginsByCollection(collection *Collection) ([]*AuthOrigin, error) {
result := []*AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAuthOriginById returns a single AuthOrigin model by its id.
func (app *BaseApp) FindAuthOriginById(id string) (*AuthOrigin, error) {
result := &AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAuthOriginByRecordAndFingerprint returns a single AuthOrigin model
// by its authRecord relation and fingerprint.
func (app *BaseApp) FindAuthOriginByRecordAndFingerprint(authRecord *Record, fingerprint string) (*AuthOrigin, error) {
result := &AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
"fingerprint": fingerprint,
}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// DeleteAllAuthOriginsByRecord deletes all AuthOrigin models associated with the provided record.
//
// Returns a combined error with the failed deletes.
func (app *BaseApp) DeleteAllAuthOriginsByRecord(authRecord *Record) error {
models, err := app.FindAllAuthOriginsByRecord(authRecord)
if err != nil {
return err
}
var errs []error
for _, m := range models {
if err := app.Delete(m); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}

View File

@ -0,0 +1,268 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestFindAllAuthOriginsByRecord(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
expected []string
}{
{demo1, nil},
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{superuser4, nil},
{client1, []string{"9r2j0m74260ur8i"}},
}
for _, s := range scenarios {
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
result, err := app.FindAllAuthOriginsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAllAuthOriginsByCollection(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
clients, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
collection *core.Collection
expected []string
}{
{demo1, nil},
{superusers, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib", "5f29jy38bf5zm3f"}},
{clients, []string{"9r2j0m74260ur8i"}},
}
for _, s := range scenarios {
t.Run(s.collection.Name, func(t *testing.T) {
result, err := app.FindAllAuthOriginsByCollection(s.collection)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAuthOriginById(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
id string
expectError bool
}{
{"", true},
{"84nmscqy84lsi1t", true}, // non-origin id
{"9r2j0m74260ur8i", false},
}
for _, s := range scenarios {
t.Run(s.id, func(t *testing.T) {
result, err := app.FindAuthOriginById(s.id)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Id != s.id {
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
}
})
}
}
func TestFindAuthOriginByRecordAndFingerprint(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
fingerprint string
expectError bool
}{
{demo1, "6afbfe481c31c08c55a746cccb88ece0", true},
{superuser2, "", true},
{superuser2, "abc", true},
{superuser2, "22bbbcbed36e25321f384ccf99f60057", false}, // fingerprint from different origin
{superuser2, "6afbfe481c31c08c55a746cccb88ece0", false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Id, s.fingerprint), func(t *testing.T) {
result, err := app.FindAuthOriginByRecordAndFingerprint(s.record, s.fingerprint)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Fingerprint() != s.fingerprint {
t.Fatalf("Expected origin with fingerprint %q, got %q", s.fingerprint, result.Fingerprint())
}
if result.RecordRef() != s.record.Id || result.CollectionRef() != s.record.Collection().Id {
t.Fatalf("Expected record %q (%q), got %q (%q)", s.record.Id, s.record.Collection().Id, result.RecordRef(), result.CollectionRef())
}
})
}
}
func TestDeleteAllAuthOriginsByRecord(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{demo1, nil}, // non-auth record
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{superuser4, nil},
{client1, []string{"9r2j0m74260ur8i"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
err := app.DeleteAllAuthOriginsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -12,20 +12,16 @@ import (
"sort" "sort"
"time" "time"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/archive" "github.com/pocketbase/pocketbase/tools/archive"
"github.com/pocketbase/pocketbase/tools/cron"
"github.com/pocketbase/pocketbase/tools/filesystem" "github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/inflector" "github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/osutils" "github.com/pocketbase/pocketbase/tools/osutils"
"github.com/pocketbase/pocketbase/tools/security" "github.com/pocketbase/pocketbase/tools/security"
) )
// Deprecated: Replaced with StoreKeyActiveBackup. const (
const CacheKeyActiveBackup string = "@activeBackup" StoreKeyActiveBackup = "@activeBackup"
)
const StoreKeyActiveBackup string = "@activeBackup"
// CreateBackup creates a new backup of the current app pb_data directory. // CreateBackup creates a new backup of the current app pb_data directory.
// //
@ -50,61 +46,67 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
return errors.New("try again later - another backup/restore operation has already been started") return errors.New("try again later - another backup/restore operation has already been started")
} }
if name == "" {
name = app.generateBackupName("pb_backup_")
}
app.Store().Set(StoreKeyActiveBackup, name) app.Store().Set(StoreKeyActiveBackup, name)
defer app.Store().Remove(StoreKeyActiveBackup) defer app.Store().Remove(StoreKeyActiveBackup)
// root dir entries to exclude from the backup generation event := new(BackupEvent)
exclude := []string{LocalBackupsDirName, LocalTempDirName} event.App = app
event.Context = ctx
event.Name = name
// default root dir entries to exclude from the backup generation
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
// make sure that the special temp directory exists return app.OnBackupCreate().Trigger(event, func(e *BackupEvent) error {
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors // generate a default name if missing
localTempDir := filepath.Join(app.DataDir(), LocalTempDirName) if e.Name == "" {
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil { e.Name = generateBackupName(e.App, "pb_backup_")
return fmt.Errorf("failed to create a temp dir: %w", err) }
}
// Archive pb_data in a temp directory, exluding the "backups" and the temp dirs. // make sure that the special temp directory exists
// // note: it needs to be inside the current pb_data to avoid "cross-device link" errors
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection). localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
// --- if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(4)) return fmt.Errorf("failed to create a temp dir: %w", err)
createErr := app.Dao().RunInTransaction(func(dataTXDao *daos.Dao) error { }
return app.LogsDao().RunInTransaction(func(logsTXDao *daos.Dao) error {
// @todo consider experimenting with temp switching the readonly pragma after the db interface change // archive pb_data in a temp directory, exluding the "backups" and the temp dirs
return archive.Create(app.DataDir(), tempPath, exclude...) //
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
// ---
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(6))
createErr := e.App.RunInTransaction(func(txApp App) error {
return txApp.AuxRunInTransaction(func(txApp App) error {
return archive.Create(txApp.DataDir(), tempPath, e.Exclude...)
})
}) })
if createErr != nil {
return createErr
}
defer os.Remove(tempPath)
// persist the backup in the backups filesystem
// ---
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(e.Context)
file, err := filesystem.NewFileFromPath(tempPath)
if err != nil {
return err
}
file.OriginalName = e.Name
file.Name = file.OriginalName
if err := fsys.UploadFile(file, file.Name); err != nil {
return err
}
return nil
}) })
if createErr != nil {
return createErr
}
defer os.Remove(tempPath)
// Persist the backup in the backups filesystem.
// ---
fsys, err := app.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(ctx)
file, err := filesystem.NewFileFromPath(tempPath)
if err != nil {
return err
}
file.OriginalName = name
file.Name = file.OriginalName
if err := fsys.UploadFile(file, file.Name); err != nil {
return err
}
return nil
} }
// RestoreBackup restores the backup with the specified name and restarts // RestoreBackup restores the backup with the specified name and restarts
@ -136,10 +138,6 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
// If a failure occure during the restore process the dir changes are reverted. // If a failure occure during the restore process the dir changes are reverted.
// If for whatever reason the revert is not possible, it panics. // If for whatever reason the revert is not possible, it panics.
func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error { func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
if runtime.GOOS == "windows" {
return errors.New("restore is not supported on windows")
}
if app.Store().Has(StoreKeyActiveBackup) { if app.Store().Has(StoreKeyActiveBackup) {
return errors.New("try again later - another backup/restore operation has already been started") return errors.New("try again later - another backup/restore operation has already been started")
} }
@ -147,131 +145,129 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
app.Store().Set(StoreKeyActiveBackup, name) app.Store().Set(StoreKeyActiveBackup, name)
defer app.Store().Remove(StoreKeyActiveBackup) defer app.Store().Remove(StoreKeyActiveBackup)
fsys, err := app.NewBackupsFilesystem() event := new(BackupEvent)
if err != nil { event.App = app
return err event.Context = ctx
} event.Name = name
defer fsys.Close() // default root dir entries to exclude from the backup restore
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
fsys.SetContext(ctx) return app.OnBackupRestore().Trigger(event, func(e *BackupEvent) error {
if runtime.GOOS == "windows" {
// fetch the backup file in a temp location return errors.New("restore is not supported on Windows")
br, err := fsys.GetFile(name)
if err != nil {
return err
}
defer br.Close()
// make sure that the special temp directory exists
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
localTempDir := filepath.Join(app.DataDir(), LocalTempDirName)
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
}
// create a temp zip file from the blob.Reader and try to extract it
tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
if err != nil {
return err
}
defer os.Remove(tempZip.Name())
if _, err := io.Copy(tempZip, br); err != nil {
return err
}
extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(4))
defer os.RemoveAll(extractedDataDir)
if err := archive.Extract(tempZip.Name(), extractedDataDir); err != nil {
return err
}
// ensure that a database file exists
extractedDB := filepath.Join(extractedDataDir, "data.db")
if _, err := os.Stat(extractedDB); err != nil {
return fmt.Errorf("data.db file is missing or invalid: %w", err)
}
// remove the extracted zip file since we no longer need it
// (this is in case the app restarts and the defer calls are not called)
if err := os.Remove(tempZip.Name()); err != nil {
app.Logger().Debug(
"[RestoreBackup] Failed to remove the temp zip backup file",
slog.String("file", tempZip.Name()),
slog.String("error", err.Error()),
)
}
// root dir entries to exclude from the backup restore
exclude := []string{LocalBackupsDirName, LocalTempDirName}
// move the current pb_data content to a special temp location
// that will hold the old data between dirs replace
// (the temp dir will be automatically removed on the next app start)
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(4))
if err := osutils.MoveDirContent(app.DataDir(), oldTempDataDir, exclude...); err != nil {
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
}
// move the extracted archive content to the app's pb_data
if err := osutils.MoveDirContent(extractedDataDir, app.DataDir(), exclude...); err != nil {
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
}
revertDataDirChanges := func() error {
if err := osutils.MoveDirContent(app.DataDir(), extractedDataDir, exclude...); err != nil {
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
} }
if err := osutils.MoveDirContent(oldTempDataDir, app.DataDir(), exclude...); err != nil { fsys, err := e.App.NewBackupsFilesystem()
return fmt.Errorf("failed to revert old pb_data dir change: %w", err) if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(e.Context)
// fetch the backup file in a temp location
br, err := fsys.GetFile(name)
if err != nil {
return err
}
defer br.Close()
// make sure that the special temp directory exists
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
} }
return nil // create a temp zip file from the blob.Reader and try to extract it
} tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
if err != nil {
return err
}
defer os.Remove(tempZip.Name())
// restart the app if _, err := io.Copy(tempZip, br); err != nil {
if err := app.Restart(); err != nil { return err
if revertErr := revertDataDirChanges(); revertErr != nil {
panic(revertErr)
} }
return fmt.Errorf("failed to restart the app process: %w", err) extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(4))
} defer os.RemoveAll(extractedDataDir)
if err := archive.Extract(tempZip.Name(), extractedDataDir); err != nil {
return err
}
return nil // ensure that a database file exists
} extractedDB := filepath.Join(extractedDataDir, "data.db")
if _, err := os.Stat(extractedDB); err != nil {
return fmt.Errorf("data.db file is missing or invalid: %w", err)
}
// initAutobackupHooks registers the autobackup app serve hooks. // remove the extracted zip file since we no longer need it
func (app *BaseApp) initAutobackupHooks() error { // (this is in case the app restarts and the defer calls are not called)
c := cron.New() if err := os.Remove(tempZip.Name()); err != nil {
isServe := false e.App.Logger().Debug(
"[RestoreBackup] Failed to remove the temp zip backup file",
loadJob := func() { slog.String("file", tempZip.Name()),
c.Stop()
// make sure that app.Settings() is always up to date
//
// @todo remove with the refactoring as core.App and daos.Dao will be one.
if err := app.RefreshSettings(); err != nil {
app.Logger().Debug(
"[Backup cron] Failed to get the latest app settings",
slog.String("error", err.Error()), slog.String("error", err.Error()),
) )
} }
// move the current pb_data content to a special temp location
// that will hold the old data between dirs replace
// (the temp dir will be automatically removed on the next app start)
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(4))
if err := osutils.MoveDirContent(e.App.DataDir(), oldTempDataDir, e.Exclude...); err != nil {
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
}
// move the extracted archive content to the app's pb_data
if err := osutils.MoveDirContent(extractedDataDir, e.App.DataDir(), e.Exclude...); err != nil {
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
}
revertDataDirChanges := func() error {
if err := osutils.MoveDirContent(e.App.DataDir(), extractedDataDir, e.Exclude...); err != nil {
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
}
if err := osutils.MoveDirContent(oldTempDataDir, e.App.DataDir(), e.Exclude...); err != nil {
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
}
return nil
}
// restart the app
if err := e.App.Restart(); err != nil {
if revertErr := revertDataDirChanges(); revertErr != nil {
panic(revertErr)
}
return fmt.Errorf("failed to restart the app process: %w", err)
}
return nil
})
}
// registerAutobackupHooks registers the autobackup app serve hooks.
func (app *BaseApp) registerAutobackupHooks() {
const jobId = "__auto_pb_backup__"
loadJob := func() {
rawSchedule := app.Settings().Backups.Cron rawSchedule := app.Settings().Backups.Cron
if rawSchedule == "" || !isServe || !app.IsBootstrapped() { if rawSchedule == "" {
app.Cron().Remove(jobId)
return return
} }
c.Add("@autobackup", rawSchedule, func() { app.Cron().Add(jobId, rawSchedule, func() {
const autoPrefix = "@auto_pb_backup_" const autoPrefix = "@auto_pb_backup_"
name := app.generateBackupName(autoPrefix) name := generateBackupName(app, autoPrefix)
if err := app.CreateBackup(context.Background(), name); err != nil { if err := app.CreateBackup(context.Background(), name); err != nil {
app.Logger().Debug( app.Logger().Error(
"[Backup cron] Failed to create backup", "[Backup cron] Failed to create backup",
slog.String("name", name), slog.String("name", name),
slog.String("error", err.Error()), slog.String("error", err.Error()),
@ -286,7 +282,7 @@ func (app *BaseApp) initAutobackupHooks() error {
fsys, err := app.NewBackupsFilesystem() fsys, err := app.NewBackupsFilesystem()
if err != nil { if err != nil {
app.Logger().Debug( app.Logger().Error(
"[Backup cron] Failed to initialize the backup filesystem", "[Backup cron] Failed to initialize the backup filesystem",
slog.String("error", err.Error()), slog.String("error", err.Error()),
) )
@ -296,7 +292,7 @@ func (app *BaseApp) initAutobackupHooks() error {
files, err := fsys.List(autoPrefix) files, err := fsys.List(autoPrefix)
if err != nil { if err != nil {
app.Logger().Debug( app.Logger().Error(
"[Backup cron] Failed to list autogenerated backups", "[Backup cron] Failed to list autogenerated backups",
slog.String("error", err.Error()), slog.String("error", err.Error()),
) )
@ -317,7 +313,7 @@ func (app *BaseApp) initAutobackupHooks() error {
for _, f := range toRemove { for _, f := range toRemove {
if err := fsys.Delete(f.Key); err != nil { if err := fsys.Delete(f.Key); err != nil {
app.Logger().Debug( app.Logger().Error(
"[Backup cron] Failed to remove old autogenerated backup", "[Backup cron] Failed to remove old autogenerated backup",
slog.String("key", f.Key), slog.String("key", f.Key),
slog.String("error", err.Error()), slog.String("error", err.Error()),
@ -325,29 +321,11 @@ func (app *BaseApp) initAutobackupHooks() error {
} }
} }
}) })
// restart the ticker
c.Start()
} }
// load on app serve app.OnBootstrap().BindFunc(func(e *BootstrapEvent) error {
app.OnBeforeServe().Add(func(e *ServeEvent) error { if err := e.Next(); err != nil {
isServe = true return err
loadJob()
return nil
})
// stop the ticker on app termination
app.OnTerminate().Add(func(e *TerminateEvent) error {
c.Stop()
return nil
})
// reload on app settings change
app.OnModelAfterUpdate((&models.Param{}).TableName()).Add(func(e *ModelEvent) error {
p := e.Model.(*models.Param)
if p == nil || p.Key != models.ParamAppSettings {
return nil
} }
loadJob() loadJob()
@ -355,10 +333,18 @@ func (app *BaseApp) initAutobackupHooks() error {
return nil return nil
}) })
return nil app.OnSettingsReload().BindFunc(func(e *SettingsReloadEvent) error {
if err := e.Next(); err != nil {
return err
}
loadJob()
return nil
})
} }
func (app *BaseApp) generateBackupName(prefix string) string { func generateBackupName(app App, prefix string) string {
appName := inflector.Snakecase(app.Settings().Meta.AppName) appName := inflector.Snakecase(app.Settings().Meta.AppName)
if len(appName) > 50 { if len(appName) > 50 {
appName = appName[:50] appName = appName[:50]

View File

@ -128,9 +128,9 @@ func verifyBackupContent(app core.App, path string) error {
"data.db", "data.db",
"data.db-shm", "data.db-shm",
"data.db-wal", "data.db-wal",
"logs.db", "aux.db",
"logs.db-shm", "aux.db-shm",
"logs.db-wal", "aux.db-wal",
".gitignore", ".gitignore",
} }

View File

@ -1,63 +0,0 @@
package core_test
import (
"testing"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestBaseAppRefreshSettings(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
// cleanup all stored settings
if _, err := app.DB().NewQuery("DELETE from _params;").Execute(); err != nil {
t.Fatalf("Failed to delete all test settings: %v", err)
}
// check if the new settings are saved in the db
app.ResetEventCalls()
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh the settings after delete: %v", err)
}
testEventCalls(t, app, map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
})
param, err := app.Dao().FindParamByKey(models.ParamAppSettings)
if err != nil {
t.Fatalf("Expected new settings to be persisted, got %v", err)
}
// change the db entry and refresh the app settings (ensure that there was no db update)
param.Value = types.JsonRaw([]byte(`{"example": 123}`))
if err := app.Dao().SaveParam(param.Key, param.Value); err != nil {
t.Fatalf("Failed to update the test settings: %v", err)
}
app.ResetEventCalls()
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh the app settings: %v", err)
}
testEventCalls(t, app, nil)
// try to refresh again without doing any changes
app.ResetEventCalls()
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh the app settings without change: %v", err)
}
testEventCalls(t, app, nil)
}
func testEventCalls(t *testing.T, app *tests.TestApp, events map[string]int) {
if len(events) != len(app.EventCalls) {
t.Fatalf("Expected events doesn't match: \n%v, \ngot \n%v", events, app.EventCalls)
}
for name, total := range events {
if v, ok := app.EventCalls[name]; !ok || v != total {
t.Fatalf("Expected events doesn't exist or match: \n%v, \ngot \n%v", events, app.EventCalls)
}
}
}

View File

@ -1,59 +1,56 @@
package core package core_test
import ( import (
"context" "context"
"database/sql"
"fmt"
"log/slog" "log/slog"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
"github.com/pocketbase/dbx" _ "unsafe"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/migrations" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/migrations/logs" "github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/logger" "github.com/pocketbase/pocketbase/tools/logger"
"github.com/pocketbase/pocketbase/tools/mailer" "github.com/pocketbase/pocketbase/tools/mailer"
"github.com/pocketbase/pocketbase/tools/migrate"
"github.com/pocketbase/pocketbase/tools/types"
) )
func TestNewBaseApp(t *testing.T) { func TestNewBaseApp(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/" const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir) defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{ app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir, DataDir: testDataDir,
EncryptionEnv: "test_env", EncryptionEnv: "test_env",
IsDev: true, IsDev: true,
}) })
if app.dataDir != testDataDir { if app.DataDir() != testDataDir {
t.Fatalf("expected dataDir %q, got %q", testDataDir, app.dataDir) t.Fatalf("expected DataDir %q, got %q", testDataDir, app.DataDir())
} }
if app.encryptionEnv != "test_env" { if app.EncryptionEnv() != "test_env" {
t.Fatalf("expected encryptionEnv test_env, got %q", app.dataDir) t.Fatalf("expected EncryptionEnv test_env, got %q", app.EncryptionEnv())
} }
if !app.isDev { if !app.IsDev() {
t.Fatalf("expected isDev true, got %v", app.isDev) t.Fatalf("expected IsDev true, got %v", app.IsDev())
} }
if app.store == nil { if app.Store() == nil {
t.Fatal("expected store to be set, got nil") t.Fatal("expected Store to be set, got nil")
} }
if app.settings == nil { if app.Settings() == nil {
t.Fatal("expected settings to be set, got nil") t.Fatal("expected Settings to be set, got nil")
} }
if app.subscriptionsBroker == nil { if app.SubscriptionsBroker() == nil {
t.Fatal("expected subscriptionsBroker to be set, got nil") t.Fatal("expected SubscriptionsBroker to be set, got nil")
}
if app.Cron() == nil {
t.Fatal("expected Cron to be set, got nil")
} }
} }
@ -61,9 +58,8 @@ func TestBaseAppBootstrap(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/" const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir) defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{ app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir, DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
}) })
defer app.ResetBootstrapState() defer app.ResetBootstrapState()
@ -83,72 +79,59 @@ func TestBaseAppBootstrap(t *testing.T) {
t.Fatal("Expected test data directory to be created.") t.Fatal("Expected test data directory to be created.")
} }
if app.dao == nil { type nilCheck struct {
t.Fatal("Expected app.dao to be initialized, got nil.") name string
value any
expectNil bool
} }
if app.dao.BeforeCreateFunc == nil { runNilChecks := func(checks []nilCheck) {
t.Fatal("Expected app.dao.BeforeCreateFunc to be set, got nil.") for _, check := range checks {
t.Run(check.name, func(t *testing.T) {
isNil := check.value == nil
if isNil != check.expectNil {
t.Fatalf("Expected isNil %v, got %v", check.expectNil, isNil)
}
})
}
} }
if app.dao.AfterCreateFunc == nil { nilChecksBeforeReset := []nilCheck{
t.Fatal("Expected app.dao.AfterCreateFunc to be set, got nil.") {"[before] concurrentDB", app.DB(), false},
{"[before] nonconcurrentDB", app.NonconcurrentDB(), false},
{"[before] auxConcurrentDB", app.AuxDB(), false},
{"[before] auxNonconcurrentDB", app.AuxNonconcurrentDB(), false},
{"[before] settings", app.Settings(), false},
{"[before] logger", app.Logger(), false},
{"[before] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
} }
if app.dao.BeforeUpdateFunc == nil { runNilChecks(nilChecksBeforeReset)
t.Fatal("Expected app.dao.BeforeUpdateFunc to be set, got nil.")
}
if app.dao.AfterUpdateFunc == nil {
t.Fatal("Expected app.dao.AfterUpdateFunc to be set, got nil.")
}
if app.dao.BeforeDeleteFunc == nil {
t.Fatal("Expected app.dao.BeforeDeleteFunc to be set, got nil.")
}
if app.dao.AfterDeleteFunc == nil {
t.Fatal("Expected app.dao.AfterDeleteFunc to be set, got nil.")
}
if app.logsDao == nil {
t.Fatal("Expected app.logsDao to be initialized, got nil.")
}
if app.settings == nil {
t.Fatal("Expected app.settings to be initialized, got nil.")
}
if app.logger == nil {
t.Fatal("Expected app.logger to be initialized, got nil.")
}
if _, ok := app.logger.Handler().(*logger.BatchHandler); !ok {
t.Fatal("Expected app.logger handler to be initialized.")
}
// reset // reset
if err := app.ResetBootstrapState(); err != nil { if err := app.ResetBootstrapState(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if app.dao != nil { nilChecksAfterReset := []nilCheck{
t.Fatalf("Expected app.dao to be nil, got %v.", app.dao) {"[after] concurrentDB", app.DB(), true},
{"[after] nonconcurrentDB", app.NonconcurrentDB(), true},
{"[after] auxConcurrentDB", app.AuxDB(), true},
{"[after] auxNonconcurrentDB", app.AuxNonconcurrentDB(), true},
{"[after] settings", app.Settings(), false},
{"[after] logger", app.Logger(), false},
{"[after] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
} }
if app.logsDao != nil { runNilChecks(nilChecksAfterReset)
t.Fatalf("Expected app.logsDao to be nil, got %v.", app.logsDao)
}
} }
func TestBaseAppGetters(t *testing.T) { func TestNewBaseAppIsTransactional(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/" const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir) defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{ app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir, DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
IsDev: true,
}) })
defer app.ResetBootstrapState() defer app.ResetBootstrapState()
@ -156,81 +139,58 @@ func TestBaseAppGetters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if app.dao != app.Dao() { if app.IsTransactional() {
t.Fatalf("Expected app.Dao %v, got %v", app.Dao(), app.dao) t.Fatalf("Didn't expect the app to be transactional")
} }
if app.dao.ConcurrentDB() != app.DB() { app.RunInTransaction(func(txApp core.App) error {
t.Fatalf("Expected app.DB %v, got %v", app.DB(), app.dao.ConcurrentDB()) if !txApp.IsTransactional() {
} t.Fatalf("Expected the app to be transactional")
}
if app.logsDao != app.LogsDao() { return nil
t.Fatalf("Expected app.LogsDao %v, got %v", app.LogsDao(), app.logsDao) })
}
if app.logsDao.ConcurrentDB() != app.LogsDB() {
t.Fatalf("Expected app.LogsDB %v, got %v", app.LogsDB(), app.logsDao.ConcurrentDB())
}
if app.dataDir != app.DataDir() {
t.Fatalf("Expected app.DataDir %v, got %v", app.DataDir(), app.dataDir)
}
if app.encryptionEnv != app.EncryptionEnv() {
t.Fatalf("Expected app.EncryptionEnv %v, got %v", app.EncryptionEnv(), app.encryptionEnv)
}
if app.isDev != app.IsDev() {
t.Fatalf("Expected app.IsDev %v, got %v", app.IsDev(), app.isDev)
}
if app.settings != app.Settings() {
t.Fatalf("Expected app.Settings %v, got %v", app.Settings(), app.settings)
}
if app.store != app.Store() {
t.Fatalf("Expected app.Store %v, got %v", app.Store(), app.store)
}
if app.logger != app.Logger() {
t.Fatalf("Expected app.Logger %v, got %v", app.Logger(), app.logger)
}
if app.subscriptionsBroker != app.SubscriptionsBroker() {
t.Fatalf("Expected app.SubscriptionsBroker %v, got %v", app.SubscriptionsBroker(), app.subscriptionsBroker)
}
if app.onBeforeServe != app.OnBeforeServe() || app.OnBeforeServe() == nil {
t.Fatalf("Getter app.OnBeforeServe does not match or nil (%v vs %v)", app.OnBeforeServe(), app.onBeforeServe)
}
} }
func TestBaseAppNewMailClient(t *testing.T) { func TestBaseAppNewMailClient(t *testing.T) {
app, cleanup, err := initTestBaseApp() const testDataDir = "./pb_base_app_test_data_dir/"
if err != nil { defer os.RemoveAll(testDataDir)
t.Fatal(err)
} app := core.NewBaseApp(core.BaseAppConfig{
defer cleanup() DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
})
defer app.ResetBootstrapState()
client1 := app.NewMailClient() client1 := app.NewMailClient()
if val, ok := client1.(*mailer.Sendmail); !ok { m1, ok := client1.(*mailer.Sendmail)
t.Fatalf("Expected mailer.Sendmail instance, got %v", val) if !ok {
t.Fatalf("Expected mailer.Sendmail instance, got %v", m1)
}
if m1.OnSend() == nil || m1.OnSend().Length() == 0 {
t.Fatal("Expected OnSend hook to be registered")
} }
app.Settings().Smtp.Enabled = true app.Settings().SMTP.Enabled = true
client2 := app.NewMailClient() client2 := app.NewMailClient()
if val, ok := client2.(*mailer.SmtpClient); !ok { m2, ok := client2.(*mailer.SMTPClient)
t.Fatalf("Expected mailer.SmtpClient instance, got %v", val) if !ok {
t.Fatalf("Expected mailer.SMTPClient instance, got %v", m2)
}
if m2.OnSend() == nil || m2.OnSend().Length() == 0 {
t.Fatal("Expected OnSend hook to be registered")
} }
} }
func TestBaseAppNewFilesystem(t *testing.T) { func TestBaseAppNewFilesystem(t *testing.T) {
app, cleanup, err := initTestBaseApp() const testDataDir = "./pb_base_app_test_data_dir/"
if err != nil { defer os.RemoveAll(testDataDir)
t.Fatal(err)
} app := core.NewBaseApp(core.BaseAppConfig{
defer cleanup() DataDir: testDataDir,
})
defer app.ResetBootstrapState()
// local // local
local, localErr := app.NewFilesystem() local, localErr := app.NewFilesystem()
@ -253,11 +213,13 @@ func TestBaseAppNewFilesystem(t *testing.T) {
} }
func TestBaseAppNewBackupsFilesystem(t *testing.T) { func TestBaseAppNewBackupsFilesystem(t *testing.T) {
app, cleanup, err := initTestBaseApp() const testDataDir = "./pb_base_app_test_data_dir/"
if err != nil { defer os.RemoveAll(testDataDir)
t.Fatal(err)
} app := core.NewBaseApp(core.BaseAppConfig{
defer cleanup() DataDir: testDataDir,
})
defer app.ResetBootstrapState()
// local // local
local, localErr := app.NewBackupsFilesystem() local, localErr := app.NewBackupsFilesystem()
@ -280,18 +242,22 @@ func TestBaseAppNewBackupsFilesystem(t *testing.T) {
} }
func TestBaseAppLoggerWrites(t *testing.T) { func TestBaseAppLoggerWrites(t *testing.T) {
app, cleanup, err := initTestBaseApp() t.Parallel()
if err != nil {
app, _ := tests.NewTestApp()
defer app.Cleanup()
// reset
if err := app.DeleteOldLogs(time.Now()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer cleanup()
const logsThreshold = 200 const logsThreshold = 200
totalLogs := func(app App, t *testing.T) int { totalLogs := func(app core.App, t *testing.T) int {
var total int var total int
err := app.LogsDao().LogQuery().Select("count(*)").Row(&total) err := app.LogQuery().Select("count(*)").Row(&total)
if err != nil { if err != nil {
t.Fatalf("Failed to fetch total logs: %v", err) t.Fatalf("Failed to fetch total logs: %v", err)
} }
@ -338,106 +304,9 @@ func TestBaseAppLoggerWrites(t *testing.T) {
t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total) t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total)
} }
}) })
t.Run("test batch logs delete", func(t *testing.T) {
app.Settings().Logs.MaxDays = 2
deleteQueries := 0
// reset
app.Store().Set("lastLogsDeletedAt", time.Now())
if err := app.LogsDao().DeleteOldLogs(time.Now()); err != nil {
t.Fatal(err)
}
db := app.LogsDao().NonconcurrentDB().(*dbx.DB)
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
if strings.Contains(sql, "DELETE") {
deleteQueries++
}
}
// trigger batch write (A)
expectedLogs := logsThreshold
for i := 0; i < expectedLogs; i++ {
app.Logger().Error("testA")
}
if total := totalLogs(app, t); total != expectedLogs {
t.Fatalf("[batch write A] Expected %d logs, got %d", expectedLogs, total)
}
// mark the A inserted logs as 2-day expired
aExpiredDate, err := types.ParseDateTime(time.Now().AddDate(0, 0, -2))
if err != nil {
t.Fatal(err)
}
_, err = app.LogsDao().NonconcurrentDB().NewQuery("UPDATE _logs SET created={:date}, updated={:date}").Bind(dbx.Params{
"date": aExpiredDate.String(),
}).Execute()
if err != nil {
t.Fatalf("Failed to mock logs timestamp fields: %v", err)
}
// simulate recently deleted logs
app.Store().Set("lastLogsDeletedAt", time.Now().Add(-5*time.Hour))
// trigger batch write (B)
for i := 0; i < logsThreshold; i++ {
app.Logger().Error("testB")
}
expectedLogs = 2 * logsThreshold
// note: even though there are expired logs it shouldn't perform the delete operation because of the lastLogsDeledAt time
if total := totalLogs(app, t); total != expectedLogs {
t.Fatalf("[batch write B] Expected %d logs, got %d", expectedLogs, total)
}
// mark the B inserted logs as 1-day expired to ensure that they will not be deleted
bExpiredDate, err := types.ParseDateTime(time.Now().AddDate(0, 0, -1))
if err != nil {
t.Fatal(err)
}
_, err = app.LogsDao().NonconcurrentDB().NewQuery("UPDATE _logs SET created={:date}, updated={:date} where message='testB'").Bind(dbx.Params{
"date": bExpiredDate.String(),
}).Execute()
if err != nil {
t.Fatalf("Failed to mock logs timestamp fields: %v", err)
}
// should trigger delete on the next batch write
app.Store().Set("lastLogsDeletedAt", time.Now().Add(-6*time.Hour))
// trigger batch write (C)
for i := 0; i < logsThreshold; i++ {
app.Logger().Error("testC")
}
expectedLogs = 2 * logsThreshold // only B and C logs should remain
if total := totalLogs(app, t); total != expectedLogs {
t.Fatalf("[batch write C] Expected %d logs, got %d", expectedLogs, total)
}
if deleteQueries != 1 {
t.Fatalf("Expected DeleteOldLogs to be called %d, got %d", 1, deleteQueries)
}
})
} }
func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) { func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
scenarios := []struct { scenarios := []struct {
name string name string
isDev bool isDev bool
@ -469,173 +338,35 @@ func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
for _, s := range scenarios { for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) { t.Run(s.name, func(t *testing.T) {
app.isDev = s.isDev const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
IsDev: s.isDev,
})
defer app.ResetBootstrapState()
if err := app.Bootstrap(); err != nil {
t.Fatal(err)
}
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
app.Settings().Logs.MinLevel = s.level app.Settings().Logs.MinLevel = s.level
if err := app.Dao().SaveSettings(app.Settings()); err != nil { if err := app.Save(app.Settings()); err != nil {
t.Fatalf("Failed to save settings: %v", err) t.Fatalf("Failed to save settings: %v", err)
} }
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh app settings: %v", err)
}
for level, enabled := range s.expectations { for level, enabled := range s.expectations {
if v := handler.Enabled(nil, slog.Level(level)); v != enabled { if v := handler.Enabled(context.Background(), slog.Level(level)); v != enabled {
t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v) t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v)
} }
} }
}) })
} }
} }
func TestBaseAppLoggerLevelDevPrint(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
testLogLevel := 4
app.Settings().Logs.MinLevel = testLogLevel
if err := app.Dao().SaveSettings(app.Settings()); err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
isDev bool
levels []int
printedLevels []int
persistedLevels []int
}{
{
"dev mode",
true,
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{testLogLevel, testLogLevel + 1},
},
{
"nondev mode",
false,
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{},
[]int{testLogLevel, testLogLevel + 1},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
var printedLevels []int
var persistedLevels []int
app.isDev = s.isDev
// trigger slog handler min level refresh
if err := app.RefreshSettings(); err != nil {
t.Fatal(err)
}
// track printed logs
originalPrintLog := printLog
defer func() {
printLog = originalPrintLog
}()
printLog = func(log *logger.Log) {
printedLevels = append(printedLevels, int(log.Level))
}
// track persisted logs
app.LogsDao().AfterCreateFunc = func(eventDao *daos.Dao, m models.Model) error {
l, ok := m.(*models.Log)
if ok {
persistedLevels = append(persistedLevels, l.Level)
}
return nil
}
// write and persist logs
for _, l := range s.levels {
app.Logger().Log(nil, slog.Level(l), "test")
}
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
if err := handler.WriteAll(nil); err != nil {
t.Fatalf("Failed to write all logs: %v", err)
}
// check persisted log levels
if len(s.persistedLevels) != len(persistedLevels) {
t.Fatalf("Expected persisted levels \n%v\ngot\n%v", s.persistedLevels, persistedLevels)
}
for _, l := range persistedLevels {
if !list.ExistInSlice(l, s.persistedLevels) {
t.Fatalf("Missing expected persisted level %v in %v", l, persistedLevels)
}
}
// check printed log levels
if len(s.printedLevels) != len(printedLevels) {
t.Fatalf("Expected printed levels \n%v\ngot\n%v", s.printedLevels, printedLevels)
}
for _, l := range printedLevels {
if !list.ExistInSlice(l, s.printedLevels) {
t.Fatalf("Missing expected printed level %v in %v", l, printedLevels)
}
}
})
}
}
// -------------------------------------------------------------------
// note: make sure to call `defer cleanup()` when the app is no longer needed.
func initTestBaseApp() (app *BaseApp, cleanup func(), err error) {
testDataDir, err := os.MkdirTemp("", "test_base_app")
if err != nil {
return nil, nil, err
}
cleanup = func() {
os.RemoveAll(testDataDir)
}
app = NewBaseApp(BaseAppConfig{
DataDir: testDataDir,
})
initErr := func() error {
if err := app.Bootstrap(); err != nil {
return fmt.Errorf("bootstrap error: %w", err)
}
logsRunner, err := migrate.NewRunner(app.LogsDB(), logs.LogsMigrations)
if err != nil {
return fmt.Errorf("logsRunner error: %w", err)
}
if _, err := logsRunner.Up(); err != nil {
return fmt.Errorf("logsRunner migrations execution error: %w", err)
}
dataRunner, err := migrate.NewRunner(app.DB(), migrations.AppMigrations)
if err != nil {
return fmt.Errorf("logsRunner error: %w", err)
}
if _, err := dataRunner.Up(); err != nil {
return fmt.Errorf("dataRunner migrations execution error: %w", err)
}
return nil
}()
if initErr != nil {
cleanup()
return nil, nil, initErr
}
return app, cleanup, nil
}

194
core/collection_import.go Normal file
View File

@ -0,0 +1,194 @@
package core
import (
"cmp"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/spf13/cast"
)
// ImportCollectionsByMarshaledJSON is the same as [ImportCollections]
// but accept marshaled json array as import data (usually used for the autogenerated snapshots).
func (app *BaseApp) ImportCollectionsByMarshaledJSON(rawSliceOfMaps []byte, deleteMissing bool) error {
data := []map[string]any{}
err := json.Unmarshal(rawSliceOfMaps, &data)
if err != nil {
return err
}
return app.ImportCollections(data, deleteMissing)
}
// ImportCollections imports the provided collections data in a single transaction.
//
// For existing matching collections, the imported data is unmarshaled on top of the existing model.
//
// NB! If deleteMissing is true, ALL NON-SYSTEM COLLECTIONS AND SCHEMA FIELDS,
// that are not present in the imported configuration, WILL BE DELETED
// (this includes their related records data).
func (app *BaseApp) ImportCollections(toImport []map[string]any, deleteMissing bool) error {
if len(toImport) == 0 {
// prevent accidentally deleting all collections
return errors.New("no collections to import")
}
importedCollections := make([]*Collection, len(toImport))
mappedImported := make(map[string]*Collection, len(toImport))
// normalize imported collections data to ensure that all
// collection fields are present and properly initialized
for i, data := range toImport {
var imported *Collection
identifier := cast.ToString(data["id"])
if identifier == "" {
identifier = cast.ToString(data["name"])
}
existing, err := app.FindCollectionByNameOrId(identifier)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if existing != nil {
// refetch for deep copy
imported, err = app.FindCollectionByNameOrId(existing.Id)
if err != nil {
return err
}
// ensure that the fields will be cleared
if data["fields"] == nil && deleteMissing {
data["fields"] = []map[string]any{}
}
rawData, err := json.Marshal(data)
if err != nil {
return err
}
// load the imported data
err = json.Unmarshal(rawData, imported)
if err != nil {
return err
}
// extend with the existing fields if necessary
for _, f := range existing.Fields {
if !f.GetSystem() && deleteMissing {
continue
}
if imported.Fields.GetById(f.GetId()) == nil {
imported.Fields.Add(f)
}
}
} else {
imported = &Collection{}
rawData, err := json.Marshal(data)
if err != nil {
return err
}
// load the imported data
err = json.Unmarshal(rawData, imported)
if err != nil {
return err
}
}
imported.IntegrityChecks(false)
importedCollections[i] = imported
mappedImported[imported.Id] = imported
}
// reorder views last since the view query could depend on some of the other collections
slices.SortStableFunc(importedCollections, func(a, b *Collection) int {
cmpA := -1
if a.IsView() {
cmpA = 1
}
cmpB := -1
if b.IsView() {
cmpB = 1
}
res := cmp.Compare(cmpA, cmpB)
if res == 0 {
res = a.Created.Compare(b.Created)
if res == 0 {
res = a.Updated.Compare(b.Updated)
}
}
return res
})
return app.RunInTransaction(func(txApp App) error {
existingCollections := []*Collection{}
if err := txApp.CollectionQuery().OrderBy("updated ASC").All(&existingCollections); err != nil {
return err
}
mappedExisting := make(map[string]*Collection, len(existingCollections))
for _, existing := range existingCollections {
existing.IntegrityChecks(false)
mappedExisting[existing.Id] = existing
}
// delete old collections not available in the new configuration
// (before saving the imports in case a deleted collection name is being reused)
if deleteMissing {
for _, existing := range existingCollections {
if mappedImported[existing.Id] != nil || existing.System {
continue // exist or system
}
// delete collection
if err := txApp.Delete(existing); err != nil {
return err
}
}
}
// upsert imported collections
for _, imported := range importedCollections {
if err := txApp.SaveNoValidate(imported); err != nil {
return fmt.Errorf("failed to save collection %q: %w", imported.Name, err)
}
}
// run validations
for _, imported := range importedCollections {
original := mappedExisting[imported.Id]
if original == nil {
original = imported
}
validator := newCollectionValidator(
context.Background(),
txApp,
imported,
original,
)
if err := validator.run(); err != nil {
// serialize the validation error(s)
serializedErr, _ := json.MarshalIndent(err, "", " ")
return validation.Errors{"collections": validation.NewError(
"validation_collections_import_failure",
fmt.Sprintf("Data validations failed for collection %q (%s):\n%s", imported.Name, imported.Id, serializedErr),
)}
}
}
return nil
})
}

View File

@ -0,0 +1,476 @@
package core_test
import (
"encoding/json"
"strings"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestImportCollections(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
var regularCollections []*core.Collection
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(&regularCollections)
if err != nil {
t.Fatal(err)
}
var systemCollections []*core.Collection
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
if err != nil {
t.Fatal(err)
}
totalRegularCollections := len(regularCollections)
totalSystemCollections := len(systemCollections)
totalCollections := totalRegularCollections + totalSystemCollections
scenarios := []struct {
name string
data []map[string]any
deleteMissing bool
expectError bool
expectCollectionsCount int
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
}{
{
name: "empty collections",
data: []map[string]any{},
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import (with missing system fields)",
data: []map[string]any{
{"name": "import_test1", "type": "auth"},
{
"name": "import_test2", "fields": []map[string]any{
{"name": "test", "type": "text"},
},
},
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: totalCollections + 2,
},
{
name: "minimal collection import (trigger collection model validations)",
data: []map[string]any{
{"name": ""},
{
"name": "import_test2", "fields": []map[string]any{
{"name": "test", "type": "text"},
},
},
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import (trigger field settings validation)",
data: []map[string]any{
{"name": "import_test", "fields": []map[string]any{{"name": "test", "type": "text", "min": -1}}},
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete (system collections delete should be ignored)",
data: []map[string]any{
{
"id": "wsmn24bux7wo113",
"name": "demo",
"fields": []map[string]any{
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": nil,
"pattern": "",
},
},
"indexes": []string{},
},
{
"name": "import1",
"fields": []map[string]any{
{
"name": "active",
"type": "bool",
},
},
},
},
deleteMissing: true,
expectError: false,
expectCollectionsCount: totalSystemCollections + 2,
},
{
name: "test with deleteMissing: false",
data: []map[string]any{
{
// "id": "wsmn24bux7wo113", // test update with only name as identifier
"name": "demo1",
"fields": []map[string]any{
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": nil,
"pattern": "",
},
{
"id": "_2hlxbmp",
"name": "field_with_duplicate_id",
"type": "text",
"system": false,
"required": true,
"unique": false,
"min": 4,
"max": nil,
"pattern": "",
},
{
"id": "abcd_import",
"name": "new_field",
"type": "text",
},
},
},
{
"name": "new_import",
"fields": []map[string]any{
{
"id": "abcd_import",
"name": "active",
"type": "bool",
},
},
},
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: totalCollections + 1,
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*core.Collection) {
expectedCollectionFields := map[string]int{
core.CollectionNameAuthOrigins: 6,
"nologin": 10,
"demo1": 18,
"demo2": 5,
"demo3": 5,
"demo4": 16,
"demo5": 9,
"new_import": 2,
}
for name, expectedCount := range expectedCollectionFields {
collection, err := testApp.FindCollectionByNameOrId(name)
if err != nil {
t.Fatal(err)
}
if totalFields := len(collection.Fields); totalFields != expectedCount {
t.Errorf("Expected %d %q fields, got %d", expectedCount, collection.Name, totalFields)
}
}
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollections(s.data, s.deleteMissing)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
// check collections count
collections := []*core.Collection{}
if err := testApp.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
if len(collections) != s.expectCollectionsCount {
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
}
if s.afterTestFunc != nil {
s.afterTestFunc(testApp, collections)
}
})
}
}
func TestImportCollectionsByMarshaledJSON(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
var regularCollections []*core.Collection
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(&regularCollections)
if err != nil {
t.Fatal(err)
}
var systemCollections []*core.Collection
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
if err != nil {
t.Fatal(err)
}
totalRegularCollections := len(regularCollections)
totalSystemCollections := len(systemCollections)
totalCollections := totalRegularCollections + totalSystemCollections
scenarios := []struct {
name string
data string
deleteMissing bool
expectError bool
expectCollectionsCount int
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
}{
{
name: "invalid json array",
data: `{"test":123}`,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete (system collections delete should be ignored)",
data: `[
{
"id": "wsmn24bux7wo113",
"name": "demo",
"fields": [
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": null,
"pattern": ""
}
],
"indexes": []
},
{
"name": "import1",
"fields": [
{
"name": "active",
"type": "bool"
}
]
}
]`,
deleteMissing: true,
expectError: false,
expectCollectionsCount: totalSystemCollections + 2,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollectionsByMarshaledJSON([]byte(s.data), s.deleteMissing)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
// check collections count
collections := []*core.Collection{}
if err := testApp.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
if len(collections) != s.expectCollectionsCount {
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
}
if s.afterTestFunc != nil {
s.afterTestFunc(testApp, collections)
}
})
}
}
func TestImportCollectionsUpdateRules(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
data map[string]any
deleteMissing bool
}{
{
"extend existing by name (without deleteMissing)",
map[string]any{"name": "clients", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
false,
},
{
"extend existing by id (without deleteMissing)",
map[string]any{"id": "v851q4r790rhknl", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
false,
},
{
"extend with delete missing",
map[string]any{
"id": "v851q4r790rhknl",
"authToken": map[string]any{"duration": 100},
"fields": []map[string]any{{"name": "test", "type": "text"}},
"passwordAuth": map[string]any{"identityFields": []string{"email"}},
"indexes": []string{
// min required system fields indexes
"CREATE UNIQUE INDEX `_v851q4r790rhknl_email_idx` ON `clients` (email) WHERE email != ''",
"CREATE UNIQUE INDEX `_v851q4r790rhknl_tokenKey_idx` ON `clients` (tokenKey)",
},
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
beforeCollection, err := testApp.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
err = testApp.ImportCollections([]map[string]any{s.data}, s.deleteMissing)
if err != nil {
t.Fatal(err)
}
afterCollection, err := testApp.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
if afterCollection.AuthToken.Duration != 100 {
t.Fatalf("Expected AuthToken duration to be %d, got %d", 100, afterCollection.AuthToken.Duration)
}
if beforeCollection.AuthToken.Secret != afterCollection.AuthToken.Secret {
t.Fatalf("Expected AuthToken secrets to remain the same, got\n%q\nVS\n%q", beforeCollection.AuthToken.Secret, afterCollection.AuthToken.Secret)
}
if beforeCollection.Name != afterCollection.Name {
t.Fatalf("Expected Name to remain the same, got\n%q\nVS\n%q", beforeCollection.Name, afterCollection.Name)
}
if beforeCollection.Id != afterCollection.Id {
t.Fatalf("Expected Id to remain the same, got\n%q\nVS\n%q", beforeCollection.Id, afterCollection.Id)
}
if !s.deleteMissing {
totalExpectedFields := len(beforeCollection.Fields) + 1
if v := len(afterCollection.Fields); v != totalExpectedFields {
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
}
if afterCollection.Fields.GetByName("test") == nil {
t.Fatalf("Missing new field %q", "test")
}
// ensure that the old fields still exist
oldFields := beforeCollection.Fields.FieldNames()
for _, name := range oldFields {
if afterCollection.Fields.GetByName(name) == nil {
t.Fatalf("Missing expected old field %q", name)
}
}
} else {
totalExpectedFields := 1
for _, f := range beforeCollection.Fields {
if f.GetSystem() {
totalExpectedFields++
}
}
if v := len(afterCollection.Fields); v != totalExpectedFields {
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
}
if afterCollection.Fields.GetByName("test") == nil {
t.Fatalf("Missing new field %q", "test")
}
// ensure that the old system fields still exist
for _, f := range beforeCollection.Fields {
if f.GetSystem() && afterCollection.Fields.GetByName(f.GetName()) == nil {
t.Fatalf("Missing expected old field %q", f.GetName())
}
}
}
})
}
}
func TestImportCollectionsCreateRules(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollections([]map[string]any{
{"name": "new_test", "type": "auth", "authToken": map[string]any{"duration": 123}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
}, false)
if err != nil {
t.Fatal(err)
}
collection, err := testApp.FindCollectionByNameOrId("new_test")
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(collection)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
expectedParts := []string{
`"name":"new_test"`,
`"fields":[`,
`"name":"id"`,
`"name":"email"`,
`"name":"tokenKey"`,
`"name":"password"`,
`"name":"test"`,
`"indexes":[`,
`CREATE UNIQUE INDEX`,
`"duration":123`,
}
for _, part := range expectedParts {
if !strings.Contains(rawStr, part) {
t.Errorf("Missing %q in\n%s", part, rawStr)
}
}
}

949
core/collection_model.go Normal file
View File

@ -0,0 +1,949 @@
package core
import (
"encoding/json"
"fmt"
"strings"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
var (
_ Model = (*Collection)(nil)
_ DBExporter = (*Collection)(nil)
_ FilesManager = (*Collection)(nil)
)
const (
CollectionTypeBase = "base"
CollectionTypeAuth = "auth"
CollectionTypeView = "view"
)
const systemHookIdCollection = "__pbCollectionSystemHook__"
func (app *BaseApp) registerCollectionHooks() {
app.OnModelValidate().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionValidate().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelCreate().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionCreate().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelCreateExecute().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionCreateExecute().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionAfterCreateSuccess().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterCreateError().Bind(&hook.Handler[*ModelErrorEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelErrorEvent) error {
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
return me.App.OnCollectionAfterCreateError().Trigger(ce, func(ce *CollectionErrorEvent) error {
syncModelErrorEventWithCollectionErrorEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelUpdate().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionUpdate().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelUpdateExecute().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionUpdateExecute().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionAfterUpdateSuccess().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterUpdateError().Bind(&hook.Handler[*ModelErrorEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelErrorEvent) error {
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
return me.App.OnCollectionAfterUpdateError().Trigger(ce, func(ce *CollectionErrorEvent) error {
syncModelErrorEventWithCollectionErrorEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelDelete().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionDelete().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelDeleteExecute().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionDeleteExecute().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionAfterDeleteSuccess().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterDeleteError().Bind(&hook.Handler[*ModelErrorEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelErrorEvent) error {
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
return me.App.OnCollectionAfterDeleteError().Trigger(ce, func(ce *CollectionErrorEvent) error {
syncModelErrorEventWithCollectionErrorEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
// --------------------------------------------------------------
app.OnCollectionValidate().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionValidate,
Priority: 99,
})
app.OnCollectionCreate().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSave,
Priority: -99,
})
app.OnCollectionUpdate().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSave,
Priority: -99,
})
app.OnCollectionCreateExecute().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSaveExecute,
// execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
Priority: 99,
})
app.OnCollectionUpdateExecute().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSaveExecute,
Priority: 99, // execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
})
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionDeleteExecute,
Priority: 99, // execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
})
// reload cache on failure
// ---
onErrorReloadCachedCollections := func(ce *CollectionErrorEvent) error {
if err := ce.App.ReloadCachedCollections(); err != nil {
ce.App.Logger().Warn("Failed to reload collections cache", "error", err)
}
return ce.Next()
}
app.OnCollectionAfterCreateError().Bind(&hook.Handler[*CollectionErrorEvent]{
Id: systemHookIdCollection,
Func: onErrorReloadCachedCollections,
Priority: -99,
})
app.OnCollectionAfterUpdateError().Bind(&hook.Handler[*CollectionErrorEvent]{
Id: systemHookIdCollection,
Func: onErrorReloadCachedCollections,
Priority: -99,
})
app.OnCollectionAfterDeleteError().Bind(&hook.Handler[*CollectionErrorEvent]{
Id: systemHookIdCollection,
Func: onErrorReloadCachedCollections,
Priority: -99,
})
// ---
app.OnBootstrap().Bind(&hook.Handler[*BootstrapEvent]{
Id: systemHookIdCollection,
Func: func(e *BootstrapEvent) error {
if err := e.Next(); err != nil {
return err
}
if err := e.App.ReloadCachedCollections(); err != nil {
return fmt.Errorf("failed to load collections cache: %w", err)
}
return nil
},
Priority: 99, // execute as latest as possible
})
}
// @todo experiment eventually replacing the rules *string with a struct?
type baseCollection struct {
BaseModel
disableIntegrityChecks bool
ListRule *string `db:"listRule" json:"listRule" form:"listRule"`
ViewRule *string `db:"viewRule" json:"viewRule" form:"viewRule"`
CreateRule *string `db:"createRule" json:"createRule" form:"createRule"`
UpdateRule *string `db:"updateRule" json:"updateRule" form:"updateRule"`
DeleteRule *string `db:"deleteRule" json:"deleteRule" form:"deleteRule"`
// RawOptions represents the raw serialized collection option loaded from the DB.
// NB! This field shouldn't be modified manually. It is automatically updated
// with the collection type specific option before save.
RawOptions types.JSONRaw `db:"options" json:"-" xml:"-" form:"-"`
Name string `db:"name" json:"name" form:"name"`
Type string `db:"type" json:"type" form:"type"`
Fields FieldsList `db:"fields" json:"fields" form:"fields"`
Indexes types.JSONArray[string] `db:"indexes" json:"indexes" form:"indexes"`
System bool `db:"system" json:"system" form:"system"`
Created types.DateTime `db:"created" json:"created"`
Updated types.DateTime `db:"updated" json:"updated"`
}
// Collection defines the table, fields and various options related to a set of records.
type Collection struct {
baseCollection
collectionAuthOptions
collectionViewOptions
}
// NewCollection initializes and returns a new Collection model with the specified type and name.
func NewCollection(typ, name string) *Collection {
switch typ {
case CollectionTypeAuth:
return NewAuthCollection(name)
case CollectionTypeView:
return NewViewCollection(name)
default:
return NewBaseCollection(name)
}
}
// NewBaseCollection initializes and returns a new "base" Collection model.
func NewBaseCollection(name string) *Collection {
m := &Collection{}
m.Name = name
m.Type = CollectionTypeBase
m.initDefaultId()
m.initDefaultFields()
return m
}
// NewViewCollection initializes and returns a new "view" Collection model.
func NewViewCollection(name string) *Collection {
m := &Collection{}
m.Name = name
m.Type = CollectionTypeView
m.initDefaultId()
m.initDefaultFields()
return m
}
// NewAuthCollection initializes and returns a new "auth" Collection model.
func NewAuthCollection(name string) *Collection {
m := &Collection{}
m.Name = name
m.Type = CollectionTypeAuth
m.initDefaultId()
m.initDefaultFields()
m.setDefaultAuthOptions()
return m
}
// TableName returns the Collection model SQL table name.
func (m *Collection) TableName() string {
return "_collections"
}
// BaseFilesPath returns the storage dir path used by the collection.
func (m *Collection) BaseFilesPath() string {
return m.Id
}
// IsBase checks if the current collection has "base" type.
func (m *Collection) IsBase() bool {
return m.Type == CollectionTypeBase
}
// IsAuth checks if the current collection has "auth" type.
func (m *Collection) IsAuth() bool {
return m.Type == CollectionTypeAuth
}
// IsView checks if the current collection has "view" type.
func (m *Collection) IsView() bool {
return m.Type == CollectionTypeView
}
// IntegrityChecks toggles the current collection integrity checks (ex. checking references on delete).
func (m *Collection) IntegrityChecks(enable bool) {
m.disableIntegrityChecks = !enable
}
// PostScan implements the [dbx.PostScanner] interface to auto unmarshal
// the raw serialized options into the concrete type specific fields.
func (m *Collection) PostScan() error {
if err := m.BaseModel.PostScan(); err != nil {
return err
}
return m.unmarshalRawOptions()
}
func (m *Collection) unmarshalRawOptions() error {
raw, err := m.RawOptions.MarshalJSON()
if err != nil {
return nil
}
switch m.Type {
case CollectionTypeView:
return json.Unmarshal(raw, &m.collectionViewOptions)
case CollectionTypeAuth:
return json.Unmarshal(raw, &m.collectionAuthOptions)
}
return nil
}
// UnmarshalJSON implements the [json.Unmarshaler] interface.
//
// For new/"blank" Collection models it replaces the model with a factory
// instance and then unmarshal the provided data one on top of it.
func (m *Collection) UnmarshalJSON(b []byte) error {
type alias *Collection
// initialize the default fields
// (e.g. in case the collection was NOT created using the designated factories)
if m.IsNew() && m.Type == "" {
minimal := &struct {
Type string `json:"type"`
Name string `json:"name"`
}{}
if err := json.Unmarshal(b, minimal); err != nil {
return err
}
blank := NewCollection(minimal.Type, minimal.Name)
*m = *blank
}
return json.Unmarshal(b, alias(m))
}
// MarshalJSON implements the [json.Marshaler] interface.
//
// Note that non-type related fields are ignored from the serialization
// (ex. for "view" colections the "auth" fields are skipped).
func (m Collection) MarshalJSON() ([]byte, error) {
switch m.Type {
case CollectionTypeView:
return json.Marshal(struct {
baseCollection
collectionViewOptions
}{m.baseCollection, m.collectionViewOptions})
case CollectionTypeAuth:
alias := struct {
baseCollection
collectionAuthOptions
}{m.baseCollection, m.collectionAuthOptions}
// ensure that it is always returned as array
if alias.OAuth2.Providers == nil {
alias.OAuth2.Providers = []OAuth2ProviderConfig{}
}
// hide secret keys from the serialization
alias.AuthToken.Secret = ""
alias.FileToken.Secret = ""
alias.PasswordResetToken.Secret = ""
alias.EmailChangeToken.Secret = ""
alias.VerificationToken.Secret = ""
for i := range alias.OAuth2.Providers {
alias.OAuth2.Providers[i].ClientSecret = ""
}
return json.Marshal(alias)
default:
return json.Marshal(m.baseCollection)
}
}
// String returns a string representation of the current collection.
func (m Collection) String() string {
raw, _ := json.Marshal(m)
return string(raw)
}
// DBExport prepares and exports the current collection data for db persistence.
func (m *Collection) DBExport(app App) (map[string]any, error) {
result := map[string]any{
"id": m.Id,
"type": m.Type,
"listRule": m.ListRule,
"viewRule": m.ViewRule,
"createRule": m.CreateRule,
"updateRule": m.UpdateRule,
"deleteRule": m.DeleteRule,
"name": m.Name,
"fields": m.Fields,
"indexes": m.Indexes,
"system": m.System,
"created": m.Created,
"updated": m.Updated,
"options": `{}`,
}
switch m.Type {
case CollectionTypeView:
if raw, err := types.ParseJSONRaw(m.collectionViewOptions); err == nil {
result["options"] = raw
} else {
return nil, err
}
case CollectionTypeAuth:
if raw, err := types.ParseJSONRaw(m.collectionAuthOptions); err == nil {
result["options"] = raw
} else {
return nil, err
}
}
return result, nil
}
// GetIndex returns s single Collection index expression by its name.
func (m *Collection) GetIndex(name string) string {
for _, idx := range m.Indexes {
if strings.EqualFold(dbutils.ParseIndex(idx).IndexName, name) {
return idx
}
}
return ""
}
// AddIndex adds a new index into the current collection.
//
// If the collection has an existing index matching the new name it will be replaced with the new one.
func (m *Collection) AddIndex(name string, unique bool, columnsExpr string, optWhereExpr string) {
m.RemoveIndex(name)
var idx strings.Builder
idx.WriteString("CREATE ")
if unique {
idx.WriteString("UNIQUE ")
}
idx.WriteString("INDEX `")
idx.WriteString(name)
idx.WriteString("` ")
idx.WriteString("ON `")
idx.WriteString(m.Name)
idx.WriteString("` (")
idx.WriteString(columnsExpr)
idx.WriteString(")")
if optWhereExpr != "" {
idx.WriteString(" WHERE ")
idx.WriteString(optWhereExpr)
}
m.Indexes = append(m.Indexes, idx.String())
}
// RemoveIndex removes a single index with the specified name from the current collection.
func (m *Collection) RemoveIndex(name string) {
for i, idx := range m.Indexes {
if strings.EqualFold(dbutils.ParseIndex(idx).IndexName, name) {
m.Indexes = append(m.Indexes[:i], m.Indexes[i+1:]...)
return
}
}
}
// delete hook
// -------------------------------------------------------------------
func onCollectionDeleteExecute(e *CollectionEvent) error {
if e.Collection.System {
return fmt.Errorf("[%s] system collections cannot be deleted", e.Collection.Name)
}
defer func() {
if err := e.App.ReloadCachedCollections(); err != nil {
e.App.Logger().Warn("Failed to reload collections cache", "error", err)
}
}()
if !e.Collection.disableIntegrityChecks {
// ensure that there aren't any existing references.
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
references, err := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
if err != nil {
return fmt.Errorf("[%s] failed to check collection references: %w", e.Collection.Name, err)
}
if total := len(references); total > 0 {
names := make([]string, 0, len(references))
for ref := range references {
names = append(names, ref.Name)
}
return fmt.Errorf("[%s] failed to delete due to existing relation references: %s", e.Collection.Name, strings.Join(names, ", "))
}
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
// delete the related view or records table
if e.Collection.IsView() {
if err := txApp.DeleteView(e.Collection.Name); err != nil {
return err
}
} else {
if err := txApp.DeleteTable(e.Collection.Name); err != nil {
return err
}
}
if !e.Collection.disableIntegrityChecks {
// trigger views resave to check for dependencies
if err := resaveViewsWithChangedFields(txApp, e.Collection.Id); err != nil {
return fmt.Errorf("[%s] failed to delete due to existing view dependency: %w", e.Collection.Name, err)
}
}
// delete
return e.Next()
})
e.App = originalApp
return txErr
}
// save hook
// -------------------------------------------------------------------
func (c *Collection) initDefaultId() {
if c.Id == "" && c.Name != "" {
c.Id = "_pbc_" + crc32Checksum(c.Name)
}
}
func (c *Collection) savePrepare() error {
if c.Type == "" {
c.Type = CollectionTypeBase
}
if c.IsNew() {
c.initDefaultId()
c.Created = types.NowDateTime()
}
c.Updated = types.NowDateTime()
// recreate the fields list to ensure that all normalizations
// like default field id are applied
c.Fields = NewFieldsList(c.Fields...)
c.initDefaultFields()
if c.IsAuth() {
c.unsetMissingOAuth2MappedFields()
}
return nil
}
func onCollectionSave(e *CollectionEvent) error {
if err := e.Collection.savePrepare(); err != nil {
return err
}
return e.Next()
}
func onCollectionSaveExecute(e *CollectionEvent) error {
defer func() {
if err := e.App.ReloadCachedCollections(); err != nil {
e.App.Logger().Warn("Failed to reload collections cache", "error", err)
}
}()
var oldCollection *Collection
if !e.Collection.IsNew() {
var err error
oldCollection, err = e.App.FindCachedCollectionByNameOrId(e.Collection.Id)
if err != nil {
return err
}
// invalidate previously issued auth tokens on auth rule change
if oldCollection.AuthRule != e.Collection.AuthRule &&
cast.ToString(oldCollection.AuthRule) != cast.ToString(e.Collection.AuthRule) {
e.Collection.AuthToken.Secret = security.RandomString(50)
}
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
isView := e.Collection.IsView()
// ensures that the view collection shema is properly loaded
if isView {
query := e.Collection.ViewQuery
// generate collection fields list from the query
viewFields, err := e.App.CreateViewFields(query)
if err != nil {
return err
}
// delete old renamed view
if oldCollection != nil {
if err := e.App.DeleteView(oldCollection.Name); err != nil {
return err
}
}
// wrap view query if necessary
query, err = normalizeViewQueryId(e.App, query)
if err != nil {
return fmt.Errorf("failed to normalize view query id: %w", err)
}
// (re)create the view
if err := e.App.SaveView(e.Collection.Name, query); err != nil {
return err
}
// updates newCollection.Fields based on the generated view table info and query
e.Collection.Fields = viewFields
}
// save the Collection model
if err := e.Next(); err != nil {
return err
}
// sync the changes with the related records table
if !isView {
if err := e.App.SyncRecordTableSchema(e.Collection, oldCollection); err != nil {
// note: don't wrap to allow propagating indexes validation.Errors
return err
}
}
return nil
})
e.App = originalApp
if txErr != nil {
return txErr
}
// trigger an update for all views with changed fields as a result of the current collection save
// (ignoring view errors to allow users to update the query from the UI)
resaveViewsWithChangedFields(e.App, e.Collection.Id)
return nil
}
func (m *Collection) initDefaultFields() {
switch m.Type {
case CollectionTypeBase:
m.initIdField()
case CollectionTypeAuth:
m.initIdField()
m.initPasswordField()
m.initTokenKeyField()
m.initEmailField()
m.initEmailVisibilityField()
m.initVerifiedField()
case CollectionTypeView:
// view fields are autogenerated
}
}
func (m *Collection) initIdField() {
field, _ := m.Fields.GetByName(FieldNameId).(*TextField)
if field == nil {
// create default field
field = &TextField{
Name: FieldNameId,
System: true,
PrimaryKey: true,
Required: true,
Min: 15,
Max: 15,
Pattern: `^[a-z0-9]+$`,
AutogeneratePattern: `[a-z0-9]{15}`,
}
// prepend it
m.Fields = NewFieldsList(append([]Field{field}, m.Fields...)...)
} else {
// enforce system defaults
field.System = true
field.Required = true
field.PrimaryKey = true
field.Hidden = false
}
}
func (m *Collection) initPasswordField() {
field, _ := m.Fields.GetByName(FieldNamePassword).(*PasswordField)
if field == nil {
// load default field
m.Fields.Add(&PasswordField{
Name: FieldNamePassword,
System: true,
Hidden: true,
Required: true,
Min: 8,
})
} else {
// enforce system defaults
field.System = true
field.Hidden = true
field.Required = true
}
}
func (m *Collection) initTokenKeyField() {
field, _ := m.Fields.GetByName(FieldNameTokenKey).(*TextField)
if field == nil {
// load default field
m.Fields.Add(&TextField{
Name: FieldNameTokenKey,
System: true,
Hidden: true,
Min: 30,
Max: 60,
Required: true,
AutogeneratePattern: `[a-zA-Z0-9]{50}`,
})
} else {
// enforce system defaults
field.System = true
field.Hidden = true
field.Required = true
}
// ensure that there is a unique index for the field
if !dbutils.HasSingleColumnUniqueIndex(FieldNameTokenKey, m.Indexes) {
m.Indexes = append(m.Indexes, fmt.Sprintf(
"CREATE UNIQUE INDEX `%s` ON `%s` (`%s`)",
m.fieldIndexName(FieldNameTokenKey),
m.Name,
FieldNameTokenKey,
))
}
}
func (m *Collection) initEmailField() {
field, _ := m.Fields.GetByName(FieldNameEmail).(*EmailField)
if field == nil {
// load default field
m.Fields.Add(&EmailField{
Name: FieldNameEmail,
System: true,
Required: true,
})
} else {
// enforce system defaults
field.System = true
field.Hidden = false // managed by the emailVisibility flag
}
// ensure that there is a unique index for the email field
if !dbutils.HasSingleColumnUniqueIndex(FieldNameEmail, m.Indexes) {
m.Indexes = append(m.Indexes, fmt.Sprintf(
"CREATE UNIQUE INDEX `%s` ON `%s` (`%s`) WHERE `%s` != ''",
m.fieldIndexName(FieldNameEmail),
m.Name,
FieldNameEmail,
FieldNameEmail,
))
}
}
func (m *Collection) initEmailVisibilityField() {
field, _ := m.Fields.GetByName(FieldNameEmailVisibility).(*BoolField)
if field == nil {
// load default field
m.Fields.Add(&BoolField{
Name: FieldNameEmailVisibility,
System: true,
})
} else {
// enforce system defaults
field.System = true
}
}
func (m *Collection) initVerifiedField() {
field, _ := m.Fields.GetByName(FieldNameVerified).(*BoolField)
if field == nil {
// load default field
m.Fields.Add(&BoolField{
Name: FieldNameVerified,
System: true,
})
} else {
// enforce system defaults
field.System = true
}
}
func (m *Collection) fieldIndexName(field string) string {
name := "idx_" + field + "_"
if m.Id != "" {
name += m.Id
} else if m.Name != "" {
name += m.Name
} else {
name += security.PseudorandomString(10)
}
if len(name) > 64 {
return name[:64]
}
return name
}

View File

@ -0,0 +1,535 @@
package core
import (
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func (m *Collection) unsetMissingOAuth2MappedFields() {
if !m.IsAuth() {
return
}
if m.OAuth2.MappedFields.Id != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Id) == nil {
m.OAuth2.MappedFields.Id = ""
}
}
if m.OAuth2.MappedFields.Name != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Name) == nil {
m.OAuth2.MappedFields.Name = ""
}
}
if m.OAuth2.MappedFields.Username != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Username) == nil {
m.OAuth2.MappedFields.Username = ""
}
}
if m.OAuth2.MappedFields.AvatarURL != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.AvatarURL) == nil {
m.OAuth2.MappedFields.AvatarURL = ""
}
}
}
func (m *Collection) setDefaultAuthOptions() {
m.collectionAuthOptions = collectionAuthOptions{
VerificationTemplate: defaultVerificationTemplate,
ResetPasswordTemplate: defaultResetPasswordTemplate,
ConfirmEmailChangeTemplate: defaultConfirmEmailChangeTemplate,
AuthRule: types.Pointer(""),
AuthAlert: AuthAlertConfig{
Enabled: true,
EmailTemplate: defaultAuthAlertTemplate,
},
PasswordAuth: PasswordAuthConfig{
Enabled: true,
IdentityFields: []string{FieldNameEmail},
},
MFA: MFAConfig{
Enabled: false,
Duration: 1800, // 30min
},
OTP: OTPConfig{
Enabled: false,
Duration: 180, // 3min
Length: 8,
EmailTemplate: defaultOTPTemplate,
},
AuthToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 604800, // 7 days
},
PasswordResetToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 1800, // 30min
},
EmailChangeToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 1800, // 30min
},
VerificationToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 259200, // 3days
},
FileToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 180, // 3min
},
}
}
var _ optionsValidator = (*collectionAuthOptions)(nil)
// collectionAuthOptions defines the options for the "auth" type collection.
type collectionAuthOptions struct {
// AuthRule could be used to specify additional record constraints
// applied after record authentication and right before returning the
// auth token response to the client.
//
// For example, to allow only verified users you could set it to
// "verified = true".
//
// Set it to empty string to allow any Auth collection record to authenticate.
//
// Set it to nil to disallow authentication altogether for the collection
// (that includes password, OAuth2, etc.).
AuthRule *string `form:"authRule" json:"authRule"`
// ManageRule gives admin-like permissions to allow fully managing
// the auth record(s), eg. changing the password without requiring
// to enter the old one, directly updating the verified state and email, etc.
//
// This rule is executed in addition to the Create and Update API rules.
ManageRule *string `form:"manageRule" json:"manageRule"`
// AuthAlert defines options related to the auth alerts on new device login.
AuthAlert AuthAlertConfig `form:"authAlert" json:"authAlert"`
// OAuth2 specifies whether OAuth2 auth is enabled for the collection
// and which OAuth2 providers are allowed.
OAuth2 OAuth2Config `form:"oauth2" json:"oauth2"`
PasswordAuth PasswordAuthConfig `form:"passwordAuth" json:"passwordAuth"`
MFA MFAConfig `form:"mfa" json:"mfa"`
OTP OTPConfig `form:"otp" json:"otp"`
// Various token configurations
// ---
AuthToken TokenConfig `form:"authToken" json:"authToken"`
PasswordResetToken TokenConfig `form:"passwordResetToken" json:"passwordResetToken"`
EmailChangeToken TokenConfig `form:"emailChangeToken" json:"emailChangeToken"`
VerificationToken TokenConfig `form:"verificationToken" json:"verificationToken"`
FileToken TokenConfig `form:"fileToken" json:"fileToken"`
// default email templates
// ---
VerificationTemplate EmailTemplate `form:"verificationTemplate" json:"verificationTemplate"`
ResetPasswordTemplate EmailTemplate `form:"resetPasswordTemplate" json:"resetPasswordTemplate"`
ConfirmEmailChangeTemplate EmailTemplate `form:"confirmEmailChangeTemplate" json:"confirmEmailChangeTemplate"`
}
func (o *collectionAuthOptions) validate(cv *collectionValidator) error {
err := validation.ValidateStruct(o,
validation.Field(
&o.AuthRule,
validation.By(cv.checkRule),
validation.By(cv.ensureNoSystemRuleChange(cv.original.AuthRule)),
),
validation.Field(
&o.ManageRule,
validation.NilOrNotEmpty,
validation.By(cv.checkRule),
validation.By(cv.ensureNoSystemRuleChange(cv.original.ManageRule)),
),
validation.Field(&o.AuthAlert),
validation.Field(&o.PasswordAuth),
validation.Field(&o.OAuth2),
validation.Field(&o.OTP),
validation.Field(&o.MFA),
validation.Field(&o.AuthToken),
validation.Field(&o.PasswordResetToken),
validation.Field(&o.EmailChangeToken),
validation.Field(&o.VerificationToken),
validation.Field(&o.FileToken),
validation.Field(&o.VerificationTemplate, validation.Required),
validation.Field(&o.ResetPasswordTemplate, validation.Required),
validation.Field(&o.ConfirmEmailChangeTemplate, validation.Required),
)
if err != nil {
return err
}
if o.MFA.Enabled {
// if MFA is enabled require at least 2 auth methods
//
// @todo maybe consider disabling the check because if custom auth methods
// are registered it may fail since we don't have mechanism to detect them at the moment
authsEnabled := 0
if o.PasswordAuth.Enabled {
authsEnabled++
}
if o.OAuth2.Enabled {
authsEnabled++
}
if o.OTP.Enabled {
authsEnabled++
}
if authsEnabled < 2 {
return validation.Errors{
"mfa": validation.Errors{
"enabled": validation.NewError("validation_mfa_not_enough_auths", "MFA requires at least 2 auth methods to be enabled."),
},
}
}
if o.MFA.Rule != "" {
mfaRuleValidators := []validation.RuleFunc{
cv.checkRule,
cv.ensureNoSystemRuleChange(&cv.original.MFA.Rule),
}
for _, validator := range mfaRuleValidators {
err := validator(&o.MFA.Rule)
if err != nil {
return validation.Errors{
"mfa": validation.Errors{
"rule": err,
},
}
}
}
}
}
// extra check to ensure that only unique identity fields are used
if o.PasswordAuth.Enabled {
err = validation.Validate(o.PasswordAuth.IdentityFields, validation.By(cv.checkFieldsForUniqueIndex))
if err != nil {
return validation.Errors{
"passwordAuth": validation.Errors{
"identityFields": err,
},
}
}
}
return nil
}
// -------------------------------------------------------------------
type EmailTemplate struct {
Subject string `form:"subject" json:"subject"`
Body string `form:"body" json:"body"`
}
// Validate makes EmailTemplate validatable by implementing [validation.Validatable] interface.
func (t EmailTemplate) Validate() error {
return validation.ValidateStruct(&t,
validation.Field(&t.Subject, validation.Required),
validation.Field(&t.Body, validation.Required),
)
}
// Resolve replaces the placeholder parameters in the current email
// template and returns its components as ready-to-use strings.
func (t EmailTemplate) Resolve(placeholders map[string]any) (subject, body string) {
body = t.Body
subject = t.Subject
for k, v := range placeholders {
vStr := cast.ToString(v)
// replace subject placeholder params (if any)
subject = strings.ReplaceAll(subject, k, vStr)
// replace body placeholder params (if any)
body = strings.ReplaceAll(body, k, vStr)
}
return subject, body
}
// -------------------------------------------------------------------
type AuthAlertConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
}
// Validate makes AuthAlertConfig validatable by implementing [validation.Validatable] interface.
func (c AuthAlertConfig) Validate() error {
return validation.ValidateStruct(&c,
// note: for now always run the email template validations even
// if not enabled since it could be used separately
validation.Field(&c.EmailTemplate),
)
}
// -------------------------------------------------------------------
type TokenConfig struct {
Secret string `form:"secret" json:"secret,omitempty"`
// Duration specifies how long an issued token to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
}
// Validate makes TokenConfig validatable by implementing [validation.Validatable] interface.
func (c TokenConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Secret, validation.Required, validation.Length(30, 255)),
validation.Field(&c.Duration, validation.Required, validation.Min(10), validation.Max(94670856)), // ~3y max
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c TokenConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type OTPConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// Duration specifies how long the OTP to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
// Length specifies the auto generated password length.
Length int `form:"length" json:"length"`
// EmailTemplate is the default OTP email template that will be send to the auth record.
//
// In addition to the system placeholders you can also make use of
// [core.EmailPlaceholderOTPId] and [core.EmailPlaceholderOTP].
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
}
// Validate makes OTPConfig validatable by implementing [validation.Validatable] interface.
func (c OTPConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
validation.Field(&c.Length, validation.When(c.Enabled, validation.Required, validation.Min(4))),
// note: for now always run the email template validations even
// if not enabled since it could be used separately
validation.Field(&c.EmailTemplate),
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c OTPConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type MFAConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// Duration specifies how long an issued MFA to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
// Rule is an optional field to restrict MFA only for the records that satisfy the rule.
//
// Leave it empty to enable MFA for everyone.
Rule string `form:"rule" json:"rule"`
}
// Validate makes MFAConfig validatable by implementing [validation.Validatable] interface.
func (c MFAConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c MFAConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type PasswordAuthConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// IdentityFields is a list of field names that could be used as
// identity during password authentication.
//
// Usually only fields that has single column UNIQUE index are accepted as values.
IdentityFields []string `form:"identityFields" json:"identityFields"`
}
// Validate makes PasswordAuthConfig validatable by implementing [validation.Validatable] interface.
func (c PasswordAuthConfig) Validate() error {
// strip duplicated values
c.IdentityFields = list.ToUniqueStringSlice(c.IdentityFields)
if !c.Enabled {
return nil // no need to validate
}
return validation.ValidateStruct(&c,
validation.Field(&c.IdentityFields, validation.Required),
)
}
// -------------------------------------------------------------------
type OAuth2KnownFields struct {
Id string `form:"id" json:"id"`
Name string `form:"name" json:"name"`
Username string `form:"username" json:"username"`
AvatarURL string `form:"avatarURL" json:"avatarURL"`
}
type OAuth2Config struct {
Providers []OAuth2ProviderConfig `form:"providers" json:"providers"`
MappedFields OAuth2KnownFields `form:"mappedFields" json:"mappedFields"`
Enabled bool `form:"enabled" json:"enabled"`
}
// GetProviderConfig returns the first OAuth2ProviderConfig that matches the specified name.
//
// Returns false and zero config if no such provider is available in c.Providers.
func (c OAuth2Config) GetProviderConfig(name string) (config OAuth2ProviderConfig, exists bool) {
for _, p := range c.Providers {
if p.Name == name {
return p, true
}
}
return
}
// Validate makes OAuth2Config validatable by implementing [validation.Validatable] interface.
func (c OAuth2Config) Validate() error {
if !c.Enabled {
return nil // no need to validate
}
return validation.ValidateStruct(&c,
// note: don't require providers for now as they could be externally registered/removed
validation.Field(&c.Providers, validation.By(checkForDuplicatedProviders)),
)
}
func checkForDuplicatedProviders(value any) error {
configs, _ := value.([]OAuth2ProviderConfig)
existing := map[string]struct{}{}
for i, c := range configs {
if c.Name == "" {
continue // the name nonempty state is validated separately
}
if _, ok := existing[c.Name]; ok {
return validation.Errors{
strconv.Itoa(i): validation.Errors{
"name": validation.NewError("validation_duplicated_provider", "The provider "+c.Name+" is already registered.").
SetParams(map[string]any{"name": c.Name}),
},
}
}
existing[c.Name] = struct{}{}
}
return nil
}
type OAuth2ProviderConfig struct {
// PKCE overwrites the default provider PKCE config option.
//
// This usually shouldn't be needed but some OAuth2 vendors, like the LinkedIn OIDC,
// may require manual adjustment due to returning error if extra parameters are added to the request
// (https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312)
PKCE *bool `form:"pkce" json:"pkce"`
Name string `form:"name" json:"name"`
ClientId string `form:"clientId" json:"clientId"`
ClientSecret string `form:"clientSecret" json:"clientSecret,omitempty"`
AuthURL string `form:"authURL" json:"authURL"`
TokenURL string `form:"tokenURL" json:"tokenURL"`
UserInfoURL string `form:"userInfoURL" json:"userInfoURL"`
DisplayName string `form:"displayName" json:"displayName"`
}
// Validate makes OAuth2ProviderConfig validatable by implementing [validation.Validatable] interface.
func (c OAuth2ProviderConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Name, validation.Required, validation.By(checkProviderName)),
validation.Field(&c.ClientId, validation.Required),
validation.Field(&c.ClientSecret, validation.Required),
validation.Field(&c.AuthURL, is.URL),
validation.Field(&c.TokenURL, is.URL),
validation.Field(&c.UserInfoURL, is.URL),
)
}
func checkProviderName(value any) error {
name, _ := value.(string)
if name == "" {
return nil // nothing to check
}
if _, err := auth.NewProviderByName(name); err != nil {
return validation.NewError("validation_missing_provider", "Invalid or missing provider with name "+name+".").
SetParams(map[string]any{"name": name})
}
return nil
}
// InitProvider returns a new auth.Provider instance loaded with the current OAuth2ProviderConfig options.
func (c OAuth2ProviderConfig) InitProvider() (auth.Provider, error) {
provider, err := auth.NewProviderByName(c.Name)
if err != nil {
return nil, err
}
if c.ClientId != "" {
provider.SetClientId(c.ClientId)
}
if c.ClientSecret != "" {
provider.SetClientSecret(c.ClientSecret)
}
if c.AuthURL != "" {
provider.SetAuthURL(c.AuthURL)
}
if c.UserInfoURL != "" {
provider.SetUserInfoURL(c.UserInfoURL)
}
if c.TokenURL != "" {
provider.SetTokenURL(c.TokenURL)
}
if c.DisplayName != "" {
provider.SetDisplayName(c.DisplayName)
}
if c.PKCE != nil {
provider.SetPKCE(*c.PKCE)
}
return provider, nil
}

Some files were not shown because too many files have changed in this diff Show More