merge v0.23.0-rc changes
This commit is contained in:
parent
ad92992324
commit
844f18cac3
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
|
||||
|
||||
All reports will be promptly addressed, and you'll be credited accordingly.
|
||||
All reports will be promptly addressed and you'll be credited in the fix release notes.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,14 @@ on:
|
|||
jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
# re-enable auto-snapshot from goreleaser-action@v3
|
||||
# (https://github.com/goreleaser/goreleaser-action-v4-auto-snapshot-example)
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
|
@ -16,12 +23,12 @@ jobs:
|
|||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20.11.0
|
||||
node-version: 20.17.0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '>=1.22.5'
|
||||
go-version: '>=1.23.0'
|
||||
|
||||
# This step usually is not needed because the /ui/dist is pregenerated locally
|
||||
# but its here to ensure that each release embeds the latest admin ui artifacts.
|
||||
|
|
@ -36,19 +43,14 @@ jobs:
|
|||
# - name: Generate jsvm types
|
||||
# run: go run ./plugins/jsvm/internal/types/types.go
|
||||
|
||||
# The prebuilt golangci-lint doesn't support go 1.18+ yet
|
||||
# https://github.com/golangci/golangci-lint/issues/2649
|
||||
# - name: Run linter
|
||||
# uses: golangci/golangci-lint-action@v3
|
||||
|
||||
- name: Run tests
|
||||
run: go test ./...
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v3
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: latest
|
||||
args: release --clean
|
||||
version: '~> v2'
|
||||
args: release --clean ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
version: 2
|
||||
|
||||
project_name: pocketbase
|
||||
|
||||
dist: .builds
|
||||
|
|
@ -58,7 +60,7 @@ checksum:
|
|||
name_template: 'checksums.txt'
|
||||
|
||||
snapshot:
|
||||
name_template: '{{ incpatch .Version }}-next'
|
||||
version_template: '{{ incpatch .Version }}-next'
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
## v0.23.0-RC (WIP)
|
||||
|
||||
...
|
||||
|
||||
|
||||
## v0.22.21
|
||||
|
||||
- Lock the logs database during backup to prevent `database disk image is malformed` errors in case there is a log write running in the background ([#5541](https://github.com/pocketbase/pocketbase/discussions/5541)).
|
||||
|
|
|
|||
25
README.md
25
README.md
|
|
@ -10,7 +10,7 @@
|
|||
<a href="https://pkg.go.dev/github.com/pocketbase/pocketbase" target="_blank" rel="noopener"><img src="https://godoc.org/github.com/pocketbase/pocketbase?status.svg" alt="Go package documentation" /></a>
|
||||
</p>
|
||||
|
||||
[PocketBase](https://pocketbase.io) is an open source Go backend, consisting of:
|
||||
[PocketBase](https://pocketbase.io) is an open source Go backend that includes:
|
||||
|
||||
- embedded database (_SQLite_) with **realtime subscriptions**
|
||||
- built-in **files and users management**
|
||||
|
|
@ -46,7 +46,7 @@ your own custom app specific business logic and still have a single portable exe
|
|||
|
||||
Here is a minimal example:
|
||||
|
||||
0. [Install Go 1.21+](https://go.dev/doc/install) (_if you haven't already_)
|
||||
0. [Install Go 1.23+](https://go.dev/doc/install) (_if you haven't already_)
|
||||
|
||||
1. Create a new project directory with the following `main.go` file inside it:
|
||||
```go
|
||||
|
|
@ -56,29 +56,20 @@ Here is a minimal example:
|
|||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := pocketbase.New()
|
||||
|
||||
app.OnBeforeServe().Add(func(e *core.ServeEvent) error {
|
||||
// add new "GET /hello" route to the app router (echo)
|
||||
e.Router.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/hello",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, "Hello world!")
|
||||
},
|
||||
Middlewares: []echo.MiddlewareFunc{
|
||||
apis.ActivityLogger(app),
|
||||
},
|
||||
app.OnServe().BindFunc(func(se *core.ServeEvent) error {
|
||||
// registers new "GET /hello" route
|
||||
se.Router.Get("/hello", func(re *core.RequestEvent) error {
|
||||
return re.String(200, "Hello world!")
|
||||
})
|
||||
|
||||
return nil
|
||||
return se.Next()
|
||||
})
|
||||
|
||||
if err := app.Start(); err != nil {
|
||||
|
|
@ -145,7 +136,7 @@ Check also the [Testing guide](http://pocketbase.io/docs/testing) to learn how t
|
|||
|
||||
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
|
||||
|
||||
All reports will be promptly addressed, and you'll be credited accordingly.
|
||||
All reports will be promptly addressed and you'll be credited in the fix release notes.
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
353
apis/admin.go
353
apis/admin.go
|
|
@ -1,353 +0,0 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tokens"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindAdminApi registers the admin api endpoints and the corresponding handlers.
|
||||
func bindAdminApi(app core.App, rg *echo.Group) {
|
||||
api := adminApi{app: app}
|
||||
|
||||
subGroup := rg.Group("/admins", ActivityLogger(app))
|
||||
subGroup.POST("/auth-with-password", api.authWithPassword)
|
||||
subGroup.POST("/request-password-reset", api.requestPasswordReset)
|
||||
subGroup.POST("/confirm-password-reset", api.confirmPasswordReset)
|
||||
subGroup.POST("/auth-refresh", api.authRefresh, RequireAdminAuth())
|
||||
subGroup.GET("", api.list, RequireAdminAuth())
|
||||
subGroup.POST("", api.create, RequireAdminAuthOnlyIfAny(app))
|
||||
subGroup.GET("/:id", api.view, RequireAdminAuth())
|
||||
subGroup.PATCH("/:id", api.update, RequireAdminAuth())
|
||||
subGroup.DELETE("/:id", api.delete, RequireAdminAuth())
|
||||
}
|
||||
|
||||
type adminApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (api *adminApi) authResponse(c echo.Context, admin *models.Admin, finalizers ...func(token string) error) error {
|
||||
token, tokenErr := tokens.NewAdminAuthToken(api.app, admin)
|
||||
if tokenErr != nil {
|
||||
return NewBadRequestError("Failed to create auth token.", tokenErr)
|
||||
}
|
||||
|
||||
for _, f := range finalizers {
|
||||
if err := f(token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
event := new(core.AdminAuthEvent)
|
||||
event.HttpContext = c
|
||||
event.Admin = admin
|
||||
event.Token = token
|
||||
|
||||
return api.app.OnAdminAuthRequest().Trigger(event, func(e *core.AdminAuthEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(200, map[string]any{
|
||||
"token": e.Token,
|
||||
"admin": e.Admin,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (api *adminApi) authRefresh(c echo.Context) error {
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin == nil {
|
||||
return NewNotFoundError("Missing auth admin context.", nil)
|
||||
}
|
||||
|
||||
event := new(core.AdminAuthRefreshEvent)
|
||||
event.HttpContext = c
|
||||
event.Admin = admin
|
||||
|
||||
return api.app.OnAdminBeforeAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error {
|
||||
return api.app.OnAdminAfterAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error {
|
||||
return api.authResponse(e.HttpContext, e.Admin)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (api *adminApi) authWithPassword(c echo.Context) error {
|
||||
form := forms.NewAdminLogin(api.app)
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminAuthWithPasswordEvent)
|
||||
event.HttpContext = c
|
||||
event.Password = form.Password
|
||||
event.Identity = form.Identity
|
||||
|
||||
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
|
||||
return func(admin *models.Admin) error {
|
||||
event.Admin = admin
|
||||
|
||||
return api.app.OnAdminBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error {
|
||||
if err := next(e.Admin); err != nil {
|
||||
return NewBadRequestError("Failed to authenticate.", err)
|
||||
}
|
||||
|
||||
return api.app.OnAdminAfterAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error {
|
||||
return api.authResponse(e.HttpContext, e.Admin)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *adminApi) requestPasswordReset(c echo.Context) error {
|
||||
form := forms.NewAdminPasswordResetRequest(api.app)
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
if err := form.Validate(); err != nil {
|
||||
return NewBadRequestError("An error occurred while validating the form.", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminRequestPasswordResetEvent)
|
||||
event.HttpContext = c
|
||||
|
||||
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
|
||||
return func(Admin *models.Admin) error {
|
||||
event.Admin = Admin
|
||||
|
||||
return api.app.OnAdminBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error {
|
||||
// run in background because we don't need to show the result to the client
|
||||
routine.FireAndForget(func() {
|
||||
if err := next(e.Admin); err != nil {
|
||||
api.app.Logger().Error("Failed to send admin password reset request.", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
return api.app.OnAdminAfterRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// eagerly write 204 response and skip submit errors
|
||||
// as a measure against admins enumeration
|
||||
if !c.Response().Committed {
|
||||
c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *adminApi) confirmPasswordReset(c echo.Context) error {
|
||||
form := forms.NewAdminPasswordResetConfirm(api.app)
|
||||
if readErr := c.Bind(form); readErr != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
|
||||
}
|
||||
|
||||
event := new(core.AdminConfirmPasswordResetEvent)
|
||||
event.HttpContext = c
|
||||
|
||||
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
|
||||
return func(admin *models.Admin) error {
|
||||
event.Admin = admin
|
||||
|
||||
return api.app.OnAdminBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error {
|
||||
if err := next(e.Admin); err != nil {
|
||||
return NewBadRequestError("Failed to set new password.", err)
|
||||
}
|
||||
|
||||
return api.app.OnAdminAfterConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *adminApi) list(c echo.Context) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(
|
||||
"id", "created", "updated", "name", "email",
|
||||
)
|
||||
|
||||
admins := []*models.Admin{}
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(api.app.Dao().AdminQuery()).
|
||||
ParseAndExec(c.QueryParams().Encode(), &admins)
|
||||
|
||||
if err != nil {
|
||||
return NewBadRequestError("", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminsListEvent)
|
||||
event.HttpContext = c
|
||||
event.Admins = admins
|
||||
event.Result = result
|
||||
|
||||
return api.app.OnAdminsListRequest().Trigger(event, func(e *core.AdminsListEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *adminApi) view(c echo.Context) error {
|
||||
id := c.PathParam("id")
|
||||
if id == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
admin, err := api.app.Dao().FindAdminById(id)
|
||||
if err != nil || admin == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminViewEvent)
|
||||
event.HttpContext = c
|
||||
event.Admin = admin
|
||||
|
||||
return api.app.OnAdminViewRequest().Trigger(event, func(e *core.AdminViewEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Admin)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *adminApi) create(c echo.Context) error {
|
||||
admin := &models.Admin{}
|
||||
|
||||
form := forms.NewAdminUpsert(api.app, admin)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminCreateEvent)
|
||||
event.HttpContext = c
|
||||
event.Admin = admin
|
||||
|
||||
// create the admin
|
||||
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
|
||||
return func(m *models.Admin) error {
|
||||
event.Admin = m
|
||||
|
||||
return api.app.OnAdminBeforeCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error {
|
||||
if err := next(e.Admin); err != nil {
|
||||
return NewBadRequestError("Failed to create admin.", err)
|
||||
}
|
||||
|
||||
return api.app.OnAdminAfterCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Admin)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *adminApi) update(c echo.Context) error {
|
||||
id := c.PathParam("id")
|
||||
if id == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
admin, err := api.app.Dao().FindAdminById(id)
|
||||
if err != nil || admin == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
form := forms.NewAdminUpsert(api.app, admin)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminUpdateEvent)
|
||||
event.HttpContext = c
|
||||
event.Admin = admin
|
||||
|
||||
// update the admin
|
||||
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
|
||||
return func(m *models.Admin) error {
|
||||
event.Admin = m
|
||||
|
||||
return api.app.OnAdminBeforeUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error {
|
||||
if err := next(e.Admin); err != nil {
|
||||
return NewBadRequestError("Failed to update admin.", err)
|
||||
}
|
||||
|
||||
return api.app.OnAdminAfterUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Admin)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *adminApi) delete(c echo.Context) error {
|
||||
id := c.PathParam("id")
|
||||
if id == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
admin, err := api.app.Dao().FindAdminById(id)
|
||||
if err != nil || admin == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.AdminDeleteEvent)
|
||||
event.HttpContext = c
|
||||
event.Admin = admin
|
||||
|
||||
return api.app.OnAdminBeforeDeleteRequest().Trigger(event, func(e *core.AdminDeleteEvent) error {
|
||||
if err := api.app.Dao().DeleteAdmin(e.Admin); err != nil {
|
||||
return NewBadRequestError("Failed to delete admin.", err)
|
||||
}
|
||||
|
||||
return api.app.OnAdminAfterDeleteRequest().Trigger(event, func(e *core.AdminDeleteEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
@ -1,925 +0,0 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestAdminAuthWithPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"identity":{"code":"validation_required","message":"Cannot be blank."},"password":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(`{`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "wrong email",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"missing@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "wrong password",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"invalid"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid email/password (guest)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"admin":{"id":"sywbhecnh46rhm0"`,
|
||||
`"token":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeAuthWithPasswordRequest": 1,
|
||||
"OnAdminAfterAuthWithPasswordRequest": 1,
|
||||
"OnAdminAuthRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid email/password (already authorized)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4MTYwMH0.han3_sG65zLddpcX2ic78qgy7FKecuPfOpFa8Dvi5Bg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"admin":{"id":"sywbhecnh46rhm0"`,
|
||||
`"token":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeAuthWithPasswordRequest": 1,
|
||||
"OnAdminAfterAuthWithPasswordRequest": 1,
|
||||
"OnAdminAuthRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnAdminAfterAuthWithPasswordRequest error response",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4MTYwMH0.han3_sG65zLddpcX2ic78qgy7FKecuPfOpFa8Dvi5Bg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnAdminAfterAuthWithPasswordRequest().Add(func(e *core.AdminAuthWithPasswordEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeAuthWithPasswordRequest": 1,
|
||||
"OnAdminAfterAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminRequestPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/request-password-reset",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "missing admin",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
},
|
||||
{
|
||||
Name: "existing admin",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnMailerBeforeAdminResetPasswordSend": 1,
|
||||
"OnMailerAfterAdminResetPasswordSend": 1,
|
||||
"OnAdminBeforeRequestPasswordResetRequest": 1,
|
||||
"OnAdminAfterRequestPasswordResetRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing admin (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
// simulate recent password request
|
||||
admin, err := app.Dao().FindAdminByEmail("test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
admin.LastResetSentAt = types.NowDateTime()
|
||||
dao := daos.New(app.Dao().DB()) // new dao to ignore hooks
|
||||
if err := dao.Save(admin); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminConfirmPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/confirm-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"password":{"code":"validation_required","message":"Cannot be blank."},"passwordConfirm":{"code":"validation_required","message":"Cannot be blank."},"token":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/confirm-password-reset",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "expired token",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MTY0MDk5MTY2MX0.GLwCOsgWTTEKXTK-AyGW838de1OeZGIjfHH0FoRLqZg",
|
||||
"password":"1234567890",
|
||||
"passwordConfirm":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"token":{"code":"validation_invalid_token","message":"Invalid or expired token."}}}`},
|
||||
},
|
||||
{
|
||||
Name: "valid token + invalid password",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
|
||||
"password":"123456",
|
||||
"passwordConfirm":"123456"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"password":{"code":"validation_length_out_of_range"`},
|
||||
},
|
||||
{
|
||||
Name: "valid token + valid password",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
|
||||
"password":"1234567891",
|
||||
"passwordConfirm":"1234567891"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnAdminBeforeConfirmPasswordResetRequest": 1,
|
||||
"OnAdminAfterConfirmPasswordResetRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnAdminAfterConfirmPasswordResetRequest error response",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
|
||||
"password":"1234567891",
|
||||
"passwordConfirm":"1234567891"
|
||||
}`),
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnAdminAfterConfirmPasswordResetRequest().Add(func(e *core.AdminConfirmPasswordResetEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnAdminBeforeConfirmPasswordResetRequest": 1,
|
||||
"OnAdminAfterConfirmPasswordResetRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-refresh",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-refresh",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (expired token)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-refresh",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MTY0MDk5MTY2MX0.I7w8iktkleQvC7_UIRpD7rNzcU4OnF7i7SFIUu6lD_4",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (valid token)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-refresh",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"admin":{"id":"sywbhecnh46rhm0"`,
|
||||
`"token":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminAuthRequest": 1,
|
||||
"OnAdminBeforeAuthRefreshRequest": 1,
|
||||
"OnAdminAfterAuthRefreshRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnAdminAfterAuthRefreshRequest error response",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins/auth-refresh",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnAdminAfterAuthRefreshRequest().Add(func(e *core.AdminAuthRefreshEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeAuthRefreshRequest": 1,
|
||||
"OnAdminAfterAuthRefreshRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":3`,
|
||||
`"items":[{`,
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
`"id":"sbmbsdb40jyxf7h"`,
|
||||
`"id":"9q2trqumvlyr3bd"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + paging and sorting",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins?page=2&perPage=1&sort=-created",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":2`,
|
||||
`"perPage":1`,
|
||||
`"totalItems":3`,
|
||||
`"items":[{`,
|
||||
`"id":"sbmbsdb40jyxf7h"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"tokenKey"`,
|
||||
`"passwordHash"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + invalid filter",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins?filter=invalidfield~'test2'",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + valid filter",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins?filter=email~'test3'",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"items":[{`,
|
||||
`"id":"9q2trqumvlyr3bd"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"tokenKey"`,
|
||||
`"passwordHash"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + nonexisting admin id",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins/nonexisting",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + existing admin id",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"sbmbsdb40jyxf7h"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"tokenKey"`,
|
||||
`"passwordHash"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminViewRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as user",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + missing admin id",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/admins/missing",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + existing admin id",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeDelete": 1,
|
||||
"OnModelAfterDelete": 1,
|
||||
"OnAdminBeforeDeleteRequest": 1,
|
||||
"OnAdminAfterDeleteRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin - try to delete the only remaining admin",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/admins/sywbhecnh46rhm0",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
// delete all admins except the authorized one
|
||||
adminModel := &models.Admin{}
|
||||
_, err := app.Dao().DB().Delete(adminModel.TableName(), dbx.Not(dbx.HashExp{
|
||||
"id": "sywbhecnh46rhm0",
|
||||
})).Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnAdminBeforeDeleteRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnAdminAfterDeleteRequest error response",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnAdminAfterDeleteRequest().Add(func(e *core.AdminDeleteEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeDelete": 1,
|
||||
"OnModelAfterDelete": 1,
|
||||
"OnAdminBeforeDeleteRequest": 1,
|
||||
"OnAdminAfterDeleteRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized (while having at least 1 existing admin)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "unauthorized (while having 0 existing admins)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(`{"email":"testnew@example.com","password":"1234567890","passwordConfirm":"1234567890","avatar":3}`),
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
// delete all admins
|
||||
_, err := app.Dao().DB().NewQuery("DELETE FROM {{_admins}}").Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":`,
|
||||
`"email":"testnew@example.com"`,
|
||||
`"avatar":3`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeCreate": 1,
|
||||
"OnModelAfterCreate": 1,
|
||||
"OnAdminBeforeCreateRequest": 1,
|
||||
"OnAdminAfterCreateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + empty data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(``),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."},"password":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + invalid data format",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(`{`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + invalid data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(`{
|
||||
"email":"test@example.com",
|
||||
"password":"1234",
|
||||
"passwordConfirm":"4321",
|
||||
"avatar":99
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"avatar":{"code":"validation_max_less_equal_than_required"`,
|
||||
`"email":{"code":"validation_admin_email_exists"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
`"passwordConfirm":{"code":"validation_values_mismatch"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + valid data",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(`{
|
||||
"email":"testnew@example.com",
|
||||
"password":"1234567890",
|
||||
"passwordConfirm":"1234567890",
|
||||
"avatar":3
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":`,
|
||||
`"email":"testnew@example.com"`,
|
||||
`"avatar":3`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"password"`,
|
||||
`"passwordConfirm"`,
|
||||
`"tokenKey"`,
|
||||
`"passwordHash"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeCreate": 1,
|
||||
"OnModelAfterCreate": 1,
|
||||
"OnAdminBeforeCreateRequest": 1,
|
||||
"OnAdminAfterCreateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnAdminAfterCreateRequest error response",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(`{
|
||||
"email":"testnew@example.com",
|
||||
"password":"1234567890",
|
||||
"passwordConfirm":"1234567890",
|
||||
"avatar":3
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnAdminAfterCreateRequest().Add(func(e *core.AdminCreateEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeCreate": 1,
|
||||
"OnModelAfterCreate": 1,
|
||||
"OnAdminBeforeCreateRequest": 1,
|
||||
"OnAdminAfterCreateRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as user",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + missing admin",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/missing",
|
||||
Body: strings.NewReader(``),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + empty data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
Body: strings.NewReader(``),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"sbmbsdb40jyxf7h"`,
|
||||
`"email":"test2@example.com"`,
|
||||
`"avatar":2`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnAdminBeforeUpdateRequest": 1,
|
||||
"OnAdminAfterUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + invalid formatted data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
Body: strings.NewReader(`{`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + invalid data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
Body: strings.NewReader(`{
|
||||
"email":"test@example.com",
|
||||
"password":"1234",
|
||||
"passwordConfirm":"4321",
|
||||
"avatar":99
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"avatar":{"code":"validation_max_less_equal_than_required"`,
|
||||
`"email":{"code":"validation_admin_email_exists"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
`"passwordConfirm":{"code":"validation_values_mismatch"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + valid data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
Body: strings.NewReader(`{
|
||||
"email":"testnew@example.com",
|
||||
"password":"1234567891",
|
||||
"passwordConfirm":"1234567891",
|
||||
"avatar":5
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"sbmbsdb40jyxf7h"`,
|
||||
`"email":"testnew@example.com"`,
|
||||
`"avatar":5`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"password"`,
|
||||
`"passwordConfirm"`,
|
||||
`"tokenKey"`,
|
||||
`"passwordHash"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnAdminBeforeUpdateRequest": 1,
|
||||
"OnAdminAfterUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnAdminAfterUpdateRequest error response",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/admins/sbmbsdb40jyxf7h",
|
||||
Body: strings.NewReader(`{
|
||||
"email":"testnew@example.com",
|
||||
"password":"1234567891",
|
||||
"passwordConfirm":"1234567891",
|
||||
"avatar":5
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnAdminAfterUpdateRequest().Add(func(e *core.AdminUpdateEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnAdminBeforeUpdateRequest": 1,
|
||||
"OnAdminAfterUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
// ApiError defines the struct for a basic api error response.
|
||||
type ApiError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]any `json:"data"`
|
||||
|
||||
// stores unformatted error data (could be an internal error, text, etc.)
|
||||
rawData any
|
||||
}
|
||||
|
||||
// Error makes it compatible with the `error` interface.
|
||||
func (e *ApiError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// RawData returns the unformatted error data (could be an internal error, text, etc.)
|
||||
func (e *ApiError) RawData() any {
|
||||
return e.rawData
|
||||
}
|
||||
|
||||
// NewNotFoundError creates and returns 404 `ApiError`.
|
||||
func NewNotFoundError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "The requested resource wasn't found."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusNotFound, message, data)
|
||||
}
|
||||
|
||||
// NewBadRequestError creates and returns 400 `ApiError`.
|
||||
func NewBadRequestError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusBadRequest, message, data)
|
||||
}
|
||||
|
||||
// NewForbiddenError creates and returns 403 `ApiError`.
|
||||
func NewForbiddenError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "You are not allowed to perform this request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusForbidden, message, data)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError creates and returns 401 `ApiError`.
|
||||
func NewUnauthorizedError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Missing or invalid authentication token."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusUnauthorized, message, data)
|
||||
}
|
||||
|
||||
// NewApiError creates and returns new normalized `ApiError` instance.
|
||||
func NewApiError(status int, message string, data any) *ApiError {
|
||||
return &ApiError{
|
||||
rawData: data,
|
||||
Data: safeErrorsData(data),
|
||||
Code: status,
|
||||
Message: strings.TrimSpace(inflector.Sentenize(message)),
|
||||
}
|
||||
}
|
||||
|
||||
func safeErrorsData(data any) map[string]any {
|
||||
switch v := data.(type) {
|
||||
case validation.Errors:
|
||||
return resolveSafeErrorsData[error](v)
|
||||
case map[string]validation.Error:
|
||||
return resolveSafeErrorsData[validation.Error](v)
|
||||
case map[string]error:
|
||||
return resolveSafeErrorsData[error](v)
|
||||
case map[string]any:
|
||||
return resolveSafeErrorsData[any](v)
|
||||
default:
|
||||
return map[string]any{} // not nil to ensure that is json serialized as object
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
|
||||
result := map[string]any{}
|
||||
|
||||
for name, err := range data {
|
||||
if isNestedError(err) {
|
||||
result[name] = safeErrorsData(err)
|
||||
continue
|
||||
}
|
||||
result[name] = resolveSafeErrorItem(err)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isNestedError(err any) bool {
|
||||
switch err.(type) {
|
||||
case validation.Errors, map[string]validation.Error, map[string]error, map[string]any:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveSafeErrorItem extracts from each validation error its
|
||||
// public safe error code and message.
|
||||
func resolveSafeErrorItem(err any) map[string]string {
|
||||
// default public safe error values
|
||||
code := "validation_invalid_value"
|
||||
msg := "Invalid value."
|
||||
|
||||
// only validation errors are public safe
|
||||
if obj, ok := err.(validation.Error); ok {
|
||||
code = obj.Code()
|
||||
msg = inflector.Sentenize(obj.Error())
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"code": code,
|
||||
"message": msg,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
package apis
|
||||
|
||||
import "github.com/pocketbase/pocketbase/tools/router"
|
||||
|
||||
// ApiError aliases to minimize the breaking changes with earlier versions
|
||||
// and for consistency with the JSVM binds.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// NewApiError is an alias for [router.NewApiError].
|
||||
func NewApiError(status int, message string, errData any) *router.ApiError {
|
||||
return router.NewApiError(status, message, errData)
|
||||
}
|
||||
|
||||
// NewBadRequestError is an alias for [router.NewBadRequestError].
|
||||
func NewBadRequestError(message string, errData any) *router.ApiError {
|
||||
return router.NewBadRequestError(message, errData)
|
||||
}
|
||||
|
||||
// NewNotFoundError is an alias for [router.NewNotFoundError].
|
||||
func NewNotFoundError(message string, errData any) *router.ApiError {
|
||||
return router.NewNotFoundError(message, errData)
|
||||
}
|
||||
|
||||
// NewForbiddenError is an alias for [router.NewForbiddenError].
|
||||
func NewForbiddenError(message string, errData any) *router.ApiError {
|
||||
return router.NewForbiddenError(message, errData)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError is an alias for [router.NewUnauthorizedError].
|
||||
func NewUnauthorizedError(message string, errData any) *router.ApiError {
|
||||
return router.NewUnauthorizedError(message, errData)
|
||||
}
|
||||
|
||||
// NewTooManyRequestsError is an alias for [router.NewTooManyRequestsError].
|
||||
func NewTooManyRequestsError(message string, errData any) *router.ApiError {
|
||||
return router.NewTooManyRequestsError(message, errData)
|
||||
}
|
||||
|
||||
// NewInternalServerError is an alias for [router.NewInternalServerError].
|
||||
func NewInternalServerError(message string, errData any) *router.ApiError {
|
||||
return router.NewInternalServerError(message, errData)
|
||||
}
|
||||
|
|
@ -1,162 +0,0 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
)
|
||||
|
||||
func TestNewApiErrorWithRawData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := apis.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
"rawData_test",
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"code":300,"message":"Message_test.","data":{}}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() != "rawData_test" {
|
||||
t.Errorf("Expected rawData %v, got %v", "rawData_test", e.RawData())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewApiErrorWithValidationData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := apis.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
validation.Errors{
|
||||
"err1": errors.New("test error"), // should be normalized
|
||||
"err2": validation.ErrRequired,
|
||||
"err3": validation.Errors{
|
||||
"sub1": errors.New("test error"), // should be normalized
|
||||
"sub2": validation.ErrRequired,
|
||||
"sub3": validation.Errors{
|
||||
"sub11": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"code":300,"message":"Message_test.","data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"sub1":{"code":"validation_invalid_value","message":"Invalid value."},"sub2":{"code":"validation_required","message":"Cannot be blank."},"sub3":{"sub11":{"code":"validation_required","message":"Cannot be blank."}}}}}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() == nil {
|
||||
t.Error("Expected non-nil rawData")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":404,"message":"The requested resource wasn't found.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":404,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":404,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := apis.NewNotFoundError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadRequestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":400,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":400,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := apis.NewBadRequestError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForbiddenError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":403,"message":"You are not allowed to perform this request.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":403,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":403,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := apis.NewForbiddenError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":401,"message":"Missing or invalid authentication token.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":401,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":401,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := apis.NewUnauthorizedError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
205
apis/backup.go
205
apis/backup.go
|
|
@ -6,42 +6,37 @@ import (
|
|||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// bindBackupApi registers the file api endpoints and the corresponding handlers.
|
||||
//
|
||||
// @todo add hooks once the app hooks api restructuring is finalized
|
||||
func bindBackupApi(app core.App, rg *echo.Group) {
|
||||
api := backupApi{app: app}
|
||||
|
||||
subGroup := rg.Group("/backups", ActivityLogger(app))
|
||||
subGroup.GET("", api.list, RequireAdminAuth())
|
||||
subGroup.POST("", api.create, RequireAdminAuth())
|
||||
subGroup.POST("/upload", api.upload, RequireAdminAuth())
|
||||
subGroup.GET("/:key", api.download)
|
||||
subGroup.DELETE("/:key", api.delete, RequireAdminAuth())
|
||||
subGroup.POST("/:key/restore", api.restore, RequireAdminAuth())
|
||||
func bindBackupApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/backups")
|
||||
sub.GET("", backupsList).Bind(RequireSuperuserAuth())
|
||||
sub.POST("", backupCreate).Bind(RequireSuperuserAuth())
|
||||
sub.POST("/upload", backupUpload).Bind(RequireSuperuserAuthOnlyIfAny())
|
||||
sub.GET("/{key}", backupDownload) // relies on superuser file token
|
||||
sub.DELETE("/{key}", backupDelete).Bind(RequireSuperuserAuth())
|
||||
sub.POST("/{key}/restore", backupRestore).Bind(RequireSuperuserAuthOnlyIfAny())
|
||||
}
|
||||
|
||||
type backupApi struct {
|
||||
app core.App
|
||||
type backupFileInfo struct {
|
||||
Modified types.DateTime `json:"modified"`
|
||||
Key string `json:"key"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func (api *backupApi) list(c echo.Context) error {
|
||||
func backupsList(e *core.RequestEvent) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := api.app.NewBackupsFilesystem()
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to load backups filesystem.", err)
|
||||
return e.BadRequestError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
|
|
@ -49,166 +44,112 @@ func (api *backupApi) list(c echo.Context) error {
|
|||
|
||||
backups, err := fsys.List("")
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
|
||||
return e.BadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
result := make([]models.BackupFileInfo, len(backups))
|
||||
result := make([]backupFileInfo, len(backups))
|
||||
|
||||
for i, obj := range backups {
|
||||
modified, _ := types.ParseDateTime(obj.ModTime)
|
||||
|
||||
result[i] = models.BackupFileInfo{
|
||||
result[i] = backupFileInfo{
|
||||
Key: obj.Key,
|
||||
Size: obj.Size,
|
||||
Modified: modified,
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (api *backupApi) create(c echo.Context) error {
|
||||
if api.app.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return NewBadRequestError("Try again later - another backup/restore process has already been started", nil)
|
||||
}
|
||||
func backupDownload(e *core.RequestEvent) error {
|
||||
fileToken := e.Request.URL.Query().Get("token")
|
||||
|
||||
form := forms.NewBackupCreate(api.app)
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[string]) forms.InterceptorNextFunc[string] {
|
||||
return func(name string) error {
|
||||
if err := next(name); err != nil {
|
||||
return NewBadRequestError("Failed to create backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *backupApi) upload(c echo.Context) error {
|
||||
files, err := rest.FindUploadedFiles(c.Request(), "file")
|
||||
if err != nil {
|
||||
return NewBadRequestError("Missing or invalid uploaded file.", err)
|
||||
}
|
||||
|
||||
form := forms.NewBackupUpload(api.app)
|
||||
form.File = files[0]
|
||||
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*filesystem.File]) forms.InterceptorNextFunc[*filesystem.File] {
|
||||
return func(file *filesystem.File) error {
|
||||
if err := next(file); err != nil {
|
||||
return NewBadRequestError("Failed to upload backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *backupApi) download(c echo.Context) error {
|
||||
fileToken := c.QueryParam("token")
|
||||
|
||||
_, err := api.app.Dao().FindAdminByToken(
|
||||
fileToken,
|
||||
api.app.Settings().AdminFileToken.Secret,
|
||||
)
|
||||
if err != nil {
|
||||
return NewForbiddenError("Insufficient permissions to access the resource.", err)
|
||||
authRecord, err := e.App.FindAuthRecordByToken(fileToken, core.TokenTypeFile)
|
||||
if err != nil || !authRecord.IsSuperuser() {
|
||||
return e.ForbiddenError("Insufficient permissions to access the resource.", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := api.app.NewBackupsFilesystem()
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to load backups filesystem.", err)
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := c.PathParam("key")
|
||||
|
||||
br, err := fsys.GetFile(key)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to retrieve backup item. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
defer br.Close()
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
return fsys.Serve(
|
||||
c.Response(),
|
||||
c.Request(),
|
||||
e.Response,
|
||||
e.Request,
|
||||
key,
|
||||
filepath.Base(key), // without the path prefix (if any)
|
||||
)
|
||||
}
|
||||
|
||||
func (api *backupApi) restore(c echo.Context) error {
|
||||
if api.app.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return NewBadRequestError("Try again later - another backup/restore process has already been started.", nil)
|
||||
func backupDelete(e *core.RequestEvent) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
if key != "" && cast.ToString(e.App.Store().Get(core.StoreKeyActiveBackup)) == key {
|
||||
return e.BadRequestError("The backup is currently being used and cannot be deleted.", nil)
|
||||
}
|
||||
|
||||
key := c.PathParam("key")
|
||||
if err := fsys.Delete(key); err != nil {
|
||||
return e.BadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func backupRestore(e *core.RequestEvent) error {
|
||||
if e.App.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return e.BadRequestError("Try again later - another backup/restore process has already been started.", nil)
|
||||
}
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := api.app.NewBackupsFilesystem()
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to load backups filesystem.", err)
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(existsCtx)
|
||||
|
||||
if exists, err := fsys.Exists(key); !exists {
|
||||
return NewBadRequestError("Missing or invalid backup file.", err)
|
||||
return e.BadRequestError("Missing or invalid backup file.", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
// wait max 15 minutes to fetch the backup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// give some optimistic time to write the response
|
||||
routine.FireAndForget(func() {
|
||||
// give some optimistic time to write the response before restarting the app
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if err := api.app.RestoreBackup(ctx, key); err != nil {
|
||||
api.app.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
|
||||
// wait max 10 minutes to fetch the backup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := e.App.RestoreBackup(ctx, key); err != nil {
|
||||
e.App.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *backupApi) delete(c echo.Context) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := api.app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := c.PathParam("key")
|
||||
|
||||
if key != "" && cast.ToString(api.app.Store().Get(core.StoreKeyActiveBackup)) == key {
|
||||
return NewBadRequestError("The backup is currently being used and cannot be deleted.", nil)
|
||||
}
|
||||
|
||||
if err := fsys.Delete(key); err != nil {
|
||||
return NewBadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func backupCreate(e *core.RequestEvent) error {
|
||||
if e.App.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return e.BadRequestError("Try again later - another backup/restore process has already been started", nil)
|
||||
}
|
||||
|
||||
form := new(backupCreateForm)
|
||||
form.app = e.App
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
err = e.App.CreateBackup(context.Background(), form.Name)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var backupNameRegex = regexp.MustCompile(`^[a-z0-9_-]+\.zip$`)
|
||||
|
||||
type backupCreateForm struct {
|
||||
app core.App
|
||||
|
||||
Name string `form:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (form *backupCreateForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(
|
||||
&form.Name,
|
||||
validation.Length(1, 150),
|
||||
validation.Match(backupNameRegex),
|
||||
validation.By(form.checkUniqueName),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *backupCreateForm) checkUniqueName(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
fsys, err := form.app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
if exists, err := fsys.Exists(v); err != nil || exists {
|
||||
return validation.NewError("validation_backup_name_exists", "The backup file name is invalid or already exists.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"gocloud.dev/blob"
|
||||
|
|
@ -23,50 +22,51 @@ func TestBackupsList(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (empty list)",
|
||||
Name: "authorized as superuser (empty list)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[]`,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`[]`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin",
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -77,6 +77,7 @@ func TestBackupsList(t *testing.T) {
|
|||
`"test2.zip"`,
|
||||
`"test3.zip"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -92,50 +93,53 @@ func TestBackupsCreate(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups",
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
URL: "/api/backups",
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (pending backup)",
|
||||
Name: "authorized as superuser (pending backup)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (autogenerated name)",
|
||||
Name: "authorized as superuser (autogenerated name)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -151,16 +155,20 @@ func TestBackupsCreate(t *testing.T) {
|
|||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBackupCreate": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (invalid name)",
|
||||
Name: "authorized as superuser (invalid name)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups",
|
||||
URL: "/api/backups",
|
||||
Body: strings.NewReader(`{"name":"!test.zip"}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
|
|
@ -168,16 +176,17 @@ func TestBackupsCreate(t *testing.T) {
|
|||
`"data":{`,
|
||||
`"name":{"code":"validation_match_invalid"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (valid name)",
|
||||
Name: "authorized as superuser (valid name)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups",
|
||||
URL: "/api/backups",
|
||||
Body: strings.NewReader(`{"name":"test.zip"}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -193,6 +202,10 @@ func TestBackupsCreate(t *testing.T) {
|
|||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBackupCreate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -201,7 +214,7 @@ func TestBackupsCreate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBackupsUpload(t *testing.T) {
|
||||
func TestBackupUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// create dummy form data bodies
|
||||
|
|
@ -243,55 +256,58 @@ func TestBackupsUpload(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/upload",
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[0].buffer,
|
||||
RequestHeaders: map[string]string{
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[0].contentType,
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/upload",
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[1].buffer,
|
||||
RequestHeaders: map[string]string{
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[1].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (missing file)",
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/upload",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/upload",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (existing backup name)",
|
||||
Name: "authorized as superuser (existing backup name)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/upload",
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[3].buffer,
|
||||
RequestHeaders: map[string]string{
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[3].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -302,7 +318,7 @@ func TestBackupsUpload(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
|
|
@ -310,23 +326,49 @@ func TestBackupsUpload(t *testing.T) {
|
|||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"file":{`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (valid file)",
|
||||
Name: "authorized as superuser (valid file)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/upload",
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[4].buffer,
|
||||
RequestHeaders: map[string]string{
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[4].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unauthorized with 0 superusers (valid file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[5].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[5].contentType,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// delete all superusers
|
||||
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -342,148 +384,159 @@ func TestBackupsDownload(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with record auth header",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with admin auth header",
|
||||
Name: "with superuser auth header",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with empty or invalid token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip?token=",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip?token=",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid record auth token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid record file token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid admin auth token",
|
||||
Name: "with valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with expired admin file token",
|
||||
Name: "with expired superuser file token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImFkbWluIn0.g7Q_3UX6H--JWJ7yt1Hoe-1ugTX1KpbKzdt0zjGSe-E",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.hTNDzikwJdcoWrLnRnp7xbaifZ2vuYZ0oOYRHtJfnk4",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid admin file token but missing backup name",
|
||||
Name: "with valid superuser file token but missing backup name",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid admin file token",
|
||||
Name: "with valid superuser file token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`storage/`,
|
||||
`data.db`,
|
||||
`logs.db`,
|
||||
"storage/",
|
||||
"data.db",
|
||||
"aux.db",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid admin file token and backup name with escaped char",
|
||||
Name: "with valid superuser file token and backup name with escaped char",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`storage/`,
|
||||
`data.db`,
|
||||
`logs.db`,
|
||||
"storage/",
|
||||
"data.db",
|
||||
"aux.db",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -495,7 +548,7 @@ func TestBackupsDownload(t *testing.T) {
|
|||
func TestBackupsDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
noTestBackupFilesChanges := func(t *testing.T, app *tests.TestApp) {
|
||||
noTestBackupFilesChanges := func(t testing.TB, app *tests.TestApp) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -511,62 +564,65 @@ func TestBackupsDelete(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/backups/test1.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (missing file)",
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/backups/missing.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/missing.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (existing file with matching active backup)",
|
||||
Name: "authorized as superuser (existing file with matching active backup)",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/backups/test1.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -574,20 +630,21 @@ func TestBackupsDelete(t *testing.T) {
|
|||
// mock active backup with the same name to delete
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "test1.zip")
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (existing file and no matching active backup)",
|
||||
Name: "authorized as superuser (existing file and no matching active backup)",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/backups/test1.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -595,7 +652,7 @@ func TestBackupsDelete(t *testing.T) {
|
|||
// mock active backup with different name
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "new.zip")
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -614,20 +671,21 @@ func TestBackupsDelete(t *testing.T) {
|
|||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (backup with escaped character)",
|
||||
Name: "authorized as superuser (backup with escaped character)",
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/backups/%40test4.zip",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/%40test4.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -646,6 +704,7 @@ func TestBackupsDelete(t *testing.T) {
|
|||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -661,53 +720,56 @@ func TestBackupsRestore(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/test1.zip/restore",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/test1.zip/restore",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (missing file)",
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/missing.zip/restore",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/missing.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (active backup process)",
|
||||
Name: "authorized as superuser (active backup process)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/backups/test1.zip/restore",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -716,6 +778,26 @@ func TestBackupsRestore(t *testing.T) {
|
|||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unauthorized with no superusers (checks only access)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/missing.zip/restore",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// delete all superusers
|
||||
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -758,7 +840,7 @@ func getBackupFiles(app core.App) ([]*blob.ListObject, error) {
|
|||
return fsys.List("")
|
||||
}
|
||||
|
||||
func ensureNoBackups(t *testing.T, app *tests.TestApp) {
|
||||
func ensureNoBackups(t testing.TB, app *tests.TestApp) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
func backupUpload(e *core.RequestEvent) error {
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
form := new(backupUploadForm)
|
||||
form.fsys = fsys
|
||||
files, _ := FindUploadedFiles(e.Request, "file")
|
||||
if len(files) > 0 {
|
||||
form.File = files[0]
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
err = fsys.UploadFile(form.File, form.File.OriginalName)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to upload backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type backupUploadForm struct {
|
||||
fsys *filesystem.System
|
||||
|
||||
File *filesystem.File `json:"file"`
|
||||
}
|
||||
|
||||
func (form *backupUploadForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(
|
||||
&form.File,
|
||||
validation.Required,
|
||||
validation.By(validators.UploadedFileMimeType([]string{"application/zip"})),
|
||||
validation.By(form.checkUniqueName),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *backupUploadForm) checkUniqueName(value any) error {
|
||||
v, _ := value.(*filesystem.File)
|
||||
if v == nil {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// note: we use the original name because that is what we upload
|
||||
if exists, err := form.fsys.Exists(v.OriginalName); err != nil || exists {
|
||||
return validation.NewError("validation_backup_name_exists", "Backup file with the specified name already exists.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
376
apis/base.go
376
apis/base.go
|
|
@ -1,266 +1,202 @@
|
|||
// Package apis implements the default PocketBase api services and middlewares.
|
||||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
"github.com/pocketbase/pocketbase/ui"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
const trailedAdminPath = "/_/"
|
||||
// StaticWildcardParam is the name of Static handler wildcard parameter.
|
||||
const StaticWildcardParam = "path"
|
||||
|
||||
// InitApi creates a configured echo instance with registered
|
||||
// system and app specific routes and middlewares.
|
||||
func InitApi(app core.App) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
e.Debug = false
|
||||
e.Binder = &rest.MultiBinder{}
|
||||
e.JSONSerializer = &rest.Serializer{
|
||||
FieldsParam: fieldsQueryParam,
|
||||
}
|
||||
// NewRouter returns a new router instance loaded with the default app middlewares and api routes.
|
||||
func NewRouter(app core.App) (*router.Router[*core.RequestEvent], error) {
|
||||
pbRouter := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*core.RequestEvent, router.EventCleanupFunc) {
|
||||
event := new(core.RequestEvent)
|
||||
event.Response = w
|
||||
event.Request = r
|
||||
event.App = app
|
||||
|
||||
// configure a custom router
|
||||
e.ResetRouterCreator(func(ec *echo.Echo) echo.Router {
|
||||
return echo.NewRouter(echo.RouterConfig{
|
||||
UnescapePathParamValues: true,
|
||||
AllowOverwritingRoute: true,
|
||||
})
|
||||
return event, nil
|
||||
})
|
||||
|
||||
// default middlewares
|
||||
e.Pre(middleware.RemoveTrailingSlashWithConfig(middleware.RemoveTrailingSlashConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
// enable by default only for the API routes
|
||||
return !strings.HasPrefix(c.Request().URL.Path, "/api/")
|
||||
},
|
||||
}))
|
||||
e.Pre(LoadAuthContext(app))
|
||||
e.Use(middleware.Recover())
|
||||
e.Use(middleware.Secure())
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set(ContextExecStartKey, time.Now())
|
||||
// register default middlewares
|
||||
pbRouter.Bind(activityLogger())
|
||||
pbRouter.Bind(loadAuthToken())
|
||||
pbRouter.Bind(securityHeaders())
|
||||
pbRouter.Bind(rateLimit())
|
||||
pbRouter.Bind(BodyLimit(DefaultMaxBodySize))
|
||||
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
apiGroup := pbRouter.Group("/api")
|
||||
bindSettingsApi(app, apiGroup)
|
||||
bindCollectionApi(app, apiGroup)
|
||||
bindRecordCrudApi(app, apiGroup)
|
||||
bindRecordAuthApi(app, apiGroup)
|
||||
bindLogsApi(app, apiGroup)
|
||||
bindBackupApi(app, apiGroup)
|
||||
bindFileApi(app, apiGroup)
|
||||
bindBatchApi(app, apiGroup)
|
||||
bindRealtimeApi(app, apiGroup)
|
||||
bindHealthApi(app, apiGroup)
|
||||
|
||||
// custom error handler
|
||||
e.HTTPErrorHandler = func(c echo.Context, err error) {
|
||||
if err == nil {
|
||||
return // no error
|
||||
}
|
||||
|
||||
var apiErr *ApiError
|
||||
|
||||
if errors.As(err, &apiErr) {
|
||||
// already an api error...
|
||||
} else if v := new(echo.HTTPError); errors.As(err, &v) {
|
||||
msg := fmt.Sprintf("%v", v.Message)
|
||||
apiErr = NewApiError(v.Code, msg, v)
|
||||
} else {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
apiErr = NewNotFoundError("", err)
|
||||
} else {
|
||||
apiErr = NewBadRequestError("", err)
|
||||
}
|
||||
}
|
||||
|
||||
logRequest(app, c, apiErr)
|
||||
|
||||
if c.Response().Committed {
|
||||
return // already committed
|
||||
}
|
||||
|
||||
event := new(core.ApiErrorEvent)
|
||||
event.HttpContext = c
|
||||
event.Error = apiErr
|
||||
|
||||
// send error response
|
||||
hookErr := app.OnBeforeApiError().Trigger(event, func(e *core.ApiErrorEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// @see https://github.com/labstack/echo/issues/608
|
||||
if e.HttpContext.Request().Method == http.MethodHead {
|
||||
return e.HttpContext.NoContent(apiErr.Code)
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(apiErr.Code, apiErr)
|
||||
})
|
||||
|
||||
if hookErr == nil {
|
||||
if err := app.OnAfterApiError().Trigger(event); err != nil {
|
||||
app.Logger().Debug("OnAfterApiError failure", slog.String("error", err.Error()))
|
||||
}
|
||||
} else {
|
||||
app.Logger().Debug("OnBeforeApiError error (truly rare case, eg. client already disconnected)", slog.String("error", hookErr.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// admin ui routes
|
||||
bindStaticAdminUI(app, e)
|
||||
|
||||
// default routes
|
||||
api := e.Group("/api", eagerRequestInfoCache(app))
|
||||
bindSettingsApi(app, api)
|
||||
bindAdminApi(app, api)
|
||||
bindCollectionApi(app, api)
|
||||
bindRecordCrudApi(app, api)
|
||||
bindRecordAuthApi(app, api)
|
||||
bindFileApi(app, api)
|
||||
bindRealtimeApi(app, api)
|
||||
bindLogsApi(app, api)
|
||||
bindHealthApi(app, api)
|
||||
bindBackupApi(app, api)
|
||||
|
||||
// catch all any route
|
||||
api.Any("/*", func(c echo.Context) error {
|
||||
return echo.ErrNotFound
|
||||
}, ActivityLogger(app))
|
||||
|
||||
return e, nil
|
||||
return pbRouter, nil
|
||||
}
|
||||
|
||||
// StaticDirectoryHandler is similar to `echo.StaticDirectoryHandler`
|
||||
// but without the directory redirect which conflicts with RemoveTrailingSlash middleware.
|
||||
// WrapStdHandler wraps Go [http.Handler] into a PocketBase handler func.
|
||||
func WrapStdHandler(h http.Handler) hook.HandlerFunc[*core.RequestEvent] {
|
||||
return func(e *core.RequestEvent) error {
|
||||
h.ServeHTTP(e.Response, e.Request)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WrapStdMiddleware wraps Go [func(http.Handler) http.Handle] into a PocketBase middleware func.
|
||||
func WrapStdMiddleware(m func(http.Handler) http.Handler) hook.HandlerFunc[*core.RequestEvent] {
|
||||
return func(e *core.RequestEvent) (err error) {
|
||||
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
e.Response = w
|
||||
e.Request = r
|
||||
err = e.Next()
|
||||
})).ServeHTTP(e.Response, e.Request)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// MustSubFS returns an [fs.FS] corresponding to the subtree rooted at fsys's dir.
|
||||
//
|
||||
// This is similar to [fs.Sub] but panics on failure.
|
||||
func MustSubFS(fsys fs.FS, dir string) fs.FS {
|
||||
dir = filepath.ToSlash(filepath.Clean(dir)) // ToSlash in case of Windows path
|
||||
|
||||
sub, err := fs.Sub(fsys, dir)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create sub FS: %w", err))
|
||||
}
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
// Static is a handler function to serve static directory content from fsys.
|
||||
//
|
||||
// If a file resource is missing and indexFallback is set, the request
|
||||
// will be forwarded to the base index.html (useful also for SPA).
|
||||
// will be forwarded to the base index.html (useful for SPA with pretty urls).
|
||||
//
|
||||
// @see https://github.com/labstack/echo/issues/2211
|
||||
func StaticDirectoryHandler(fileSystem fs.FS, indexFallback bool) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
p := c.PathParam("*")
|
||||
// NB! Expects the route to have a "{path...}" wildcard parameter.
|
||||
//
|
||||
// Special redirects:
|
||||
// - if "path" is a file that ends in index.html, it is redirected to its non-index.html version (eg. /test/index.html -> /test/)
|
||||
// - if "path" is a directory that has index.html, the index.html file is rendered,
|
||||
// otherwise if missing - returns 404 or fallback to the root index.html if indexFallback is set
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// fsys := os.DirFS("./pb_public")
|
||||
// router.GET("/files/{path...}", apis.Static(fsys, false))
|
||||
func Static(fsys fs.FS, indexFallback bool) hook.HandlerFunc[*core.RequestEvent] {
|
||||
if fsys == nil {
|
||||
panic("Static: the provided fs.FS argument is nil")
|
||||
}
|
||||
|
||||
// escape url path
|
||||
tmpPath, err := url.PathUnescape(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unescape path variable: %w", err)
|
||||
return func(e *core.RequestEvent) error {
|
||||
// disable the activity logger to avoid flooding with messages
|
||||
//
|
||||
// note: errors are still logged
|
||||
if e.Get(requestEventKeySkipSuccessActivityLog) == nil {
|
||||
e.Set(requestEventKeySkipSuccessActivityLog, true)
|
||||
}
|
||||
p = tmpPath
|
||||
|
||||
// fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
|
||||
name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
|
||||
filename := e.Request.PathValue(StaticWildcardParam)
|
||||
filename = filepath.ToSlash(filepath.Clean(strings.TrimPrefix(filename, "/")))
|
||||
|
||||
fileErr := c.FileFS(name, fileSystem)
|
||||
// eagerly check for directory traversal
|
||||
//
|
||||
// note: this is just out of an abundance of caution because the fs.FS implementation could be non-std,
|
||||
// but usually shouldn't be necessary since os.DirFS.Open is expected to fail if the filename starts with dots
|
||||
if len(filename) > 2 && filename[0] == '.' && filename[1] == '.' && (filename[2] == '/' || filename[2] == '\\') {
|
||||
if indexFallback && filename != router.IndexPage {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
return router.ErrFileNotFound
|
||||
}
|
||||
|
||||
if fileErr != nil && indexFallback && errors.Is(fileErr, echo.ErrNotFound) {
|
||||
return c.FileFS("index.html", fileSystem)
|
||||
fi, err := fs.Stat(fsys, filename)
|
||||
if err != nil {
|
||||
if indexFallback && filename != router.IndexPage {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
return router.ErrFileNotFound
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
// redirect to a canonical dir url, aka. with trailing slash
|
||||
if !strings.HasSuffix(e.Request.URL.Path, "/") {
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(e.Request.URL.Path+"/"))
|
||||
}
|
||||
} else {
|
||||
urlPath := e.Request.URL.Path
|
||||
if strings.HasSuffix(urlPath, "/") {
|
||||
// redirect to a non-trailing slash file route
|
||||
urlPath = strings.TrimRight(urlPath, "/")
|
||||
if len(urlPath) > 0 {
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(urlPath))
|
||||
}
|
||||
} else if stripped, ok := strings.CutSuffix(urlPath, router.IndexPage); ok {
|
||||
// redirect without the index.html
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(stripped))
|
||||
}
|
||||
}
|
||||
|
||||
fileErr := e.FileFS(fsys, filename)
|
||||
|
||||
if fileErr != nil && indexFallback && filename != router.IndexPage && errors.Is(fileErr, router.ErrFileNotFound) {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
|
||||
return fileErr
|
||||
}
|
||||
}
|
||||
|
||||
// bindStaticAdminUI registers the endpoints that serves the static admin UI.
|
||||
func bindStaticAdminUI(app core.App, e *echo.Echo) error {
|
||||
// redirect to trailing slash to ensure that relative urls will still work properly
|
||||
e.GET(
|
||||
strings.TrimRight(trailedAdminPath, "/"),
|
||||
func(c echo.Context) error {
|
||||
return c.Redirect(http.StatusTemporaryRedirect, strings.TrimLeft(trailedAdminPath, "/"))
|
||||
},
|
||||
)
|
||||
|
||||
// serves static files from the /ui/dist directory
|
||||
// (similar to echo.StaticFS but with gzip middleware enabled)
|
||||
e.GET(
|
||||
trailedAdminPath+"*",
|
||||
echo.StaticDirectoryHandler(ui.DistDirFS, false),
|
||||
installerRedirect(app),
|
||||
uiCacheControl(),
|
||||
middleware.Gzip(),
|
||||
)
|
||||
|
||||
return nil
|
||||
// safeRedirectPath normalizes the path string by replacing all beginning slashes
|
||||
// (`\\`, `//`, `\/`) with a single forward slash to prevent open redirect attacks
|
||||
func safeRedirectPath(path string) string {
|
||||
if len(path) > 1 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
|
||||
path = "/" + strings.TrimLeft(path, `/\`)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func uiCacheControl() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// add default Cache-Control header for all Admin UI resources
|
||||
// (ignoring the root admin path)
|
||||
if c.Request().URL.Path != trailedAdminPath {
|
||||
c.Response().Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
|
||||
}
|
||||
|
||||
return next(c)
|
||||
// FindUploadedFiles extracts all form files of "key" from a http request
|
||||
// and returns a slice with filesystem.File instances (if any).
|
||||
func FindUploadedFiles(r *http.Request, key string) ([]*filesystem.File, error) {
|
||||
if r.MultipartForm == nil {
|
||||
err := r.ParseMultipartForm(router.DefaultMaxMemory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasAdminsCacheKey = "@hasAdmins"
|
||||
|
||||
func updateHasAdminsCache(app core.App) error {
|
||||
total, err := app.Dao().TotalAdmins()
|
||||
if err != nil {
|
||||
return err
|
||||
if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 {
|
||||
return nil, http.ErrMissingFile
|
||||
}
|
||||
|
||||
app.Store().Set(hasAdminsCacheKey, total > 0)
|
||||
result := make([]*filesystem.File, 0, len(r.MultipartForm.File[key]))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// installerRedirect redirects the user to the installer admin UI page
|
||||
// when the application needs some preliminary configurations to be done.
|
||||
func installerRedirect(app core.App) echo.MiddlewareFunc {
|
||||
// keep hasAdminsCacheKey value up-to-date
|
||||
app.OnAdminAfterCreateRequest().Add(func(data *core.AdminCreateEvent) error {
|
||||
return updateHasAdminsCache(app)
|
||||
})
|
||||
|
||||
app.OnAdminAfterDeleteRequest().Add(func(data *core.AdminDeleteEvent) error {
|
||||
return updateHasAdminsCache(app)
|
||||
})
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// skip redirect checks for non-root level index.html requests
|
||||
path := c.Request().URL.Path
|
||||
if path != trailedAdminPath && path != trailedAdminPath+"index.html" {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
hasAdmins := cast.ToBool(app.Store().Get(hasAdminsCacheKey))
|
||||
|
||||
if !hasAdmins {
|
||||
// update the cache to make sure that the admin wasn't created by another process
|
||||
if err := updateHasAdminsCache(app); err != nil {
|
||||
return err
|
||||
}
|
||||
hasAdmins = cast.ToBool(app.Store().Get(hasAdminsCacheKey))
|
||||
}
|
||||
|
||||
_, hasInstallerParam := c.Request().URL.Query()["installer"]
|
||||
|
||||
if !hasAdmins && !hasInstallerParam {
|
||||
// redirect to the installer page
|
||||
return c.Redirect(http.StatusTemporaryRedirect, "?installer#")
|
||||
}
|
||||
|
||||
if hasAdmins && hasInstallerParam {
|
||||
// clear the installer param
|
||||
return c.Redirect(http.StatusTemporaryRedirect, "?")
|
||||
}
|
||||
|
||||
return next(c)
|
||||
for _, fh := range r.MultipartForm.File[key] {
|
||||
file, err := filesystem.NewFileFromMultipart(fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, file)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,422 +1,386 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func Test404(t *testing.T) {
|
||||
func TestWrapStdHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/missing",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/missing",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/missing",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Method: http.MethodDelete,
|
||||
Url: "/api/missing",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Method: http.MethodHead,
|
||||
Url: "/api/missing",
|
||||
ExpectedStatus: 404,
|
||||
},
|
||||
}
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
func TestCustomRoutesAndErrorsHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "custom route",
|
||||
Method: http.MethodGet,
|
||||
Url: "/custom",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, "test123")
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "custom route with url encoded parameter",
|
||||
Method: http.MethodGet,
|
||||
Url: "/a%2Bb%2Bc",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/:param",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, c.PathParam("param"))
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"a+b+c"},
|
||||
},
|
||||
{
|
||||
Name: "route with HTTPError",
|
||||
Method: http.MethodGet,
|
||||
Url: "/http-error",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/http-error",
|
||||
Handler: func(c echo.Context) error {
|
||||
return echo.ErrBadRequest
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`{"code":400,"message":"Bad Request.","data":{}}`},
|
||||
},
|
||||
{
|
||||
Name: "route with api error",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api-error",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/api-error",
|
||||
Handler: func(c echo.Context) error {
|
||||
return apis.NewApiError(500, "test message", errors.New("internal_test"))
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 500,
|
||||
ExpectedContent: []string{`{"code":500,"message":"Test message.","data":{}}`},
|
||||
},
|
||||
{
|
||||
Name: "route with plain error",
|
||||
Method: http.MethodGet,
|
||||
Url: "/plain-error",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/plain-error",
|
||||
Handler: func(c echo.Context) error {
|
||||
return errors.New("Test error")
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveTrailingSlashMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "non /api/* route (exact match)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/custom",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, "test123")
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "non /api/* route (with trailing slash)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/custom/",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, "test123")
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "/api/* route (exact match)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/custom",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, "test123")
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "/api/* route (with trailing slash)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/custom/",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
return c.String(200, "test123")
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiBinder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rawJson := `{"name":"test123"}`
|
||||
|
||||
formData, mp, err := tests.MockMultipartData(map[string]string{
|
||||
rest.MultipartJsonKey: rawJson,
|
||||
})
|
||||
err := apis.WrapStdHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}))(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "non-api group route",
|
||||
Method: "POST",
|
||||
Url: "/custom",
|
||||
Body: strings.NewReader(rawJson),
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: "POST",
|
||||
Path: "/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
data := &struct {
|
||||
Name string `json:"name"`
|
||||
}{}
|
||||
|
||||
if err := c.Bind(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// try to read the body again
|
||||
r := apis.RequestInfo(c)
|
||||
if v := cast.ToString(r.Data["name"]); v != "test123" {
|
||||
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
|
||||
}
|
||||
|
||||
return c.NoContent(200)
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
},
|
||||
{
|
||||
Name: "api group route",
|
||||
Method: "GET",
|
||||
Url: "/api/admins",
|
||||
Body: strings.NewReader(rawJson),
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// it is not important whether the route handler return an error since
|
||||
// we just need to ensure that the eagerRequestInfoCache was registered
|
||||
next(c)
|
||||
|
||||
// ensure that the body was read at least once
|
||||
data := &struct {
|
||||
Name string `json:"name"`
|
||||
}{}
|
||||
c.Bind(data)
|
||||
|
||||
// try to read the body again
|
||||
r := apis.RequestInfo(c)
|
||||
if v := cast.ToString(r.Data["name"]); v != "test123" {
|
||||
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
},
|
||||
{
|
||||
Name: "custom route with @jsonPayload as multipart body",
|
||||
Method: "POST",
|
||||
Url: "/custom",
|
||||
Body: formData,
|
||||
RequestHeaders: map[string]string{
|
||||
"Content-Type": mp.FormDataContentType(),
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.AddRoute(echo.Route{
|
||||
Method: "POST",
|
||||
Path: "/custom",
|
||||
Handler: func(c echo.Context) error {
|
||||
data := &struct {
|
||||
Name string `json:"name"`
|
||||
}{}
|
||||
|
||||
if err := c.Bind(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// try to read the body again
|
||||
r := apis.RequestInfo(c)
|
||||
if v := cast.ToString(r.Data["name"]); v != "test123" {
|
||||
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
|
||||
}
|
||||
|
||||
return c.NoContent(200)
|
||||
},
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
if body := rec.Body.String(); body != "test" {
|
||||
t.Fatalf("Expected body %q, got %q", "test", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler(t *testing.T) {
|
||||
func TestWrapStdMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.WrapStdMiddleware(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
})(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if body := rec.Body.String(); body != "test" {
|
||||
t.Fatalf("Expected body %q, got %q", "test", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fsys := os.DirFS(filepath.Join(dir, "sub"))
|
||||
|
||||
type staticScenario struct {
|
||||
path string
|
||||
indexFallback bool
|
||||
expectedStatus int
|
||||
expectBody string
|
||||
expectError bool
|
||||
}
|
||||
|
||||
scenarios := []staticScenario{
|
||||
{
|
||||
Name: "apis.ApiError",
|
||||
Method: http.MethodGet,
|
||||
Url: "/test",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return apis.NewApiError(418, "test", nil)
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 418,
|
||||
ExpectedContent: []string{`"message":"Test."`},
|
||||
path: "",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
Name: "wrapped apis.ApiError",
|
||||
Method: http.MethodGet,
|
||||
Url: "/test",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return fmt.Errorf("example 123: %w", apis.NewApiError(418, "test", nil))
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 418,
|
||||
ExpectedContent: []string{`"message":"Test."`},
|
||||
NotExpectedContent: []string{"example", "123"},
|
||||
path: "missing/a/b/c",
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
Name: "echo.HTTPError",
|
||||
Method: http.MethodGet,
|
||||
Url: "/test",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return echo.NewHTTPError(418, "test")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 418,
|
||||
ExpectedContent: []string{`"message":"Test."`},
|
||||
path: "missing/a/b/c",
|
||||
indexFallback: true,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
Name: "wrapped echo.HTTPError",
|
||||
Method: http.MethodGet,
|
||||
Url: "/test",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return fmt.Errorf("example 123: %w", echo.NewHTTPError(418, "test"))
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 418,
|
||||
ExpectedContent: []string{`"message":"Test."`},
|
||||
NotExpectedContent: []string{"example", "123"},
|
||||
path: "testroot", // parent directory file
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
Name: "wrapped sql.ErrNoRows",
|
||||
Method: http.MethodGet,
|
||||
Url: "/test",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return fmt.Errorf("example 123: %w", sql.ErrNoRows)
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
NotExpectedContent: []string{"example", "123"},
|
||||
path: "test",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub test",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
Name: "custom error",
|
||||
Method: http.MethodGet,
|
||||
Url: "/test",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return fmt.Errorf("example 123")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
NotExpectedContent: []string{"example", "123"},
|
||||
path: "sub2",
|
||||
indexFallback: false,
|
||||
expectedStatus: 301,
|
||||
expectBody: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub2 index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/test",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub2 test",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/test/",
|
||||
indexFallback: false,
|
||||
expectedStatus: 301,
|
||||
expectBody: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
// extra directory traversal checks
|
||||
dtp := []string{
|
||||
"/../",
|
||||
"\\../",
|
||||
"../",
|
||||
"../../",
|
||||
"..\\",
|
||||
"..\\..\\",
|
||||
"../..\\",
|
||||
"..\\..//",
|
||||
`%2e%2e%2f`,
|
||||
`%2e%2e%2f%2e%2e%2f`,
|
||||
`%2e%2e/`,
|
||||
`%2e%2e/%2e%2e/`,
|
||||
`..%2f`,
|
||||
`..%2f..%2f`,
|
||||
`%2e%2e%5c`,
|
||||
`%2e%2e%5c%2e%2e%5c`,
|
||||
`%2e%2e\`,
|
||||
`%2e%2e\%2e%2e\`,
|
||||
`..%5c`,
|
||||
`..%5c..%5c`,
|
||||
`%252e%252e%255c`,
|
||||
`%252e%252e%255c%252e%252e%255c`,
|
||||
`..%255c`,
|
||||
`..%255c..%255c`,
|
||||
}
|
||||
for _, p := range dtp {
|
||||
scenarios = append(scenarios,
|
||||
staticScenario{
|
||||
path: p + "testroot",
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
staticScenario{
|
||||
path: p + "testroot",
|
||||
indexFallback: true,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%v", i, s.path, s.indexFallback), func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+s.path, nil)
|
||||
req.SetPathValue(apis.StaticWildcardParam, s.path)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.Static(fsys, s.indexFallback)(e)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if body != s.expectBody {
|
||||
t.Fatalf("Expected body %q, got %q", s.expectBody, body)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
apiErr := router.ToApiError(err)
|
||||
if apiErr.Status != s.expectedStatus {
|
||||
t.Fatalf("Expected status code %d, got %d", s.expectedStatus, apiErr.Status)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUploadedFiles(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
filename string
|
||||
expectedPattern string
|
||||
}{
|
||||
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
|
||||
{"test", `^test_\w{10}\.txt$`},
|
||||
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
|
||||
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.filename, func(t *testing.T) {
|
||||
// create multipart form file body
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", s.filename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w.Write([]byte("test"))
|
||||
mp.Close()
|
||||
// ---
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := apis.FindUploadedFiles(req, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 file, got %d", len(result))
|
||||
}
|
||||
|
||||
if result[0].Size != 4 {
|
||||
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size)
|
||||
}
|
||||
|
||||
pattern, err := regexp.Compile(s.expectedPattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err)
|
||||
}
|
||||
if !pattern.MatchString(result[0].Name) {
|
||||
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUploadedFilesMissing(t *testing.T) {
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
mp.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := apis.FindUploadedFiles(req, "test")
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected result to be nil, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustSubFS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// invalid path (no beginning and ending slashes)
|
||||
if !hasPanicked(func() {
|
||||
apis.MustSubFS(os.DirFS(dir), "/test/")
|
||||
}) {
|
||||
t.Fatalf("Expected to panic")
|
||||
}
|
||||
|
||||
// valid path
|
||||
if hasPanicked(func() {
|
||||
apis.MustSubFS(os.DirFS(dir), "./////a/b/c") // checks if ToSlash was called
|
||||
}) {
|
||||
t.Fatalf("Didn't expect to panic")
|
||||
}
|
||||
|
||||
// check sub content
|
||||
sub := apis.MustSubFS(os.DirFS(dir), "sub")
|
||||
|
||||
_, err := sub.Open("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Missing expected file sub/test")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func hasPanicked(f func()) (didPanic bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
didPanic = true
|
||||
}
|
||||
}()
|
||||
f()
|
||||
return
|
||||
}
|
||||
|
||||
// note: make sure to call os.RemoveAll(dir) after you are done
|
||||
// working with the created test dir.
|
||||
func createTestDir(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "test_dir")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("root index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "testroot"), []byte("root test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "sub"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/index.html"), []byte("sub index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/test"), []byte("sub test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "sub", "sub2"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/index.html"), []byte("sub2 index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/test"), []byte("sub2 test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,542 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/batch")
|
||||
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
|
||||
}
|
||||
|
||||
type HandleFunc func(e *core.RequestEvent) error
|
||||
|
||||
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc
|
||||
|
||||
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
|
||||
//
|
||||
// Note: when adding new routes make sure that their middlewares are inlined!
|
||||
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
|
||||
// "upsert" handler
|
||||
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
|
||||
var id string
|
||||
if len(ir.Body) > 0 && ir.Body["id"] != "" {
|
||||
id = cast.ToString(ir.Body["id"])
|
||||
}
|
||||
if id != "" {
|
||||
_, err := app.FindRecordById(params["collection"], id)
|
||||
if err == nil {
|
||||
// update
|
||||
// ---
|
||||
params["id"] = id // required for the path value
|
||||
ir.Method = "PATCH"
|
||||
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
|
||||
return recordUpdate(next)
|
||||
}
|
||||
}
|
||||
|
||||
// create
|
||||
// ---
|
||||
ir.Method = "POST"
|
||||
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
|
||||
return recordCreate(next)
|
||||
},
|
||||
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
|
||||
return recordCreate(next)
|
||||
},
|
||||
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
|
||||
return recordUpdate(next)
|
||||
},
|
||||
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
|
||||
return recordDelete(next)
|
||||
},
|
||||
}
|
||||
|
||||
type BatchRequestResult struct {
|
||||
Body any `json:"body"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type batchRequestsForm struct {
|
||||
Requests []*core.InternalRequest `form:"requests" json:"requests"`
|
||||
|
||||
max int
|
||||
}
|
||||
|
||||
func (brs batchRequestsForm) validate() error {
|
||||
return validation.ValidateStruct(&brs,
|
||||
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
|
||||
)
|
||||
}
|
||||
|
||||
// NB! When the request is submitted as multipart/form-data,
|
||||
// the regular fields data is expected to be submitted as serailized
|
||||
// json under the @jsonPayload field and file keys need to follow the
|
||||
// pattern "requests.N.fileField" or requests[N].fileField.
|
||||
func batchTransaction(e *core.RequestEvent) error {
|
||||
maxRequests := e.App.Settings().Batch.MaxRequests
|
||||
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
|
||||
return e.ForbiddenError("Batch requests are not allowed.", nil)
|
||||
}
|
||||
|
||||
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
|
||||
if txTimeout <= 0 {
|
||||
txTimeout = 3 * time.Second // for now always limit
|
||||
}
|
||||
|
||||
maxBodySize := e.App.Settings().Batch.MaxBodySize
|
||||
if maxBodySize <= 0 {
|
||||
maxBodySize = 128 << 20
|
||||
}
|
||||
|
||||
err := applyBodyLimit(e, maxBodySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
form := &batchRequestsForm{max: maxRequests}
|
||||
|
||||
// load base requests data
|
||||
err = e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to read the submitted batch data.", err)
|
||||
}
|
||||
|
||||
// load uploaded files into each request item
|
||||
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
|
||||
// (the other regular fields must be put under `@jsonPayload` as serialized json)
|
||||
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
for i, ir := range form.Requests {
|
||||
iStr := strconv.Itoa(i)
|
||||
|
||||
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to read the submitted batch files data.", err)
|
||||
}
|
||||
|
||||
for key, files := range files {
|
||||
if ir.Body == nil {
|
||||
ir.Body = map[string]any{}
|
||||
}
|
||||
ir.Body[key] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate batch request form
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid batch request data.", err)
|
||||
}
|
||||
|
||||
event := new(core.BatchRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Batch = form.Requests
|
||||
|
||||
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
|
||||
bp := batchProcessor{
|
||||
app: e.App,
|
||||
baseEvent: e.RequestEvent,
|
||||
infoContext: core.RequestInfoContextBatch,
|
||||
}
|
||||
|
||||
if err := bp.Process(e.Batch, txTimeout); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, bp.results)
|
||||
})
|
||||
}
|
||||
|
||||
type batchProcessor struct {
|
||||
app core.App
|
||||
baseEvent *core.RequestEvent
|
||||
infoContext string
|
||||
results []*BatchRequestResult
|
||||
failedIndex int
|
||||
errCh chan error
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
|
||||
p.results = make([]*BatchRequestResult, 0, len(batch))
|
||||
|
||||
if p.stopCh != nil {
|
||||
close(p.stopCh)
|
||||
}
|
||||
p.stopCh = make(chan struct{}, 1)
|
||||
|
||||
if p.errCh != nil {
|
||||
close(p.errCh)
|
||||
}
|
||||
p.errCh = make(chan error, 1)
|
||||
|
||||
return p.app.RunInTransaction(func(txApp core.App) error {
|
||||
// used to interupts the recursive processing calls in case of a timeout or connection close
|
||||
defer func() {
|
||||
p.stopCh <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := p.process(txApp, batch, 0)
|
||||
|
||||
if err != nil {
|
||||
err = validation.Errors{
|
||||
"requests": validation.Errors{
|
||||
strconv.Itoa(p.failedIndex): &BatchResponseError{
|
||||
code: "batch_request_failed",
|
||||
message: "Batch request failed.",
|
||||
err: router.ToApiError(err),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// note: to avoid copying and due to the process recursion the final results order is reversed
|
||||
if err == nil {
|
||||
slices.Reverse(p.results)
|
||||
}
|
||||
|
||||
p.errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case responseErr := <-p.errCh:
|
||||
return responseErr
|
||||
case <-time.After(timeout):
|
||||
// note: we don't return 408 Reques Timeout error because
|
||||
// some browsers perform automatic retry behind the scenes
|
||||
// which are hard to debug and unnecessary
|
||||
return errors.New("batch transaction timeout")
|
||||
case <-p.baseEvent.Request.Context().Done():
|
||||
return errors.New("batch request interrupted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return nil
|
||||
default:
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := processInternalRequest(
|
||||
activeApp,
|
||||
p.baseEvent,
|
||||
batch[0],
|
||||
p.infoContext,
|
||||
func() error {
|
||||
if len(batch) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.process(activeApp, batch[1:], i+1)
|
||||
|
||||
// update the failed batch index (if not already)
|
||||
if err != nil && p.failedIndex == 0 {
|
||||
p.failedIndex = i + 1
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.results = append(p.results, result)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func processInternalRequest(
|
||||
activeApp core.App,
|
||||
baseEvent *core.RequestEvent,
|
||||
ir *core.InternalRequest,
|
||||
infoContext string,
|
||||
optNext func() error,
|
||||
) (*BatchRequestResult, error) {
|
||||
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
|
||||
if !ok {
|
||||
return nil, errors.New("unknown batch request action")
|
||||
}
|
||||
|
||||
// construct a new http.Request
|
||||
// ---------------------------------------------------------------
|
||||
buf, mw, err := multipartDataFromInternalRequest(ir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// cleanup multipart temp files
|
||||
defer func() {
|
||||
if r.MultipartForm != nil {
|
||||
if err := r.MultipartForm.RemoveAll(); err != nil {
|
||||
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// load batch request path params
|
||||
// ---
|
||||
for k, v := range params {
|
||||
r.SetPathValue(k, v)
|
||||
}
|
||||
|
||||
// clone original request
|
||||
// ---
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.Proto = baseEvent.Request.Proto
|
||||
r.ProtoMajor = baseEvent.Request.ProtoMajor
|
||||
r.ProtoMinor = baseEvent.Request.ProtoMinor
|
||||
r.Host = baseEvent.Request.Host
|
||||
r.RemoteAddr = baseEvent.Request.RemoteAddr
|
||||
r.TLS = baseEvent.Request.TLS
|
||||
|
||||
if s := baseEvent.Request.TransferEncoding; s != nil {
|
||||
s2 := make([]string, len(s))
|
||||
copy(s2, s)
|
||||
r.TransferEncoding = s2
|
||||
}
|
||||
|
||||
if baseEvent.Request.Trailer != nil {
|
||||
r.Trailer = baseEvent.Request.Trailer.Clone()
|
||||
}
|
||||
|
||||
if baseEvent.Request.Header != nil {
|
||||
r.Header = baseEvent.Request.Header.Clone()
|
||||
}
|
||||
|
||||
// apply batch request specific headers
|
||||
// ---
|
||||
for k, v := range ir.Headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
// construct a new RequestEvent
|
||||
// ---------------------------------------------------------------
|
||||
event := &core.RequestEvent{}
|
||||
event.App = activeApp
|
||||
event.Auth = baseEvent.Auth
|
||||
event.SetAll(baseEvent.GetAll())
|
||||
|
||||
// load RequestInfo context
|
||||
if infoContext == "" {
|
||||
infoContext = core.RequestInfoContextDefault
|
||||
}
|
||||
event.Set(core.RequestEventKeyInfoContext, infoContext)
|
||||
|
||||
// assign request
|
||||
event.Request = r
|
||||
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
|
||||
|
||||
// assign response
|
||||
rec := httptest.NewRecorder()
|
||||
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
|
||||
|
||||
// execute
|
||||
// ---------------------------------------------------------------
|
||||
if err := handle(event); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
|
||||
|
||||
return &BatchRequestResult{
|
||||
Status: result.StatusCode,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
mw := multipart.NewWriter(buf)
|
||||
|
||||
regularFields := map[string]any{}
|
||||
fileFields := map[string][]*filesystem.File{}
|
||||
|
||||
// separate regular fields from files
|
||||
// ---
|
||||
for k, rawV := range ir.Body {
|
||||
switch v := rawV.(type) {
|
||||
case *filesystem.File:
|
||||
fileFields[k] = append(fileFields[k], v)
|
||||
case []*filesystem.File:
|
||||
fileFields[k] = append(fileFields[k], v...)
|
||||
default:
|
||||
regularFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// submit regularFields as @jsonPayload
|
||||
// ---
|
||||
rawBody, err := json.Marshal(regularFields)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
jsonPayload, err := mw.CreateFormField("@jsonPayload")
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
_, err = jsonPayload.Write(rawBody)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
// submit fileFields as multipart files
|
||||
// ---
|
||||
for key, files := range fileFields {
|
||||
for _, file := range files {
|
||||
part, err := mw.CreateFormFile(key, file.Name)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
fr, err := file.Reader.Open()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, fr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
|
||||
}
|
||||
|
||||
err = fr.Close()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, mw, mw.Close()
|
||||
}
|
||||
|
||||
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
|
||||
if request.MultipartForm == nil {
|
||||
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
result := make(map[string][]*filesystem.File)
|
||||
|
||||
for k, fhs := range request.MultipartForm.File {
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(k, p) {
|
||||
resultKey := strings.TrimPrefix(k, p)
|
||||
|
||||
for _, fh := range fhs {
|
||||
file, err := filesystem.NewFileFromMultipart(fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result[resultKey] = append(result[resultKey], file)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func() error) (HandleFunc, map[string]string, bool) {
|
||||
full := strings.ToUpper(ir.Method) + " " + ir.URL
|
||||
|
||||
for re, actionFactory := range ValidBatchActions {
|
||||
params, ok := findNamedMatches(re, full)
|
||||
if ok {
|
||||
return actionFactory(activeApp, ir, params, optNext), params, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
|
||||
match := re.FindStringSubmatch(str)
|
||||
if match == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
result := map[string]string{}
|
||||
|
||||
names := re.SubexpNames()
|
||||
|
||||
for i, m := range match {
|
||||
if names[i] != "" {
|
||||
result[names[i]] = m
|
||||
}
|
||||
}
|
||||
|
||||
return result, true
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
_ router.SafeErrorItem = (*BatchResponseError)(nil)
|
||||
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
|
||||
)
|
||||
|
||||
type BatchResponseError struct {
|
||||
err *router.ApiError
|
||||
code string
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Code() string {
|
||||
return e.code
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Resolve(errData map[string]any) any {
|
||||
errData["response"] = e.err
|
||||
return errData
|
||||
}
|
||||
|
||||
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(map[string]any{
|
||||
"message": e.message,
|
||||
"code": e.code,
|
||||
"response": e.err,
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,691 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestBatchRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
formData, mp, err := tests.MockMultipartData(
|
||||
map[string]string{
|
||||
router.JSONPayloadKey: `{
|
||||
"requests":[
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch2"}},
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch3"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/lcl9d87w22ml6jy", "body": {"files-": "test_FLurQTgrY8.txt"}}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
"requests.0.files",
|
||||
"requests.0.files",
|
||||
"requests.0.files",
|
||||
"requests[2].files",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "disabled batch requets",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = false
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "max request limits reached",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"GET", "url":"/test1"},
|
||||
{"method":"GET", "url":"/test2"},
|
||||
{"method":"GET", "url":"/test3"}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = true
|
||||
app.Settings().Batch.MaxRequests = 2
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "trigger requests validations",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{},
|
||||
{"method":"GET", "url":"/valid"},
|
||||
{"method":"invalid", "url":"/valid"},
|
||||
{"method":"POST", "url":"` + strings.Repeat("a", 2001) + `"}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = true
|
||||
app.Settings().Batch.MaxRequests = 100
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`"0":{"method":{"code":"validation_required"`,
|
||||
`"2":{"method":{"code":"validation_in_invalid"`,
|
||||
`"3":{"url":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unknown batch request action",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"GET", "url":"/api/health"}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`0":{"code":"batch_request_failed"`,
|
||||
`"response":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "base 2 successful and 1 failed (public collection)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": ""}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"response":{`,
|
||||
`"2":{"code":"batch_request_failed"`,
|
||||
`"response":{"data":{"title":{"code":"validation_required"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"0":`,
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 2,
|
||||
"OnModelAfterCreateError": 3,
|
||||
"OnModelValidate": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 2,
|
||||
"OnRecordAfterCreateError": 3,
|
||||
"OnRecordValidate": 3,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("Expected no batch records to be persisted, got %d", len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "base 4 successful (public collection)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
|
||||
{"method":"PUT", "url":"/api/collections/demo2/records", "body": {"title": "batch3"}},
|
||||
{"method":"PUT", "url":"/api/collections/demo2/records?fields=*,id:excerpt(4,true)", "body": {"id":"achvryl401bhse3","title": "batch4"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch1"`,
|
||||
`"title":"batch2"`,
|
||||
`"title":"batch3"`,
|
||||
`"title":"batch4"`,
|
||||
`"id":"achv..."`,
|
||||
`"active":false`,
|
||||
`"active":true`,
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnModelValidate": 4,
|
||||
"OnRecordValidate": 4,
|
||||
"OnRecordEnrich": 4,
|
||||
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 4 {
|
||||
t.Fatalf("Expected %d batch records to be persisted, got %d", 3, len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (rules failure)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`"2":{"code":"batch_request_failed"`,
|
||||
`"response":{`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// only demo3 requires authentication
|
||||
`"0":`,
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateError": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteError": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to not be created")
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to not be updated")
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err != nil {
|
||||
t.Fatal("Expected record to not be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (rules success)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, clients
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch_create"`,
|
||||
`"title":"batch_update"`,
|
||||
`"status":200`,
|
||||
`"status":204`,
|
||||
`"body":{`,
|
||||
`"body":null`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (superuser auth)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch_create"`,
|
||||
`"title":"batch_update"`,
|
||||
`"status":200`,
|
||||
`"status":204`,
|
||||
`"body":{`,
|
||||
`"body":null`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cascade delete/update",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"DELETE", "url":"/api/collections/demo3/records/1tmknxy2868d869"},
|
||||
{"method":"DELETE", "url":"/api/collections/demo3/records/mk5fmymtx4wsprk"}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"status":204`,
|
||||
`"body":null`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelDelete": 3, // 2 batch + 1 cascade delete
|
||||
"OnModelDeleteExecute": 3,
|
||||
"OnModelAfterDeleteSuccess": 3,
|
||||
"OnModelUpdate": 5, // 5 cascade update
|
||||
"OnModelUpdateExecute": 5,
|
||||
"OnModelAfterUpdateSuccess": 5,
|
||||
// ---
|
||||
"OnRecordDeleteRequest": 2,
|
||||
"OnRecordDelete": 3,
|
||||
"OnRecordDeleteExecute": 3,
|
||||
"OnRecordAfterDeleteSuccess": 3,
|
||||
"OnRecordUpdate": 5,
|
||||
"OnRecordUpdateExecute": 5,
|
||||
"OnRecordAfterUpdateSuccess": 5,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ids := []string{
|
||||
"1tmknxy2868d869",
|
||||
"mk5fmymtx4wsprk",
|
||||
"qzaqccwrmva4o1n",
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
_, err := app.FindRecordById("demo2", id)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected record %q to be deleted", id)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "transaction timeout",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
// test@example.com, superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Timeout = 1
|
||||
app.OnRecordCreateRequest("demo2").BindFunc(func(e *core.RecordRequestEvent) error {
|
||||
time.Sleep(600 * time.Millisecond) // < 1s so that the first request can succeed
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnRecordCreateRequest": 2,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateError": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("Expected %d batch records to be persisted, got %d", 0, len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "multipart/form-data + file upload",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: formData,
|
||||
Headers: map[string]string{
|
||||
// test@example.com, clients
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
"Content-Type": mp.FormDataContentType(),
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch1"`,
|
||||
`"title":"batch2"`,
|
||||
`"title":"batch3"`,
|
||||
`"id":"lcl9d87w22ml6jy"`,
|
||||
`"files":["300_UhLKX91HVb.png"]`,
|
||||
`"tmpfile_`,
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 4,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 4,
|
||||
"OnRecordEnrich": 4,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
batch1, err := app.FindFirstRecordByFilter("demo3", `title="batch1"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch1: %v", err)
|
||||
}
|
||||
batch1Files := batch1.GetStringSlice("files")
|
||||
if len(batch1Files) != 3 {
|
||||
t.Fatalf("Expected %d batch1 file(s), got %d", 3, len(batch1Files))
|
||||
}
|
||||
|
||||
batch2, err := app.FindFirstRecordByFilter("demo3", `title="batch2"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch2: %v", err)
|
||||
}
|
||||
batch2Files := batch2.GetStringSlice("files")
|
||||
if len(batch2Files) != 0 {
|
||||
t.Fatalf("Expected %d batch2 file(s), got %d", 0, len(batch2Files))
|
||||
}
|
||||
|
||||
batch3, err := app.FindFirstRecordByFilter("demo3", `title="batch3"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch3: %v", err)
|
||||
}
|
||||
batch3Files := batch3.GetStringSlice("files")
|
||||
if len(batch3Files) != 1 {
|
||||
t.Fatalf("Expected %d batch3 file(s), got %d", 1, len(batch3Files))
|
||||
}
|
||||
|
||||
batch4, err := app.FindRecordById("demo3", "lcl9d87w22ml6jy")
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch4: %v", err)
|
||||
}
|
||||
batch4Files := batch4.GetStringSlice("files")
|
||||
if len(batch4Files) != 1 {
|
||||
t.Fatalf("Expected %d batch4 file(s), got %d", 1, len(batch4Files))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "create/update with expand query params",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"body":{`,
|
||||
`"id":"qjeql998mtp1azp"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"expand":{"rel_one"`,
|
||||
`"expand":{"rel_many"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "check body limit middleware",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.MaxBodySize = 10
|
||||
},
|
||||
ExpectedStatus: 413,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,210 +1,186 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindCollectionApi registers the collection api endpoints and the corresponding handlers.
|
||||
func bindCollectionApi(app core.App, rg *echo.Group) {
|
||||
api := collectionApi{app: app}
|
||||
|
||||
subGroup := rg.Group("/collections", ActivityLogger(app), RequireAdminAuth())
|
||||
subGroup.GET("", api.list)
|
||||
subGroup.POST("", api.create)
|
||||
subGroup.GET("/:collection", api.view)
|
||||
subGroup.PATCH("/:collection", api.update)
|
||||
subGroup.DELETE("/:collection", api.delete)
|
||||
subGroup.PUT("/import", api.bulkImport)
|
||||
func bindCollectionApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/collections").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", collectionsList)
|
||||
subGroup.POST("", collectionCreate)
|
||||
subGroup.GET("/{collection}", collectionView)
|
||||
subGroup.PATCH("/{collection}", collectionUpdate)
|
||||
subGroup.DELETE("/{collection}", collectionDelete)
|
||||
subGroup.DELETE("/{collection}/truncate", collectionTruncate)
|
||||
subGroup.PUT("/import", collectionsImport)
|
||||
subGroup.GET("/meta/scaffolds", collectionScaffolds)
|
||||
}
|
||||
|
||||
type collectionApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (api *collectionApi) list(c echo.Context) error {
|
||||
func collectionsList(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(
|
||||
"id", "created", "updated", "name", "system", "type",
|
||||
)
|
||||
|
||||
collections := []*models.Collection{}
|
||||
collections := []*core.Collection{}
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(api.app.Dao().CollectionQuery()).
|
||||
ParseAndExec(c.QueryParams().Encode(), &collections)
|
||||
Query(e.App.CollectionQuery()).
|
||||
ParseAndExec(e.Request.URL.Query().Encode(), &collections)
|
||||
|
||||
if err != nil {
|
||||
return NewBadRequestError("", err)
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionsListEvent)
|
||||
event.HttpContext = c
|
||||
event := new(core.CollectionsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collections = collections
|
||||
event.Result = result
|
||||
|
||||
return api.app.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Result)
|
||||
return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
|
||||
return e.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *collectionApi) view(c echo.Context) error {
|
||||
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
|
||||
func collectionView(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return NewNotFoundError("", err)
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionViewEvent)
|
||||
event.HttpContext = c
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return api.app.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionViewEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Collection)
|
||||
return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *collectionApi) create(c echo.Context) error {
|
||||
collection := &models.Collection{}
|
||||
|
||||
form := forms.NewCollectionUpsert(api.app, collection)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
func collectionCreate(e *core.RequestEvent) error {
|
||||
// populate the minimal required factory collection data (if any)
|
||||
factoryExtract := struct {
|
||||
Type string `form:"type" json:"type"`
|
||||
Name string `form:"name" json:"name"`
|
||||
}{}
|
||||
if err := e.BindBody(&factoryExtract); err != nil {
|
||||
return e.BadRequestError("Failed to load the collection type data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionCreateEvent)
|
||||
event.HttpContext = c
|
||||
// create scaffold
|
||||
collection := core.NewCollection(factoryExtract.Type, factoryExtract.Name)
|
||||
|
||||
// merge the scaffold with the submitted request data
|
||||
if err := e.BindBody(collection); err != nil {
|
||||
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
// create the collection
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] {
|
||||
return func(m *models.Collection) error {
|
||||
event.Collection = m
|
||||
|
||||
return api.app.OnCollectionBeforeCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error {
|
||||
if err := next(e.Collection); err != nil {
|
||||
return NewBadRequestError("Failed to create the collection.", err)
|
||||
}
|
||||
|
||||
return api.app.OnCollectionAfterCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *collectionApi) update(c echo.Context) error {
|
||||
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
form := forms.NewCollectionUpsert(api.app, collection)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionUpdateEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
// update the collection
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] {
|
||||
return func(m *models.Collection) error {
|
||||
event.Collection = m
|
||||
|
||||
return api.app.OnCollectionBeforeUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error {
|
||||
if err := next(e.Collection); err != nil {
|
||||
return NewBadRequestError("Failed to update the collection.", err)
|
||||
}
|
||||
|
||||
return api.app.OnCollectionAfterUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *collectionApi) delete(c echo.Context) error {
|
||||
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionDeleteEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
return api.app.OnCollectionBeforeDeleteRequest().Trigger(event, func(e *core.CollectionDeleteEvent) error {
|
||||
if err := api.app.Dao().DeleteCollection(e.Collection); err != nil {
|
||||
return NewBadRequestError("Failed to delete collection due to existing dependency.", err)
|
||||
}
|
||||
|
||||
return api.app.OnCollectionAfterDeleteRequest().Trigger(event, func(e *core.CollectionDeleteEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
return e.App.OnCollectionCreateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Save(e.Collection); err != nil {
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to create collection.", validationErrors)
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
// other generic db error
|
||||
return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *collectionApi) bulkImport(c echo.Context) error {
|
||||
form := forms.NewCollectionsImport(api.app)
|
||||
|
||||
// load request data
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
func collectionUpdate(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionsImportEvent)
|
||||
event.HttpContext = c
|
||||
event.Collections = form.Collections
|
||||
if err := e.BindBody(collection); err != nil {
|
||||
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
// import collections
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[[]*models.Collection]) forms.InterceptorNextFunc[[]*models.Collection] {
|
||||
return func(imports []*models.Collection) error {
|
||||
event.Collections = imports
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return api.app.OnCollectionsBeforeImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error {
|
||||
if err := next(e.Collections); err != nil {
|
||||
return NewBadRequestError("Failed to import the submitted collections.", err)
|
||||
}
|
||||
return event.App.OnCollectionUpdateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Save(e.Collection); err != nil {
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to update collection.", validationErrors)
|
||||
}
|
||||
|
||||
return api.app.OnCollectionsAfterImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
// other generic db error
|
||||
return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
}
|
||||
|
||||
func collectionDelete(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionDeleteRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Delete(e.Collection); err != nil {
|
||||
msg := "Failed to delete collection"
|
||||
|
||||
// check fo references
|
||||
refs, _ := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
|
||||
if len(refs) > 0 {
|
||||
names := make([]string, 0, len(refs))
|
||||
for ref := range refs {
|
||||
names = append(names, ref.Name)
|
||||
}
|
||||
msg += " probably due to existing reference in " + strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
return e.BadRequestError(msg, err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
func collectionTruncate(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
err = e.App.TruncateCollection(collection)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to truncate collection (most likely due to required cascade delete record references).", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func collectionScaffolds(e *core.RequestEvent) error {
|
||||
return e.JSON(http.StatusOK, map[string]*core.Collection{
|
||||
core.CollectionTypeBase: core.NewBaseCollection(""),
|
||||
core.CollectionTypeAuth: core.NewAuthCollection(""),
|
||||
core.CollectionTypeView: core.NewViewCollection(""),
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func collectionsImport(e *core.RequestEvent) error {
|
||||
form := new(collectionsImportForm)
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.CollectionsImportRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.CollectionsData = form.Collections
|
||||
event.DeleteMissing = form.DeleteMissing
|
||||
|
||||
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
|
||||
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
|
||||
if importErr == nil {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to import collections.", validationErrors)
|
||||
}
|
||||
|
||||
// generic/db failure
|
||||
return e.BadRequestError("Failed to import collections.", validation.Errors{"collections": validation.NewError(
|
||||
"validation_collections_import_failure",
|
||||
"Failed to import the collections configuration. Raw error:\n"+importErr.Error(),
|
||||
)})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type collectionsImportForm struct {
|
||||
Collections []map[string]any `form:"collections" json:"collections"`
|
||||
DeleteMissing bool `form:"deleteMissing" json:"deleteMissing"`
|
||||
}
|
||||
|
||||
func (form *collectionsImportForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Collections, validation.Required),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestCollectionsImport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
totalCollections := 16
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + empty collections",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{"collections":[]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + collections validator failure",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{"name": "import1"},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "expand",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_collections_import_failure"`,
|
||||
`import2`,
|
||||
`fields`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 2,
|
||||
"OnCollectionCreateExecute": 2,
|
||||
"OnCollectionAfterCreateError": 2,
|
||||
"OnModelCreate": 2,
|
||||
"OnModelCreateExecute": 2,
|
||||
"OnModelAfterCreateError": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + successful collections create",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"indexes": [
|
||||
"create index idx_test on import2 (test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "auth_without_fields",
|
||||
"type": "auth"
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 3,
|
||||
"OnCollectionCreateExecute": 3,
|
||||
"OnCollectionAfterCreateSuccess": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := totalCollections + 3
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
|
||||
indexes, err := app.TableIndexes("import2")
|
||||
if err != nil || indexes["idx_test"] == "" {
|
||||
t.Fatalf("Missing index %s (%v)", "idx_test", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + create/update/delete",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"deleteMissing": true,
|
||||
"collections":[
|
||||
{"name": "test123"},
|
||||
{
|
||||
"id":"wsmn24bux7wo113",
|
||||
"name":"demo1",
|
||||
"fields":[
|
||||
{
|
||||
"id":"_2hlxbmp",
|
||||
"name":"title",
|
||||
"type":"text",
|
||||
"required":true
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnCollectionCreate": 1,
|
||||
"OnCollectionCreateExecute": 1,
|
||||
"OnCollectionAfterCreateSuccess": 1,
|
||||
// ---
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnCollectionUpdate": 1,
|
||||
"OnCollectionUpdateExecute": 1,
|
||||
"OnCollectionAfterUpdateSuccess": 1,
|
||||
// ---
|
||||
"OnModelDelete": 14,
|
||||
"OnModelAfterDeleteSuccess": 14,
|
||||
"OnModelDeleteExecute": 14,
|
||||
"OnCollectionDelete": 9,
|
||||
"OnCollectionDeleteExecute": 9,
|
||||
"OnCollectionAfterDeleteSuccess": 9,
|
||||
"OnRecordAfterDeleteSuccess": 5,
|
||||
"OnRecordDelete": 5,
|
||||
"OnRecordDeleteExecute": 5,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
systemCollections := 0
|
||||
for _, c := range collections {
|
||||
if c.System {
|
||||
systemCollections++
|
||||
}
|
||||
}
|
||||
|
||||
expected := systemCollections + 2
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,138 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
const installerParam = "pbinstal"
|
||||
|
||||
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
|
||||
|
||||
func stripWildcard(pattern string) string {
|
||||
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "")
|
||||
}
|
||||
|
||||
// installerRedirect redirects the user to the installer dashboard UI page
|
||||
// when the application needs some preliminary configurations to be done.
|
||||
func installerRedirect(app core.App, cpPath string) hook.HandlerFunc[*core.RequestEvent] {
|
||||
// note: to avoid locks contention it is not concurrent safe but it
|
||||
// is expected to be updated only once during initialization
|
||||
var hasSuperuser bool
|
||||
|
||||
// strip named wildcard
|
||||
cpPath = stripWildcard(cpPath)
|
||||
|
||||
updateHasSuperuser := func(app core.App) error {
|
||||
total, err := app.CountRecords(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hasSuperuser = total > 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// load initial state on app init
|
||||
app.OnBootstrap().BindFunc(func(e *core.BootstrapEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = updateHasSuperuser(e.App)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for existing superuser: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// update on superuser create
|
||||
app.OnRecordCreateRequest(core.CollectionNameSuperusers).BindFunc(func(e *core.RecordRequestEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !hasSuperuser {
|
||||
hasSuperuser = true
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return func(e *core.RequestEvent) error {
|
||||
if hasSuperuser {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
isAPI := strings.HasPrefix(e.Request.URL.Path, "/api/")
|
||||
isControlPanel := strings.HasPrefix(e.Request.URL.Path, cpPath)
|
||||
wildcard := e.Request.PathValue(StaticWildcardParam)
|
||||
|
||||
// skip redirect checks for API and non-root level dashboard index.html requests (css, images, etc.)
|
||||
if isAPI || (isControlPanel && wildcard != "" && wildcard != router.IndexPage) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// check again in case the superuser was created by some other process
|
||||
if err := updateHasSuperuser(e.App); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hasSuperuser {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
_, hasInstallerParam := e.Request.URL.Query()[installerParam]
|
||||
|
||||
// redirect to the installer page
|
||||
if !hasInstallerParam {
|
||||
return e.Redirect(http.StatusTemporaryRedirect, cpPath+"?"+installerParam+"#")
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// dashboardRemoveInstallerParam redirects to a non-installer
|
||||
// query param in case there is already a superuser created.
|
||||
//
|
||||
// Note: intended to be registered only for the dashboard route
|
||||
// to prevent excessive checks for every other route in installerRedirect.
|
||||
func dashboardRemoveInstallerParam() hook.HandlerFunc[*core.RequestEvent] {
|
||||
return func(e *core.RequestEvent) error {
|
||||
_, hasInstallerParam := e.Request.URL.Query()[installerParam]
|
||||
if !hasInstallerParam {
|
||||
return e.Next() // nothing to remove
|
||||
}
|
||||
|
||||
// clear installer param
|
||||
total, _ := e.App.CountRecords(core.CollectionNameSuperusers)
|
||||
if total > 0 {
|
||||
return e.Redirect(http.StatusTemporaryRedirect, "?")
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// dashboardCacheControl adds default Cache-Control header for all
|
||||
// dashboard UI resources (ignoring the root index.html path)
|
||||
func dashboardCacheControl() hook.HandlerFunc[*core.RequestEvent] {
|
||||
return func(e *core.RequestEvent) error {
|
||||
if e.Request.PathValue(StaticWildcardParam) != "" {
|
||||
e.Response.Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
190
apis/file.go
190
apis/file.go
|
|
@ -7,18 +7,12 @@ import (
|
|||
"log/slog"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tokens"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
|
@ -27,23 +21,19 @@ var imageContentTypes = []string{"image/png", "image/jpg", "image/jpeg", "image/
|
|||
var defaultThumbSizes = []string{"100x100"}
|
||||
|
||||
// bindFileApi registers the file api endpoints and the corresponding handlers.
|
||||
func bindFileApi(app core.App, rg *echo.Group) {
|
||||
func bindFileApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
api := fileApi{
|
||||
app: app,
|
||||
thumbGenSem: semaphore.NewWeighted(int64(runtime.NumCPU() + 2)), // the value is arbitrary chosen and may change in the future
|
||||
thumbGenPending: new(singleflight.Group),
|
||||
thumbGenMaxWait: 60 * time.Second,
|
||||
}
|
||||
|
||||
subGroup := rg.Group("/files", ActivityLogger(app))
|
||||
subGroup.POST("/token", api.fileToken)
|
||||
subGroup.HEAD("/:collection/:recordId/:filename", api.download, LoadCollectionContext(api.app))
|
||||
subGroup.GET("/:collection/:recordId/:filename", api.download, LoadCollectionContext(api.app))
|
||||
sub := rg.Group("/files")
|
||||
sub.POST("/token", api.fileToken).Bind(RequireAuth())
|
||||
sub.GET("/{collection}/{recordId}/{filename}", api.download).Bind(collectionPathRateLimit("", "file"))
|
||||
}
|
||||
|
||||
type fileApi struct {
|
||||
app core.App
|
||||
|
||||
// thumbGenSem is a semaphore to prevent too much concurrent
|
||||
// requests generating new thumbs at the same time.
|
||||
thumbGenSem *semaphore.Weighted
|
||||
|
|
@ -57,84 +47,67 @@ type fileApi struct {
|
|||
thumbGenMaxWait time.Duration
|
||||
}
|
||||
|
||||
func (api *fileApi) fileToken(c echo.Context) error {
|
||||
event := new(core.FileTokenEvent)
|
||||
event.HttpContext = c
|
||||
|
||||
if admin, _ := c.Get(ContextAdminKey).(*models.Admin); admin != nil {
|
||||
event.Model = admin
|
||||
event.Token, _ = tokens.NewAdminFileToken(api.app, admin)
|
||||
} else if record, _ := c.Get(ContextAuthRecordKey).(*models.Record); record != nil {
|
||||
event.Model = record
|
||||
event.Token, _ = tokens.NewRecordFileToken(api.app, record)
|
||||
func (api *fileApi) fileToken(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("Missing auth context.", nil)
|
||||
}
|
||||
|
||||
return api.app.OnFileBeforeTokenRequest().Trigger(event, func(e *core.FileTokenEvent) error {
|
||||
if e.Model == nil || e.Token == "" {
|
||||
return NewBadRequestError("Failed to generate file token.", nil)
|
||||
}
|
||||
token, err := e.Auth.NewFileToken()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to generate file token", err)
|
||||
}
|
||||
|
||||
return api.app.OnFileAfterTokenRequest().Trigger(event, func(e *core.FileTokenEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
event := new(core.FileTokenRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Token = token
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, map[string]string{
|
||||
"token": e.Token,
|
||||
})
|
||||
return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
|
||||
return e.JSON(http.StatusOK, map[string]string{
|
||||
"token": e.Token,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (api *fileApi) download(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
recordId := c.PathParam("recordId")
|
||||
if recordId == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
record, err := api.app.Dao().FindRecordById(collection.Id, recordId)
|
||||
func (api *fileApi) download(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil {
|
||||
return NewNotFoundError("", err)
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
filename := c.PathParam("filename")
|
||||
recordId := e.Request.PathValue("recordId")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, recordId)
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
filename := e.Request.PathValue("filename")
|
||||
|
||||
fileField := record.FindFileFieldByFile(filename)
|
||||
if fileField == nil {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
options, ok := fileField.Options.(*schema.FileOptions)
|
||||
if !ok {
|
||||
return NewBadRequestError("", errors.New("failed to load file options"))
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
// check whether the request is authorized to view the protected file
|
||||
if options.Protected {
|
||||
token := c.QueryParam("token")
|
||||
|
||||
adminOrAuthRecord, _ := api.findAdminOrAuthRecordByFileToken(token)
|
||||
|
||||
// create a copy of the cached request data and adjust it for the current auth model
|
||||
requestInfo := *RequestInfo(c)
|
||||
requestInfo.Context = models.RequestInfoContextProtectedFile
|
||||
requestInfo.Admin = nil
|
||||
requestInfo.AuthRecord = nil
|
||||
if adminOrAuthRecord != nil {
|
||||
if admin, _ := adminOrAuthRecord.(*models.Admin); admin != nil {
|
||||
requestInfo.Admin = admin
|
||||
} else if record, _ := adminOrAuthRecord.(*models.Record); record != nil {
|
||||
requestInfo.AuthRecord = record
|
||||
}
|
||||
if fileField.Protected {
|
||||
originalRequestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load request info", err)
|
||||
}
|
||||
|
||||
if ok, _ := api.app.Dao().CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
|
||||
return NewForbiddenError("Insufficient permissions to access the file resource.", nil)
|
||||
token := e.Request.URL.Query().Get("token")
|
||||
authRecord, _ := e.App.FindAuthRecordByToken(token, core.TokenTypeFile)
|
||||
|
||||
// create a shallow copy of the cached request data and adjust it to the current auth record (if any)
|
||||
requestInfo := *originalRequestInfo
|
||||
requestInfo.Context = core.RequestInfoContextProtectedFile
|
||||
requestInfo.Auth = authRecord
|
||||
|
||||
if ok, _ := e.App.CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
|
||||
return e.NotFoundError("", errors.New("insufficient permissions to access the file resource"))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -142,16 +115,16 @@ func (api *fileApi) download(c echo.Context) error {
|
|||
|
||||
// fetch the original view file field related record
|
||||
if collection.IsView() {
|
||||
fileRecord, err := api.app.Dao().FindRecordByViewFile(collection.Id, fileField.Name, filename)
|
||||
fileRecord, err := e.App.FindRecordByViewFile(collection.Id, fileField.Name, filename)
|
||||
if err != nil {
|
||||
return NewNotFoundError("", fmt.Errorf("Failed to fetch view file field record: %w", err))
|
||||
return e.NotFoundError("", fmt.Errorf("failed to fetch view file field record: %w", err))
|
||||
}
|
||||
baseFilesPath = fileRecord.BaseFilesPath()
|
||||
}
|
||||
|
||||
fsys, err := api.app.NewFilesystem()
|
||||
fsys, err := e.App.NewFilesystem()
|
||||
if err != nil {
|
||||
return NewBadRequestError("Filesystem initialization failure.", err)
|
||||
return e.InternalServerError("Filesystem initialization failure.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
|
|
@ -160,12 +133,12 @@ func (api *fileApi) download(c echo.Context) error {
|
|||
servedName := filename
|
||||
|
||||
// check for valid thumb size param
|
||||
thumbSize := c.QueryParam("thumb")
|
||||
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, options.Thumbs)) {
|
||||
thumbSize := e.Request.URL.Query().Get("thumb")
|
||||
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, fileField.Thumbs)) {
|
||||
// extract the original file meta attributes and check it existence
|
||||
oAttrs, oAttrsErr := fsys.Attributes(originalPath)
|
||||
if oAttrsErr != nil {
|
||||
return NewNotFoundError("", err)
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
// check if it is an image
|
||||
|
|
@ -176,8 +149,8 @@ func (api *fileApi) download(c echo.Context) error {
|
|||
|
||||
// create a new thumb if it doesn't exist
|
||||
if exists, _ := fsys.Exists(servedPath); !exists {
|
||||
if err := api.createThumb(c, fsys, originalPath, servedPath, thumbSize); err != nil {
|
||||
api.app.Logger().Warn(
|
||||
if err := api.createThumb(e, fsys, originalPath, servedPath, thumbSize); err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Fallback to original - failed to create thumb "+servedName,
|
||||
slog.Any("error", err),
|
||||
slog.String("original", originalPath),
|
||||
|
|
@ -192,8 +165,8 @@ func (api *fileApi) download(c echo.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
event := new(core.FileDownloadEvent)
|
||||
event.HttpContext = c
|
||||
event := new(core.FileDownloadRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.FileField = fileField
|
||||
|
|
@ -203,61 +176,26 @@ func (api *fileApi) download(c echo.Context) error {
|
|||
// clickjacking shouldn't be a concern when serving uploaded files,
|
||||
// so it safe to unset the global X-Frame-Options to allow files embedding
|
||||
// (note: it is out of the hook to allow users to customize the behavior)
|
||||
c.Response().Header().Del("X-Frame-Options")
|
||||
e.Response.Header().Del("X-Frame-Options")
|
||||
|
||||
return api.app.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := fsys.Serve(e.HttpContext.Response(), e.HttpContext.Request(), e.ServedPath, e.ServedName); err != nil {
|
||||
return NewNotFoundError("", err)
|
||||
return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
|
||||
if err := fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName); err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (api *fileApi) findAdminOrAuthRecordByFileToken(fileToken string) (models.Model, error) {
|
||||
fileToken = strings.TrimSpace(fileToken)
|
||||
if fileToken == "" {
|
||||
return nil, errors.New("missing file token")
|
||||
}
|
||||
|
||||
claims, _ := security.ParseUnverifiedJWT(strings.TrimSpace(fileToken))
|
||||
tokenType := cast.ToString(claims["type"])
|
||||
|
||||
switch tokenType {
|
||||
case tokens.TypeAdmin:
|
||||
admin, err := api.app.Dao().FindAdminByToken(
|
||||
fileToken,
|
||||
api.app.Settings().AdminFileToken.Secret,
|
||||
)
|
||||
if err == nil && admin != nil {
|
||||
return admin, nil
|
||||
}
|
||||
case tokens.TypeAuthRecord:
|
||||
record, err := api.app.Dao().FindAuthRecordByToken(
|
||||
fileToken,
|
||||
api.app.Settings().RecordFileToken.Secret,
|
||||
)
|
||||
if err == nil && record != nil {
|
||||
return record, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("missing or invalid file token")
|
||||
}
|
||||
|
||||
func (api *fileApi) createThumb(
|
||||
c echo.Context,
|
||||
e *core.RequestEvent,
|
||||
fsys *filesystem.System,
|
||||
originalPath string,
|
||||
thumbPath string,
|
||||
thumbSize string,
|
||||
) error {
|
||||
ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request().Context(), api.thumbGenMaxWait)
|
||||
ctx, cancel := context.WithTimeout(e.Request.Context(), api.thumbGenMaxWait)
|
||||
defer cancel()
|
||||
|
||||
if err := api.thumbGenSem.Acquire(ctx, 1); err != nil {
|
||||
|
|
|
|||
|
|
@ -10,11 +10,8 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
|
@ -26,23 +23,54 @@ func TestFileToken(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/files/token",
|
||||
ExpectedStatus: 400,
|
||||
URL: "/api/files/token",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnFileBeforeTokenRequest": 1,
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "unauthorized with model and token via hook",
|
||||
Name: "superuser",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/files/token",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnFileBeforeTokenRequest().Add(func(e *core.FileTokenEvent) error {
|
||||
record, _ := app.Dao().FindAuthRecordByEmail("users", "test@example.com")
|
||||
e.Model = record
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hook token overwrite",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnFileTokenRequest().BindFunc(func(e *core.FileTokenRequestEvent) error {
|
||||
e.Token = "test"
|
||||
return nil
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
|
|
@ -50,40 +78,8 @@ func TestFileToken(t *testing.T) {
|
|||
`"token":"test"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnFileBeforeTokenRequest": 1,
|
||||
"OnFileAfterTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "auth record",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/files/token",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnFileBeforeTokenRequest": 1,
|
||||
"OnFileAfterTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "admin",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/files/token",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnFileBeforeTokenRequest": 1,
|
||||
"OnFileAfterTokenRequest": 1,
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -152,233 +148,271 @@ func TestFileDownload(t *testing.T) {
|
|||
{
|
||||
Name: "missing collection",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
URL: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing record",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
|
||||
URL: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing file",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing image",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testImg)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - missing thumb (should fallback to the original)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testImg)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop center)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropCenter)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop top)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropTop)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop bottom)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropBottom)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (fit)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbFit)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (zero width)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbZeroWidth)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (zero height)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbZeroHeight)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing non image file - thumb parameter should be ignored",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
|
||||
URL: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testFile)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// protected file access checks
|
||||
{
|
||||
Name: "protected file - expired token",
|
||||
Name: "protected file - superuser with expired file token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testFile)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - admin with expired file token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImFkbWluIn0.g7Q_3UX6H--JWJ7yt1Hoe-1ugTX1KpbKzdt0zjGSe-E",
|
||||
ExpectedStatus: 403,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.hTNDzikwJdcoWrLnRnp7xbaifZ2vuYZ0oOYRHtJfnk4",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - admin with valid file token",
|
||||
Name: "protected file - superuser with valid file token",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - guest without view access",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
ExpectedStatus: 403,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - guest with view access",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
dao := daos.New(app.Dao().DB())
|
||||
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock public view access
|
||||
c, err := dao.FindCollectionByNameOrId("demo1")
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("")
|
||||
if err := dao.SaveCollection(c); err != nil {
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - auth record without view access",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
dao := daos.New(app.Dao().DB())
|
||||
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock restricted user view access
|
||||
c, err := dao.FindCollectionByNameOrId("demo1")
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("@request.auth.verified = true")
|
||||
if err := dao.SaveCollection(c); err != nil {
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - auth record with view access",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
dao := daos.New(app.Dao().DB())
|
||||
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock user view access
|
||||
c, err := dao.FindCollectionByNameOrId("demo1")
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("@request.auth.verified = false")
|
||||
if err := dao.SaveCollection(c); err != nil {
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file in view (view's View API rule failure)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
|
||||
ExpectedStatus: 403,
|
||||
URL: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file in view (view's View API rule success)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
|
||||
URL: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:file"},
|
||||
{MaxRequests: 0, Label: "users:file"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:file"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
|
|
@ -410,30 +444,23 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
|
|||
defer fsys.Close()
|
||||
|
||||
// create a dummy file field collection
|
||||
demo1, err := app.Dao().FindCollectionByNameOrId("demo1")
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fileField := demo1.Schema.GetFieldByName("file_one")
|
||||
fileField.Options = &schema.FileOptions{
|
||||
Protected: false,
|
||||
MaxSelect: 1,
|
||||
MaxSize: 999999,
|
||||
// new thumbs
|
||||
Thumbs: []string{"111x111", "111x222", "111x333"},
|
||||
}
|
||||
demo1.Schema.AddField(fileField)
|
||||
if err := app.Dao().SaveCollection(demo1); err != nil {
|
||||
fileField := demo1.Fields.GetByName("file_one").(*core.FileField)
|
||||
fileField.Protected = false
|
||||
fileField.MaxSelect = 1
|
||||
fileField.MaxSize = 999999
|
||||
// new thumbs
|
||||
fileField.Thumbs = []string{"111x111", "111x222", "111x333"}
|
||||
demo1.Fields.Add(fileField)
|
||||
if err = app.Save(demo1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png"
|
||||
|
||||
e, err := apis.InitApi(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
urls := []string{
|
||||
"/api/files/" + fileKey + "?thumb=111x111",
|
||||
"/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb
|
||||
|
|
@ -446,7 +473,6 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
|
|||
wg.Add(len(urls))
|
||||
|
||||
for _, url := range urls {
|
||||
url := url
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
|
|
@ -454,7 +480,11 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
|
|||
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
|
||||
e.ServeHTTP(recorder, req)
|
||||
pbRouter, _ := apis.NewRouter(app)
|
||||
mux, _ := pbRouter.BuildMux()
|
||||
if mux != nil {
|
||||
mux.ServeHTTP(recorder, req)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,42 +2,52 @@ package apis
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindHealthApi registers the health api endpoint.
|
||||
func bindHealthApi(app core.App, rg *echo.Group) {
|
||||
api := healthApi{app: app}
|
||||
|
||||
func bindHealthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/health")
|
||||
subGroup.HEAD("", api.healthCheck)
|
||||
subGroup.GET("", api.healthCheck)
|
||||
}
|
||||
|
||||
type healthApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
type healthCheckResponse struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
CanBackup bool `json:"canBackup"`
|
||||
} `json:"data"`
|
||||
subGroup.GET("", healthCheck)
|
||||
}
|
||||
|
||||
// healthCheck returns a 200 OK response if the server is healthy.
|
||||
func (api *healthApi) healthCheck(c echo.Context) error {
|
||||
if c.Request().Method == http.MethodHead {
|
||||
return c.NoContent(http.StatusOK)
|
||||
func healthCheck(e *core.RequestEvent) error {
|
||||
resp := struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
Data map[string]any `json:"data"`
|
||||
}{
|
||||
Code: http.StatusOK,
|
||||
Message: "API is healthy.",
|
||||
}
|
||||
|
||||
resp := new(healthCheckResponse)
|
||||
resp.Code = http.StatusOK
|
||||
resp.Message = "API is healthy."
|
||||
resp.Data.CanBackup = !api.app.Store().Has(core.StoreKeyActiveBackup)
|
||||
if e.HasSuperuserAuth() {
|
||||
resp.Data = make(map[string]any, 3)
|
||||
resp.Data["canBackup"] = !e.App.Store().Has(core.StoreKeyActiveBackup)
|
||||
resp.Data["realIP"] = e.RealIP()
|
||||
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
// loosely check if behind a reverse proxy
|
||||
// (usually used in the dashboard to remind superusers in case deployed behind reverse-proxy)
|
||||
possibleProxyHeader := ""
|
||||
headersToCheck := append(
|
||||
slices.Clone(e.App.Settings().TrustedProxy.Headers),
|
||||
// common proxy headers
|
||||
"CF-Connecting-IP", "Fly-Client-IP", "X‑Forwarded-For",
|
||||
)
|
||||
for _, header := range headersToCheck {
|
||||
if e.Request.Header.Get(header) != "" {
|
||||
possibleProxyHeader = header
|
||||
break
|
||||
}
|
||||
}
|
||||
resp.Data["possibleProxyHeader"] = possibleProxyHeader
|
||||
} else {
|
||||
resp.Data = map[string]any{} // ensure that it is returned as object
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,21 +12,56 @@ func TestHealthAPI(t *testing.T) {
|
|||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "HEAD health status",
|
||||
Method: http.MethodHead,
|
||||
Url: "/api/health",
|
||||
Name: "GET health status (guest)",
|
||||
Method: http.MethodGet, // automatically matches also HEAD as a side-effect of the Go std mux
|
||||
URL: "/api/health",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{}`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"canBackup",
|
||||
"realIP",
|
||||
"possibleProxyHeader",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "GET health status",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/health",
|
||||
Name: "GET health status (regular user)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/health",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{}`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"canBackup",
|
||||
"realIP",
|
||||
"possibleProxyHeader",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "GET health status (superuser)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/health",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{`,
|
||||
`"canBackup":true`,
|
||||
`"realIP"`,
|
||||
`"possibleProxyHeader"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
56
apis/logs.go
56
apis/logs.go
|
|
@ -3,79 +3,71 @@ package apis
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindLogsApi registers the request logs api endpoints.
|
||||
func bindLogsApi(app core.App, rg *echo.Group) {
|
||||
api := logsApi{app: app}
|
||||
|
||||
subGroup := rg.Group("/logs", RequireAdminAuth())
|
||||
subGroup.GET("", api.list)
|
||||
subGroup.GET("/stats", api.stats)
|
||||
subGroup.GET("/:id", api.view)
|
||||
}
|
||||
|
||||
type logsApi struct {
|
||||
app core.App
|
||||
func bindLogsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/logs").Bind(RequireSuperuserAuth(), SkipSuccessActivityLog())
|
||||
sub.GET("", logsList)
|
||||
sub.GET("/stats", logsStats)
|
||||
sub.GET("/{id}", logsView)
|
||||
}
|
||||
|
||||
var logFilterFields = []string{
|
||||
"rowid", "id", "created", "updated",
|
||||
"level", "message", "data",
|
||||
"id", "created", "level", "message", "data",
|
||||
`^data\.[\w\.\:]*\w+$`,
|
||||
}
|
||||
|
||||
func (api *logsApi) list(c echo.Context) error {
|
||||
func logsList(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(api.app.LogsDao().LogQuery()).
|
||||
ParseAndExec(c.QueryParams().Encode(), &[]*models.Log{})
|
||||
Query(e.App.AuxModelQuery(&core.Log{})).
|
||||
ParseAndExec(e.Request.URL.Query().Encode(), &[]*core.Log{})
|
||||
|
||||
if err != nil {
|
||||
return NewBadRequestError("", err)
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (api *logsApi) stats(c echo.Context) error {
|
||||
func logsStats(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
|
||||
|
||||
filter := c.QueryParam(search.FilterQueryParam)
|
||||
filter := e.Request.URL.Query().Get(search.FilterQueryParam)
|
||||
|
||||
var expr dbx.Expression
|
||||
if filter != "" {
|
||||
var err error
|
||||
expr, err = search.FilterData(filter).BuildExpr(fieldResolver)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Invalid filter format.", err)
|
||||
return e.BadRequestError("Invalid filter format.", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := api.app.LogsDao().LogsStats(expr)
|
||||
stats, err := e.App.LogsStats(expr)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to generate logs stats.", err)
|
||||
return e.BadRequestError("Failed to generate logs stats.", err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, stats)
|
||||
return e.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
func (api *logsApi) view(c echo.Context) error {
|
||||
id := c.PathParam("id")
|
||||
func logsView(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
log, err := api.app.LogsDao().FindLogById(id)
|
||||
log, err := e.App.FindLogById(id)
|
||||
if err != nil || log == nil {
|
||||
return NewNotFoundError("", err)
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, log)
|
||||
return e.JSON(http.StatusOK, log)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
|
|
@ -15,29 +15,31 @@ func TestLogsList(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs",
|
||||
URL: "/api/logs",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/logs",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin",
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/logs",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
if err := tests.MockLogsData(app); err != nil {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
|
|
@ -50,16 +52,17 @@ func TestLogsList(t *testing.T) {
|
|||
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
|
||||
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + filter",
|
||||
Name: "authorized as superuser + filter",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs?filter=data.status>200",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/logs?filter=data.status>200",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
if err := tests.MockLogsData(app); err != nil {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
|
|
@ -71,6 +74,7 @@ func TestLogsList(t *testing.T) {
|
|||
`"items":[{`,
|
||||
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -86,44 +90,47 @@ func TestLogView(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (nonexisting request log)",
|
||||
Name: "authorized as superuser (nonexisting request log)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
if err := tests.MockLogsData(app); err != nil {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (existing request log)",
|
||||
Name: "authorized as superuser (existing request log)",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
if err := tests.MockLogsData(app); err != nil {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
|
|
@ -131,6 +138,7 @@ func TestLogView(t *testing.T) {
|
|||
ExpectedContent: []string{
|
||||
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -146,52 +154,54 @@ func TestLogsStats(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/stats",
|
||||
URL: "/api/logs/stats",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/stats",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/logs/stats",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin",
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/stats",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/logs/stats",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
if err := tests.MockLogsData(app); err != nil {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[{"total":1,"date":"2022-05-01 10:00:00.000Z"},{"total":1,"date":"2022-05-02 10:00:00.000Z"}]`,
|
||||
`[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin + filter",
|
||||
Name: "authorized as superuser + filter",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/logs/stats?filter=data.status>200",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/logs/stats?filter=data.status>200",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
if err := tests.MockLogsData(app); err != nil {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[{"total":1,"date":"2022-05-02 10:00:00.000Z"}]`,
|
||||
`[{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,303 +3,321 @@ package apis
|
|||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tokens"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// Common request context keys used by the middlewares and api handlers.
|
||||
// Common request event store keys used by the middlewares and api handlers.
|
||||
const (
|
||||
ContextAdminKey string = "admin"
|
||||
ContextAuthRecordKey string = "authRecord"
|
||||
ContextCollectionKey string = "collection"
|
||||
ContextExecStartKey string = "execStart"
|
||||
RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log
|
||||
|
||||
requestEventKeyExecStart = "__execStart" // the value must be time.Time
|
||||
requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultWWWRedirectMiddlewarePriority = -99999
|
||||
DefaultWWWRedirectMiddlewareId = "pbWWWRedirect"
|
||||
|
||||
DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30
|
||||
DefaultActivityLoggerMiddlewareId = "pbActivityLogger"
|
||||
DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog"
|
||||
DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog"
|
||||
|
||||
DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20
|
||||
DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken"
|
||||
|
||||
DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10
|
||||
DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders"
|
||||
|
||||
DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly"
|
||||
DefaultRequireAuthMiddlewareId = "pbRequireAuth"
|
||||
DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth"
|
||||
DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId = "pbRequireSuperuserAuthOnlyIfAny"
|
||||
DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth"
|
||||
DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth"
|
||||
)
|
||||
|
||||
// RequireGuestOnly middleware requires a request to NOT have a valid
|
||||
// Authorization header.
|
||||
//
|
||||
// This middleware is the opposite of [apis.RequireAdminOrRecordAuth()].
|
||||
func RequireGuestOnly() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
err := NewBadRequestError("The request can be accessed only by guests.", nil)
|
||||
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if record != nil {
|
||||
return err
|
||||
// This middleware is the opposite of [apis.RequireAuth()].
|
||||
func RequireGuestOnly() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireGuestOnlyMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth != nil {
|
||||
return router.NewBadRequestError("The request can be accessed only by guests.", nil)
|
||||
}
|
||||
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRecordAuth middleware requires a request to have
|
||||
// a valid record auth Authorization header.
|
||||
// RequireAuth middleware requires a request to have a valid record Authorization header.
|
||||
//
|
||||
// The auth record could be from any collection.
|
||||
//
|
||||
// You can further filter the allowed record auth collections by
|
||||
// specifying their names.
|
||||
// You can further filter the allowed record auth collections by specifying their names.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// apis.RequireRecordAuth()
|
||||
//
|
||||
// Or:
|
||||
//
|
||||
// apis.RequireRecordAuth("users", "supervisors")
|
||||
//
|
||||
// To restrict the auth record only to the loaded context collection,
|
||||
// use [apis.RequireSameContextRecordAuth()] instead.
|
||||
func RequireRecordAuth(optCollectionNames ...string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if record == nil {
|
||||
return NewUnauthorizedError("The request requires valid record authorization token to be set.", nil)
|
||||
}
|
||||
|
||||
// check record collection name
|
||||
if len(optCollectionNames) > 0 && !list.ExistInSlice(record.Collection().Name, optCollectionNames) {
|
||||
return NewForbiddenError("The authorized record model is not allowed to perform this action.", nil)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
// apis.RequireAuth() // any auth collection
|
||||
// apis.RequireAuth("_superusers", "users") // only the listed auth collections
|
||||
func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireAuthMiddlewareId,
|
||||
Func: requireAuth(optCollectionNames...),
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSameContextRecordAuth middleware requires a request to have
|
||||
// a valid record Authorization header.
|
||||
//
|
||||
// The auth record must be from the same collection already loaded in the context.
|
||||
func RequireSameContextRecordAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if record == nil {
|
||||
return NewUnauthorizedError("The request requires valid record authorization token to be set.", nil)
|
||||
}
|
||||
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil || record.Collection().Id != collection.Id {
|
||||
return NewForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", record.Collection().Name), nil)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
func requireAuth(optCollectionNames ...string) hook.HandlerFunc[*core.RequestEvent] {
|
||||
return func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
|
||||
// check record collection name
|
||||
if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) {
|
||||
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdminAuth middleware requires a request to have
|
||||
// a valid admin Authorization header.
|
||||
func RequireAdminAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin == nil {
|
||||
return NewUnauthorizedError("The request requires valid admin authorization token to be set.", nil)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
// RequireSuperuserAuth middleware requires a request to have
|
||||
// a valid superuser Authorization header.
|
||||
func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserAuthMiddlewareId,
|
||||
Func: requireAuth(core.CollectionNameSuperusers),
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdminAuthOnlyIfAny middleware requires a request to have
|
||||
// a valid admin Authorization header ONLY if the application has
|
||||
// at least 1 existing Admin model.
|
||||
func RequireAdminAuthOnlyIfAny(app core.App) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin != nil {
|
||||
return next(c)
|
||||
// RequireSuperuserAuthOnlyIfAny middleware requires a request to have
|
||||
// a valid superuser Authorization header ONLY if the application has
|
||||
// at least 1 existing superuser.
|
||||
func RequireSuperuserAuthOnlyIfAny() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.HasSuperuserAuth() {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
totalAdmins, err := app.Dao().TotalAdmins()
|
||||
totalSuperusers, err := e.App.CountRecords(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to fetch admins info.", err)
|
||||
return e.InternalServerError("Failed to fetch superusers info.", err)
|
||||
}
|
||||
|
||||
if totalAdmins == 0 {
|
||||
return next(c)
|
||||
if totalSuperusers == 0 {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
return NewUnauthorizedError("The request requires valid admin authorization token to be set.", nil)
|
||||
}
|
||||
return requireAuth(core.CollectionNameSuperusers)(e)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdminOrRecordAuth middleware requires a request to have
|
||||
// a valid admin or record Authorization header set.
|
||||
// RequireSuperuserOrOwnerAuth middleware requires a request to have
|
||||
// a valid superuser or regular record owner Authorization header set.
|
||||
//
|
||||
// You can further filter the allowed auth record collections by providing their names.
|
||||
//
|
||||
// This middleware is the opposite of [apis.RequireGuestOnly()].
|
||||
func RequireAdminOrRecordAuth(optCollectionNames ...string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
|
||||
if admin == nil && record == nil {
|
||||
return NewUnauthorizedError("The request requires admin or record authorization token to be set.", nil)
|
||||
}
|
||||
|
||||
if record != nil && len(optCollectionNames) > 0 && !list.ExistInSlice(record.Collection().Name, optCollectionNames) {
|
||||
return NewForbiddenError("The authorized record model is not allowed to perform this action.", nil)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdminOrOwnerAuth middleware requires a request to have
|
||||
// a valid admin or auth record owner Authorization header set.
|
||||
//
|
||||
// This middleware is similar to [apis.RequireAdminOrRecordAuth()] but
|
||||
// This middleware is similar to [apis.RequireAuth()] but
|
||||
// for the auth record token expects to have the same id as the path
|
||||
// parameter ownerIdParam (default to "id" if empty).
|
||||
func RequireAdminOrOwnerAuth(ownerIdParam string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin != nil {
|
||||
return next(c)
|
||||
// parameter ownerIdPathParam (default to "id" if empty).
|
||||
func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires superuser or record authorization token.", nil)
|
||||
}
|
||||
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if record == nil {
|
||||
return NewUnauthorizedError("The request requires admin or record authorization token to be set.", nil)
|
||||
if e.Auth.IsSuperuser() {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
if ownerIdParam == "" {
|
||||
ownerIdParam = "id"
|
||||
if ownerIdPathParam == "" {
|
||||
ownerIdPathParam = "id"
|
||||
}
|
||||
ownerId := c.PathParam(ownerIdParam)
|
||||
ownerId := e.Request.PathValue(ownerIdPathParam)
|
||||
|
||||
// note: it is "safe" to compare only the record id since the auth
|
||||
// record ids are treated as unique across all auth collections
|
||||
if record.Id != ownerId {
|
||||
return NewForbiddenError("You are not allowed to perform this request.", nil)
|
||||
// note: it is considered "safe" to compare only the record id
|
||||
// since the auth record ids are treated as unique across all auth collections
|
||||
if e.Auth.Id != ownerId {
|
||||
return e.ForbiddenError("You are not allowed to perform this request.", nil)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAuthContext middleware reads the Authorization request header
|
||||
// and loads the token related record or admin instance into the
|
||||
// request's context.
|
||||
// RequireSameCollectionContextAuth middleware requires a request to have
|
||||
// a valid record Authorization header and the auth record's collection to
|
||||
// match the one from the route path parameter (default to "collection" if collectionParam is empty).
|
||||
func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSameCollectionContextAuthMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if collection == nil || e.Auth.Collection().Id != collection.Id {
|
||||
return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value.
|
||||
//
|
||||
// This middleware is expected to be already registered by default for all routes.
|
||||
func LoadAuthContext(app core.App) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
token := c.Request().Header.Get("Authorization")
|
||||
// This middleware does nothing in case of missing, invalid or expired token.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
//
|
||||
// Note: We don't throw an error on invalid or expired token to allow
|
||||
// users to extend with their own custom handling in external middleware(s).
|
||||
func loadAuthToken() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultLoadAuthTokenMiddlewareId,
|
||||
Priority: DefaultLoadAuthTokenMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
token := getAuthTokenFromRequest(e)
|
||||
if token == "" {
|
||||
return next(c)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// the schema is not required and it is only for
|
||||
// compatibility with the defaults of some HTTP clients
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
|
||||
claims, _ := security.ParseUnverifiedJWT(token)
|
||||
tokenType := cast.ToString(claims["type"])
|
||||
|
||||
switch tokenType {
|
||||
case tokens.TypeAdmin:
|
||||
admin, err := app.Dao().FindAdminByToken(
|
||||
token,
|
||||
app.Settings().AdminAuthToken.Secret,
|
||||
)
|
||||
if err == nil && admin != nil {
|
||||
c.Set(ContextAdminKey, admin)
|
||||
}
|
||||
case tokens.TypeAuthRecord:
|
||||
record, err := app.Dao().FindAuthRecordByToken(
|
||||
token,
|
||||
app.Settings().RecordAuthToken.Secret,
|
||||
)
|
||||
if err == nil && record != nil {
|
||||
c.Set(ContextAuthRecordKey, record)
|
||||
}
|
||||
record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth)
|
||||
if err != nil {
|
||||
e.App.Logger().Debug("loadAuthToken failure", "error", err)
|
||||
} else if record != nil {
|
||||
e.Auth = record
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadCollectionContext middleware finds the collection with related
|
||||
// path identifier and loads it into the request context.
|
||||
func getAuthTokenFromRequest(e *core.RequestEvent) string {
|
||||
token := e.Request.Header.Get("Authorization")
|
||||
if token != "" {
|
||||
// the schema prefix is not required and it is only for
|
||||
// compatibility with the defaults of some HTTP clients
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// wwwRedirect performs www->non-www redirect(s) if the request host
|
||||
// matches with one of the values in redirectHosts.
|
||||
//
|
||||
// Set optCollectionTypes to further filter the found collection by its type.
|
||||
func LoadCollectionContext(app core.App, optCollectionTypes ...string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if param := c.PathParam("collection"); param != "" {
|
||||
collection, err := core.FindCachedCollectionByNameOrId(app, param)
|
||||
if err != nil || collection == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
// This middleware is registered by default on Serve for all routes.
|
||||
func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultWWWRedirectMiddlewareId,
|
||||
Priority: DefaultWWWRedirectMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
host := e.Request.Host
|
||||
|
||||
if len(optCollectionTypes) > 0 && !list.ExistInSlice(collection.Type, optCollectionTypes) {
|
||||
return NewBadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
c.Set(ContextCollectionKey, collection)
|
||||
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) {
|
||||
return e.Redirect(
|
||||
http.StatusTemporaryRedirect,
|
||||
(e.Request.URL.Scheme + "://" + host[4:] + e.Request.RequestURI),
|
||||
)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ActivityLogger middleware takes care to save the request information
|
||||
// securityHeaders middleware adds common security headers to the response.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
func securityHeaders() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultSecurityHeadersMiddlewareId,
|
||||
Priority: DefaultSecurityHeadersMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Response.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
e.Response.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
||||
|
||||
// @todo consider a default HSTS?
|
||||
// (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/)
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SkipSuccessActivityLog is a helper middleware that instructs the global
|
||||
// activity logger to log only requests that have failed/returned an error.
|
||||
func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultSkipSuccessActivityLogMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Set(requestEventKeySkipSuccessActivityLog, true)
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// activityLogger middleware takes care to save the request information
|
||||
// into the logs database.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
//
|
||||
// The middleware does nothing if the app logs retention period is zero
|
||||
// (aka. app.Settings().Logs.MaxDays = 0).
|
||||
func ActivityLogger(app core.App) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if err := next(c); err != nil {
|
||||
return err
|
||||
}
|
||||
//
|
||||
// Users can attach the [apis.SkipSuccessActivityLog()] middleware if
|
||||
// you want to log only the failed requests.
|
||||
func activityLogger() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultActivityLoggerMiddlewareId,
|
||||
Priority: DefaultActivityLoggerMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Set(requestEventKeyExecStart, time.Now())
|
||||
|
||||
logRequest(app, c, nil)
|
||||
err := e.Next()
|
||||
|
||||
return nil
|
||||
}
|
||||
logRequest(e, err)
|
||||
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func logRequest(app core.App, c echo.Context, err *ApiError) {
|
||||
func logRequest(event *core.RequestEvent, err error) {
|
||||
// no logs retention
|
||||
if app.Settings().Logs.MaxDays == 0 {
|
||||
if event.App.Settings().Logs.MaxDays == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// the non-error route has explicitly disabled the activity logger
|
||||
if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -307,32 +325,31 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
|
|||
|
||||
attrs = append(attrs, slog.String("type", "request"))
|
||||
|
||||
started := cast.ToTime(c.Get(ContextExecStartKey))
|
||||
started := cast.ToTime(event.Get(requestEventKeyExecStart))
|
||||
if !started.IsZero() {
|
||||
attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond)))
|
||||
}
|
||||
|
||||
httpRequest := c.Request()
|
||||
httpResponse := c.Response()
|
||||
method := strings.ToUpper(httpRequest.Method)
|
||||
status := httpResponse.Status
|
||||
requestUri := httpRequest.URL.RequestURI()
|
||||
if meta := event.Get(RequestEventKeyLogMeta); meta != nil {
|
||||
attrs = append(attrs, slog.Any("meta", meta))
|
||||
}
|
||||
|
||||
status := event.Status()
|
||||
method := cutStr(strings.ToUpper(event.Request.Method), 50)
|
||||
requestUri := cutStr(event.Request.URL.RequestURI(), 3000)
|
||||
|
||||
// parse the request error
|
||||
if err != nil {
|
||||
status = err.Code
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("error", err.Message),
|
||||
slog.Any("details", err.RawData()),
|
||||
)
|
||||
}
|
||||
|
||||
requestAuth := models.RequestAuthGuest
|
||||
if c.Get(ContextAuthRecordKey) != nil {
|
||||
requestAuth = models.RequestAuthRecord
|
||||
} else if c.Get(ContextAdminKey) != nil {
|
||||
requestAuth = models.RequestAuthAdmin
|
||||
if apiErr, ok := err.(*router.ApiError); ok {
|
||||
status = apiErr.Status
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("error", apiErr.Message),
|
||||
slog.Any("details", apiErr.RawData()),
|
||||
)
|
||||
} else {
|
||||
attrs = append(attrs, slog.String("error", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
attrs = append(
|
||||
|
|
@ -340,17 +357,33 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
|
|||
slog.String("url", requestUri),
|
||||
slog.String("method", method),
|
||||
slog.Int("status", status),
|
||||
slog.String("auth", requestAuth),
|
||||
slog.String("referer", httpRequest.Referer()),
|
||||
slog.String("userAgent", httpRequest.UserAgent()),
|
||||
slog.String("referer", cutStr(event.Request.Referer(), 2000)),
|
||||
slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)),
|
||||
)
|
||||
|
||||
if app.Settings().Logs.LogIp {
|
||||
ip, _, _ := net.SplitHostPort(httpRequest.RemoteAddr)
|
||||
if event.Auth != nil {
|
||||
attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name))
|
||||
|
||||
if event.App.Settings().Logs.LogAuthId {
|
||||
attrs = append(attrs, slog.String("authId", event.Auth.Id))
|
||||
}
|
||||
} else {
|
||||
attrs = append(attrs, slog.String("auth", ""))
|
||||
}
|
||||
|
||||
if event.App.Settings().Logs.LogIP {
|
||||
var userIP string
|
||||
if len(event.App.Settings().TrustedProxy.Headers) > 0 {
|
||||
userIP = event.RealIP()
|
||||
} else {
|
||||
// fallback to the legacy behavior (it is "safe" since it is only for log purposes)
|
||||
userIP = cutStr(event.UnsafeRealIP(), 50)
|
||||
}
|
||||
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("userIp", realUserIp(httpRequest, ip)),
|
||||
slog.String("remoteIp", ip),
|
||||
slog.String("userIP", userIP),
|
||||
slog.String("remoteIP", event.RemoteIP()),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -358,64 +391,23 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
|
|||
routine.FireAndForget(func() {
|
||||
message := method + " "
|
||||
|
||||
if escaped, err := url.PathUnescape(requestUri); err == nil {
|
||||
if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil {
|
||||
message += escaped
|
||||
} else {
|
||||
message += requestUri
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
app.Logger().Error(message, attrs...)
|
||||
event.App.Logger().Error(message, attrs...)
|
||||
} else {
|
||||
app.Logger().Info(message, attrs...)
|
||||
event.App.Logger().Info(message, attrs...)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Returns the "real" user IP from common proxy headers (or fallbackIp if none is found).
|
||||
//
|
||||
// The returned IP value shouldn't be trusted if not behind a trusted reverse proxy!
|
||||
func realUserIp(r *http.Request, fallbackIp string) string {
|
||||
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("Fly-Client-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
if ipsList := r.Header.Get("X-Forwarded-For"); ipsList != "" {
|
||||
// extract the first non-empty leftmost-ish ip
|
||||
ips := strings.Split(ipsList, ",")
|
||||
for _, ip := range ips {
|
||||
ip = strings.TrimSpace(ip)
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fallbackIp
|
||||
}
|
||||
|
||||
// @todo consider removing as this may no longer be needed due to the custom rest.MultiBinder.
|
||||
//
|
||||
// eagerRequestInfoCache ensures that the request data is cached in the request
|
||||
// context to allow reading for example the json request body data more than once.
|
||||
func eagerRequestInfoCache(app core.App) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
switch c.Request().Method {
|
||||
// currently we are eagerly caching only the requests with body
|
||||
case "POST", "PUT", "PATCH", "DELETE":
|
||||
RequestInfo(c)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
func cutStr(str string, max int) string {
|
||||
if len(str) > max {
|
||||
return str[:max] + "..."
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
var ErrRequestEntityTooLarge = router.NewApiError(http.StatusRequestEntityTooLarge, "Request entity too large", nil)
|
||||
|
||||
const DefaultMaxBodySize int64 = 32 << 20
|
||||
|
||||
const (
|
||||
DefaultBodyLimitMiddlewareId = "pbBodyLimit"
|
||||
DefaultBodyLimitMiddlewarePriority = DefaultRateLimitMiddlewarePriority + 10
|
||||
)
|
||||
|
||||
// BodyLimit returns a middleware function that changes the default request body size limit.
|
||||
//
|
||||
// Note that in order to have effect this middleware should be registered
|
||||
// before other middlewares that reads the request body.
|
||||
//
|
||||
// If limitBytes <= 0, no limit is applied.
|
||||
//
|
||||
// Otherwise, if the request body size exceeds the configured limitBytes,
|
||||
// it sends 413 error response.
|
||||
func BodyLimit(limitBytes int64) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultBodyLimitMiddlewareId,
|
||||
Priority: DefaultBodyLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
err := applyBodyLimit(e, limitBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func dynamicCollectionBodyLimit(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultBodyLimitMiddlewareId,
|
||||
Priority: DefaultBodyLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid collection context.", err)
|
||||
}
|
||||
|
||||
limitBytes := DefaultMaxBodySize
|
||||
if !collection.IsView() {
|
||||
for _, f := range collection.Fields {
|
||||
if calc, ok := f.(core.MaxBodySizeCalculator); ok {
|
||||
limitBytes += calc.CalculateMaxBodySize()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = applyBodyLimit(e, limitBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyBodyLimit(e *core.RequestEvent, limitBytes int64) error {
|
||||
// no limit
|
||||
if limitBytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// optimistically check the submitted request content length
|
||||
if e.Request.ContentLength > limitBytes {
|
||||
return ErrRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// replace the request body
|
||||
//
|
||||
// note: we don't use sync.Pool since the size of the elements could vary too much
|
||||
// and it might not be efficient (see https://github.com/golang/go/issues/23199)
|
||||
e.Request.Body = &limitedReader{ReadCloser: e.Request.Body, limit: limitBytes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type limitedReader struct {
|
||||
io.ReadCloser
|
||||
limit int64
|
||||
totalRead int64
|
||||
}
|
||||
|
||||
func (r *limitedReader) Read(b []byte) (int, error) {
|
||||
n, err := r.ReadCloser.Read(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
r.totalRead += int64(n)
|
||||
if r.totalRead > r.limit {
|
||||
return n, ErrRequestEntityTooLarge
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *limitedReader) Reread() {
|
||||
rr, ok := r.ReadCloser.(router.Rereader)
|
||||
if ok {
|
||||
rr.Reread()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestBodyLimitMiddleware(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
pbRouter, err := apis.NewRouter(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pbRouter.POST("/a", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "a")
|
||||
}) // default global BodyLimit check
|
||||
|
||||
pbRouter.POST("/b", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "b")
|
||||
}).Bind(apis.BodyLimit(20))
|
||||
|
||||
mux, err := pbRouter.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
url string
|
||||
size int64
|
||||
expectedStatus int
|
||||
}{
|
||||
{"/a", 21, 200},
|
||||
{"/a", apis.DefaultMaxBodySize + 1, 413},
|
||||
{"/b", 20, 200},
|
||||
{"/b", 21, 413},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%s_%d", s.url, s.size), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", s.url, bytes.NewReader(make([]byte, s.size)))
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if result.StatusCode != s.expectedStatus {
|
||||
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,307 @@
|
|||
package apis
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// This middleware is ported from echo/middleware to minimize the breaking
|
||||
// changes and differences in the API behavior from earlier PocketBase versions
|
||||
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go).
|
||||
//
|
||||
// I doubt that this would matter for most cases, but the only major difference
|
||||
// is that for non-supported routes this middleware doesn't return 405 and fallbacks
|
||||
// to the default catch-all PocketBase route (aka. returns 404) to avoid
|
||||
// the extra overhead of further hijacking and wrapping the Go default mux
|
||||
// (https://github.com/golang/go/issues/65648#issuecomment-1955328807).
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCorsMiddlewareId = "pbCors"
|
||||
DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted
|
||||
)
|
||||
|
||||
// CORSConfig defines the config for CORS middleware.
|
||||
type CORSConfig struct {
|
||||
// AllowOrigins determines the value of the Access-Control-Allow-Origin
|
||||
// response header. This header defines a list of origins that may access the
|
||||
// resource. The wildcard characters '*' and '?' are supported and are
|
||||
// converted to regex fragments '.*' and '.' accordingly.
|
||||
//
|
||||
// Security: use extreme caution when handling the origin, and carefully
|
||||
// validate any logic. Remember that attackers may register hostile domain names.
|
||||
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// Optional. Default value []string{"*"}.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
|
||||
AllowOrigins []string
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
//
|
||||
// Security: use extreme caution when handling the origin, and carefully
|
||||
// validate any logic. Remember that attackers may register hostile domain names.
|
||||
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error)
|
||||
|
||||
// AllowMethods determines the value of the Access-Control-Allow-Methods
|
||||
// response header. This header specified the list of methods allowed when
|
||||
// accessing the resource. This is used in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
|
||||
AllowMethods []string
|
||||
|
||||
// AllowHeaders determines the value of the Access-Control-Allow-Headers
|
||||
// response header. This header is used in response to a preflight request to
|
||||
// indicate which HTTP headers can be used when making the actual request.
|
||||
//
|
||||
// Optional. Default value []string{}.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
|
||||
AllowHeaders []string
|
||||
|
||||
// AllowCredentials determines the value of the
|
||||
// Access-Control-Allow-Credentials response header. This header indicates
|
||||
// whether or not the response to the request can be exposed when the
|
||||
// credentials mode (Request.credentials) is true. When used as part of a
|
||||
// response to a preflight request, this indicates whether or not the actual
|
||||
// request can be made using credentials. See also
|
||||
// [MDN: Access-Control-Allow-Credentials].
|
||||
//
|
||||
// Optional. Default value false, in which case the header is not set.
|
||||
//
|
||||
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
|
||||
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
|
||||
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
|
||||
AllowCredentials bool
|
||||
|
||||
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
|
||||
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
|
||||
//
|
||||
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
|
||||
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
|
||||
//
|
||||
// Optional. Default value is false.
|
||||
UnsafeWildcardOriginWithAllowCredentials bool
|
||||
|
||||
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
|
||||
// defines a list of headers that clients are allowed to access.
|
||||
//
|
||||
// Optional. Default value []string{}, in which case the header is not set.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
|
||||
ExposeHeaders []string
|
||||
|
||||
// MaxAge determines the value of the Access-Control-Max-Age response header.
|
||||
// This header indicates how long (in seconds) the results of a preflight
|
||||
// request can be cached.
|
||||
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
|
||||
//
|
||||
// Optional. Default value 0 - meaning header is not sent.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// DefaultCORSConfig is the default CORS middleware config.
|
||||
var DefaultCORSConfig = CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}
|
||||
|
||||
// CORSWithConfig returns a CORS middleware with config.
|
||||
func CORSWithConfig(config CORSConfig) hook.HandlerFunc[*core.RequestEvent] {
|
||||
// Defaults
|
||||
if len(config.AllowOrigins) == 0 {
|
||||
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
|
||||
}
|
||||
if len(config.AllowMethods) == 0 {
|
||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||
}
|
||||
|
||||
allowOriginPatterns := []string{}
|
||||
for _, origin := range config.AllowOrigins {
|
||||
pattern := regexp.QuoteMeta(origin)
|
||||
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
||||
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
||||
pattern = "^" + pattern + "$"
|
||||
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
||||
}
|
||||
|
||||
allowMethods := strings.Join(config.AllowMethods, ",")
|
||||
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
||||
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
||||
|
||||
maxAge := "0"
|
||||
if config.MaxAge > 0 {
|
||||
maxAge = strconv.Itoa(config.MaxAge)
|
||||
}
|
||||
|
||||
return func(e *core.RequestEvent) error {
|
||||
req := e.Request
|
||||
res := e.Response
|
||||
origin := req.Header.Get("Origin")
|
||||
allowOrigin := ""
|
||||
|
||||
res.Header().Add("Vary", "Origin")
|
||||
|
||||
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
|
||||
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
// For simplicity we just consider method type and later `Origin` header.
|
||||
preflight := req.Method == http.MethodOptions
|
||||
|
||||
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
|
||||
if origin == "" {
|
||||
if !preflight {
|
||||
return e.Next()
|
||||
}
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
if config.AllowOriginFunc != nil {
|
||||
allowed, err := config.AllowOriginFunc(origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if allowed {
|
||||
allowOrigin = origin
|
||||
}
|
||||
} else {
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
checkPatterns := false
|
||||
if allowOrigin == "" {
|
||||
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
|
||||
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
|
||||
checkPatterns = true
|
||||
}
|
||||
}
|
||||
if checkPatterns {
|
||||
for _, re := range allowOriginPatterns {
|
||||
if match, _ := regexp.MatchString(re, origin); match {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Origin not allowed
|
||||
if allowOrigin == "" {
|
||||
if !preflight {
|
||||
return e.Next()
|
||||
}
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
res.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// Simple request
|
||||
if !preflight {
|
||||
if exposeHeaders != "" {
|
||||
res.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Preflight request
|
||||
res.Header().Add("Vary", "Access-Control-Request-Method")
|
||||
res.Header().Add("Vary", "Access-Control-Request-Headers")
|
||||
res.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
||||
|
||||
if allowHeaders != "" {
|
||||
res.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
||||
} else {
|
||||
h := req.Header.Get("Access-Control-Request-Headers")
|
||||
if h != "" {
|
||||
res.Header().Set("Access-Control-Allow-Headers", h)
|
||||
}
|
||||
}
|
||||
if config.MaxAge != 0 {
|
||||
res.Header().Set("Access-Control-Max-Age", maxAge)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// matchSubdomain compares authority with wildcard
|
||||
func matchSubdomain(domain, pattern string) bool {
|
||||
if !matchScheme(domain, pattern) {
|
||||
return false
|
||||
}
|
||||
didx := strings.Index(domain, "://")
|
||||
pidx := strings.Index(pattern, "://")
|
||||
if didx == -1 || pidx == -1 {
|
||||
return false
|
||||
}
|
||||
domAuth := domain[didx+3:]
|
||||
// to avoid long loop by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
return false
|
||||
}
|
||||
patAuth := pattern[pidx+3:]
|
||||
|
||||
domComp := strings.Split(domAuth, ".")
|
||||
patComp := strings.Split(patAuth, ".")
|
||||
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(domComp) - 1 - i
|
||||
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
||||
}
|
||||
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(patComp) - 1 - i
|
||||
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
||||
}
|
||||
|
||||
for i, v := range domComp {
|
||||
if len(patComp) <= i {
|
||||
return false
|
||||
}
|
||||
p := patComp[i]
|
||||
if p == "*" {
|
||||
return true
|
||||
}
|
||||
if p != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
package apis
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// This middleware is ported from echo/middleware to minimize the breaking
|
||||
// changes and differences in the API behavior from earlier PocketBase versions
|
||||
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/compress.go).
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
const (
|
||||
gzipScheme = "gzip"
|
||||
)
|
||||
|
||||
// GzipConfig defines the config for Gzip middleware.
|
||||
type GzipConfig struct {
|
||||
// Gzip compression level.
|
||||
// Optional. Default value -1.
|
||||
Level int
|
||||
|
||||
// Length threshold before gzip compression is applied.
|
||||
// Optional. Default value 0.
|
||||
//
|
||||
// Most of the time you will not need to change the default. Compressing
|
||||
// a short response might increase the transmitted data because of the
|
||||
// gzip format overhead. Compressing the response will also consume CPU
|
||||
// and time on the server and the client (for decompressing). Depending on
|
||||
// your use case such a threshold might be useful.
|
||||
//
|
||||
// See also:
|
||||
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
|
||||
MinLength int
|
||||
}
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
func Gzip() hook.HandlerFunc[*core.RequestEvent] {
|
||||
return GzipWithConfig(GzipConfig{})
|
||||
}
|
||||
|
||||
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
func GzipWithConfig(config GzipConfig) hook.HandlerFunc[*core.RequestEvent] {
|
||||
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
|
||||
panic(errors.New("invalid gzip level"))
|
||||
}
|
||||
if config.Level == 0 {
|
||||
config.Level = -1
|
||||
}
|
||||
if config.MinLength < 0 {
|
||||
config.MinLength = 0
|
||||
}
|
||||
|
||||
pool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w
|
||||
},
|
||||
}
|
||||
|
||||
bpool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := &bytes.Buffer{}
|
||||
return b
|
||||
},
|
||||
}
|
||||
|
||||
return func(e *core.RequestEvent) error {
|
||||
e.Response.Header().Add("Vary", "Accept-Encoding")
|
||||
if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) {
|
||||
w, ok := pool.Get().(*gzip.Writer)
|
||||
if !ok {
|
||||
return e.InternalServerError("", errors.New("failed to get gzip.Writer"))
|
||||
}
|
||||
|
||||
rw := e.Response
|
||||
w.Reset(rw)
|
||||
|
||||
buf := bpool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
|
||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
|
||||
defer func() {
|
||||
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
|
||||
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
|
||||
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
|
||||
if !grw.wroteBody {
|
||||
if rw.Header().Get("Content-Encoding") == gzipScheme {
|
||||
rw.Header().Del("Content-Encoding")
|
||||
}
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
// We have to reset response to it's pristine state when
|
||||
// nothing is written to body or error is returned.
|
||||
// See issue echo#424, echo#407.
|
||||
e.Response = rw
|
||||
w.Reset(io.Discard)
|
||||
} else if !grw.minLengthExceeded {
|
||||
// Write uncompressed response
|
||||
e.Response = rw
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
grw.buffer.WriteTo(rw)
|
||||
w.Reset(io.Discard)
|
||||
}
|
||||
w.Close()
|
||||
bpool.Put(buf)
|
||||
pool.Put(w)
|
||||
}()
|
||||
e.Response = grw
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
io.Writer
|
||||
buffer *bytes.Buffer
|
||||
minLength int
|
||||
code int
|
||||
wroteHeader bool
|
||||
wroteBody bool
|
||||
minLengthExceeded bool
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(code int) {
|
||||
w.Header().Del("Content-Length") // Issue echo#444
|
||||
|
||||
w.wroteHeader = true
|
||||
|
||||
// Delay writing of the header until we know if we'll actually compress the response
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", http.DetectContentType(b))
|
||||
}
|
||||
|
||||
w.wroteBody = true
|
||||
|
||||
if !w.minLengthExceeded {
|
||||
n, err := w.buffer.Write(b)
|
||||
|
||||
if w.buffer.Len() >= w.minLength {
|
||||
w.minLengthExceeded = true
|
||||
|
||||
// The minimum length is exceeded, add Content-Encoding header and write the header
|
||||
w.Header().Set("Content-Encoding", gzipScheme)
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
return w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
if !w.minLengthExceeded {
|
||||
// Enforce compression because we will not know how much more data will come
|
||||
w.minLengthExceeded = true
|
||||
w.Header().Set("Content-Encoding", gzipScheme)
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
_, _ = w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
_ = w.Writer.(*gzip.Writer).Flush()
|
||||
|
||||
_ = http.NewResponseController(w.ResponseWriter).Flush()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(w.ResponseWriter).Hijack()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
rw := w.ResponseWriter
|
||||
for {
|
||||
switch p := rw.(type) {
|
||||
case http.Pusher:
|
||||
return p.Push(target, opts)
|
||||
case router.RWUnwrapper:
|
||||
rw = p.Unwrap()
|
||||
default:
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
rw := w.ResponseWriter
|
||||
for {
|
||||
switch rf := rw.(type) {
|
||||
case io.ReaderFrom:
|
||||
return rf.ReadFrom(r)
|
||||
case router.RWUnwrapper:
|
||||
rw = rf.Unwrap()
|
||||
default:
|
||||
return io.Copy(w.ResponseWriter, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
|
@ -0,0 +1,298 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultRateLimitMiddlewareId = "pbRateLimit"
|
||||
DefaultRateLimitMiddlewarePriority = -1000
|
||||
)
|
||||
|
||||
const (
|
||||
rateLimitersStoreKey = "__pbRateLimiters__"
|
||||
rateLimitersCronKey = "__pbRateLimitersCleanup__"
|
||||
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
|
||||
)
|
||||
|
||||
// rateLimit defines the global rate limit middleware.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
func rateLimit() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRateLimitMiddlewareId,
|
||||
Priority: DefaultRateLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if skipRateLimit(e) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(defaultRateLimitLabels(e))
|
||||
if ok {
|
||||
err := checkRateLimit(e, e.Request.Pattern, rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
|
||||
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRateLimitMiddlewareId,
|
||||
Priority: DefaultRateLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid collection context.", err)
|
||||
}
|
||||
|
||||
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// checkCollectionRateLimit checks whether the current request satisfy the
|
||||
// rate limit configuration for the specific collection.
|
||||
//
|
||||
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
|
||||
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
|
||||
if skipRateLimit(e) {
|
||||
return nil
|
||||
}
|
||||
|
||||
labels := make([]string, 0, 2+len(baseTags)*2)
|
||||
|
||||
rtId := collection.Id + e.Request.Pattern
|
||||
|
||||
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
|
||||
for _, baseTag := range baseTags {
|
||||
rtId += baseTag
|
||||
labels = append(labels, collection.Name+":"+baseTag)
|
||||
}
|
||||
|
||||
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
|
||||
for _, baseTag := range baseTags {
|
||||
labels = append(labels, "*:"+baseTag)
|
||||
}
|
||||
labels = append(labels, defaultRateLimitLabels(e)...)
|
||||
|
||||
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels)
|
||||
if ok {
|
||||
return checkRateLimit(e, rtId, rule)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// @todo consider exporting as RateLimit helper?
|
||||
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
|
||||
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
|
||||
return initRateLimitersStore(e.App)
|
||||
}).(*store.Store[*rateLimiter])
|
||||
if rateLimiters == nil {
|
||||
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
|
||||
return nil
|
||||
}
|
||||
|
||||
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
|
||||
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
|
||||
})
|
||||
if rt == nil {
|
||||
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
|
||||
return nil
|
||||
}
|
||||
|
||||
key := e.RealIP()
|
||||
if key == "" {
|
||||
e.App.Logger().Warn("Empty rate limit client key")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !rt.isAllowed(key) {
|
||||
return e.TooManyRequestsError("", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func skipRateLimit(e *core.RequestEvent) bool {
|
||||
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
|
||||
}
|
||||
|
||||
func defaultRateLimitLabels(e *core.RequestEvent) []string {
|
||||
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
|
||||
}
|
||||
|
||||
func destroyRateLimitersStore(app core.App) {
|
||||
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
|
||||
app.Cron().Remove(rateLimitersCronKey)
|
||||
app.Store().Remove(rateLimitersStoreKey)
|
||||
}
|
||||
|
||||
func initRateLimitersStore(app core.App) *store.Store[*rateLimiter] {
|
||||
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
|
||||
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter])
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
limiters := limitersStore.GetAll()
|
||||
for _, limiter := range limiters {
|
||||
limiter.clean()
|
||||
}
|
||||
})
|
||||
|
||||
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
|
||||
Id: rateLimitersSettingsHookId,
|
||||
Func: func(e *core.SettingsReloadEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// reset
|
||||
destroyRateLimitersStore(e.App)
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return store.New[*rateLimiter](nil)
|
||||
}
|
||||
|
||||
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
maxAllowed: maxAllowed,
|
||||
interval: intervalInSec,
|
||||
minDeleteInterval: minDeleteIntervalInSec,
|
||||
clients: map[string]*fixedWindow{},
|
||||
}
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
clients map[string]*fixedWindow
|
||||
|
||||
maxAllowed int
|
||||
interval int64
|
||||
minDeleteInterval int64
|
||||
totalDeleted int64
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (rt *rateLimiter) isAllowed(key string) bool {
|
||||
// lock only reads to minimize locks contention
|
||||
rt.RLock()
|
||||
client, ok := rt.clients[key]
|
||||
rt.RUnlock()
|
||||
|
||||
if !ok {
|
||||
rt.Lock()
|
||||
// check again in case the client was added by another request
|
||||
client, ok = rt.clients[key]
|
||||
if !ok {
|
||||
client = newFixedWindow(rt.maxAllowed, rt.interval)
|
||||
rt.clients[key] = client
|
||||
}
|
||||
rt.Unlock()
|
||||
}
|
||||
|
||||
return client.consume()
|
||||
}
|
||||
|
||||
func (rt *rateLimiter) clean() {
|
||||
rt.Lock()
|
||||
defer rt.Unlock()
|
||||
|
||||
nowUnix := time.Now().Unix()
|
||||
|
||||
for k, client := range rt.clients {
|
||||
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
|
||||
delete(rt.clients, k)
|
||||
rt.totalDeleted++
|
||||
}
|
||||
}
|
||||
|
||||
// "shrink" the map if too may items were deleted
|
||||
//
|
||||
// @todo remove after https://github.com/golang/go/issues/20135
|
||||
if rt.totalDeleted >= 300 {
|
||||
shrunk := make(map[string]*fixedWindow, len(rt.clients))
|
||||
for k, v := range rt.clients {
|
||||
shrunk[k] = v
|
||||
}
|
||||
rt.clients = shrunk
|
||||
rt.totalDeleted = 0
|
||||
}
|
||||
}
|
||||
|
||||
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
|
||||
return &fixedWindow{
|
||||
maxAllowed: maxAllowed,
|
||||
interval: intervalInSec,
|
||||
}
|
||||
}
|
||||
|
||||
type fixedWindow struct {
|
||||
// use plain Mutex instead of RWMutex since the operations are expected
|
||||
// to be mostly writes (e.g. consume()) and it should perform better
|
||||
sync.Mutex
|
||||
|
||||
maxAllowed int // the max allowed tokens per interval
|
||||
available int // the total available tokens
|
||||
interval int64 // in seconds
|
||||
lastConsume int64 // the time of the last consume
|
||||
}
|
||||
|
||||
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
|
||||
// (usually used to perform periodic cleanup of staled instances).
|
||||
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return relativeNow-l.lastConsume > minElapsed
|
||||
}
|
||||
|
||||
// consume decrease the current window allowance with 1 (if not exhausted already).
|
||||
//
|
||||
// It returns false if the allowance has been already exhausted and the user
|
||||
// has to wait until it resets back to its maxAllowed value.
|
||||
func (l *fixedWindow) consume() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
nowUnix := time.Now().Unix()
|
||||
|
||||
// reset consumed counter
|
||||
if nowUnix-l.lastConsume >= l.interval {
|
||||
l.available = l.maxAllowed
|
||||
}
|
||||
|
||||
if l.available > 0 {
|
||||
l.available--
|
||||
l.lastConsume = nowUnix
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestDefaultRateLimitMiddleware(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{
|
||||
Label: "/rate/",
|
||||
MaxRequests: 2,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "/rate/b",
|
||||
MaxRequests: 3,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "POST /rate/b",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
pbRouter, err := apis.NewRouter(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "norate")
|
||||
}).BindFunc(func(e *core.RequestEvent) error {
|
||||
return e.Next()
|
||||
})
|
||||
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "a")
|
||||
})
|
||||
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "b")
|
||||
})
|
||||
|
||||
mux, err := pbRouter.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
url string
|
||||
wait float64
|
||||
expectedStatus int
|
||||
}{
|
||||
{"/norate", 0, 200},
|
||||
{"/norate", 0, 200},
|
||||
{"/norate", 0, 200},
|
||||
{"/norate", 0, 200},
|
||||
{"/norate", 0, 200},
|
||||
|
||||
{"/rate/a", 0, 200},
|
||||
{"/rate/a", 0, 200},
|
||||
{"/rate/a", 0, 429},
|
||||
{"/rate/a", 0, 429},
|
||||
{"/rate/a", 1.1, 200},
|
||||
{"/rate/a", 0, 200},
|
||||
{"/rate/a", 0, 429},
|
||||
|
||||
{"/rate/b", 0, 200},
|
||||
{"/rate/b", 0, 200},
|
||||
{"/rate/b", 0, 200},
|
||||
{"/rate/b", 0, 429},
|
||||
{"/rate/b", 1.1, 200},
|
||||
{"/rate/b", 0, 200},
|
||||
{"/rate/b", 0, 200},
|
||||
{"/rate/b", 0, 429},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.url, func(t *testing.T) {
|
||||
if s.wait > 0 {
|
||||
time.Sleep(time.Duration(s.wait) * time.Second)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", s.url, nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
if result.StatusCode != s.expectedStatus {
|
||||
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
972
apis/realtime.go
972
apis/realtime.go
File diff suppressed because it is too large
Load Diff
|
|
@ -1,20 +1,17 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
|
|
@ -22,7 +19,7 @@ func TestRealtimeConnect(t *testing.T) {
|
|||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
|
|
@ -31,12 +28,11 @@ func TestRealtimeConnect(t *testing.T) {
|
|||
`data:{"clientId":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeBeforeMessageSend": 1,
|
||||
"OnRealtimeAfterMessageSend": 1,
|
||||
"OnRealtimeDisconnectRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
|
|
@ -45,23 +41,23 @@ func TestRealtimeConnect(t *testing.T) {
|
|||
{
|
||||
Name: "PB_CONNECT interrupt",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeBeforeMessageSend": 1,
|
||||
"OnRealtimeDisconnectRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
|
||||
if e.Message.Name == "PB_CONNECT" {
|
||||
return errors.New("PB_CONNECT error")
|
||||
}
|
||||
return nil
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
|
|
@ -70,20 +66,20 @@ func TestRealtimeConnect(t *testing.T) {
|
|||
{
|
||||
Name: "Skipping/ignoring messages",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeBeforeMessageSend": 1,
|
||||
"OnRealtimeDisconnectRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error {
|
||||
return hook.StopPropagation
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
|
|
@ -101,34 +97,34 @@ func TestRealtimeSubscribe(t *testing.T) {
|
|||
|
||||
resetClient := func() {
|
||||
client.Unsubscribe()
|
||||
client.Set(apis.ContextAdminKey, nil)
|
||||
client.Set(apis.ContextAuthRecordKey, nil)
|
||||
client.Set(apis.RealtimeClientAuthKey, nil)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing client",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client - empty subscriptions",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeBeforeSubscribeRequest": 1,
|
||||
"OnRealtimeAfterSubscribeRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected no subscriptions, got %v", client.Subscriptions())
|
||||
}
|
||||
|
|
@ -138,18 +134,18 @@ func TestRealtimeSubscribe(t *testing.T) {
|
|||
{
|
||||
Name: "existing client - 2 new subscriptions",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeBeforeSubscribeRequest": 1,
|
||||
"OnRealtimeAfterSubscribeRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
expectedSubs := []string{"test1", "test2"}
|
||||
if len(expectedSubs) != len(client.Subscriptions()) {
|
||||
t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions())
|
||||
|
|
@ -164,49 +160,49 @@ func TestRealtimeSubscribe(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - authorized admin",
|
||||
Name: "existing client - authorized superuser",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeBeforeSubscribeRequest": 1,
|
||||
"OnRealtimeAfterSubscribeRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
admin, _ := client.Get(apis.ContextAdminKey).(*models.Admin)
|
||||
if admin == nil {
|
||||
t.Errorf("Expected admin auth model, got nil")
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil || !authRecord.IsSuperuser() {
|
||||
t.Errorf("Expected superuser auth record, got %v", authRecord)
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - authorized record",
|
||||
Name: "existing client - authorized regular record",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnRealtimeBeforeSubscribeRequest": 1,
|
||||
"OnRealtimeAfterSubscribeRequest": 1,
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
t.Errorf("Expected regular user auth record, got %v", authRecord)
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
|
|
@ -214,22 +210,50 @@ func TestRealtimeSubscribe(t *testing.T) {
|
|||
{
|
||||
Name: "existing client - mismatched auth",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/realtime",
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
initialAuth := &models.Record{}
|
||||
initialAuth.RefreshId()
|
||||
client.Set(apis.ContextAuthRecordKey, initialAuth)
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - unauthorized client",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
|
|
@ -247,24 +271,29 @@ func TestRealtimeAuthRecordDeleteEvent(t *testing.T) {
|
|||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
apis.InitApi(testApp)
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.ContextAuthRecordKey, authRecord)
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// mock delete event
|
||||
e := new(core.ModelEvent)
|
||||
e.Dao = testApp.Dao()
|
||||
e.App = testApp
|
||||
e.Type = core.ModelEventTypeDelete
|
||||
e.Context = context.Background()
|
||||
e.Model = authRecord
|
||||
testApp.OnModelAfterDelete().Trigger(e)
|
||||
|
||||
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
|
||||
testApp.OnModelAfterDeleteSuccess().Trigger(e)
|
||||
|
||||
if total := len(testApp.SubscriptionsBroker().Clients()); total != 0 {
|
||||
t.Fatalf("Expected no subscription clients, found %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -272,111 +301,58 @@ func TestRealtimeAuthRecordUpdateEvent(t *testing.T) {
|
|||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
apis.InitApi(testApp)
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.ContextAuthRecordKey, authRecord1)
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord and change its email
|
||||
authRecord2, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authRecord2.SetEmail("new@example.com")
|
||||
|
||||
// mock update event
|
||||
e := new(core.ModelEvent)
|
||||
e.Dao = testApp.Dao()
|
||||
e.App = testApp
|
||||
e.Type = core.ModelEventTypeUpdate
|
||||
e.Context = context.Background()
|
||||
e.Model = authRecord2
|
||||
testApp.OnModelAfterUpdate().Trigger(e)
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
|
||||
testApp.OnModelAfterUpdateSuccess().Trigger(e)
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuthRecord.Email() != authRecord2.Email() {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeAdminDeleteEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
apis.InitApi(testApp)
|
||||
|
||||
admin, err := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.ContextAdminKey, admin)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
e := new(core.ModelEvent)
|
||||
e.Dao = testApp.Dao()
|
||||
e.Model = admin
|
||||
testApp.OnModelAfterDelete().Trigger(e)
|
||||
|
||||
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeAdminUpdateEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
apis.InitApi(testApp)
|
||||
|
||||
admin1, err := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.ContextAdminKey, admin1)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord and change its email
|
||||
admin2, err := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
admin2.Email = "new@example.com"
|
||||
|
||||
e := new(core.ModelEvent)
|
||||
e.Dao = testApp.Dao()
|
||||
e.Model = admin2
|
||||
testApp.OnModelAfterUpdate().Trigger(e)
|
||||
|
||||
clientAdmin, _ := client.Get(apis.ContextAdminKey).(*models.Admin)
|
||||
if clientAdmin.Email != admin2.Email {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
|
||||
}
|
||||
}
|
||||
|
||||
// Custom auth record model struct
|
||||
// -------------------------------------------------------------------
|
||||
var _ models.Model = (*CustomUser)(nil)
|
||||
var _ core.Model = (*CustomUser)(nil)
|
||||
|
||||
type CustomUser struct {
|
||||
models.BaseModel
|
||||
core.BaseModel
|
||||
|
||||
Email string `db:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (m *CustomUser) TableName() string {
|
||||
return "users" // the name of your collection
|
||||
return "users"
|
||||
}
|
||||
|
||||
func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) {
|
||||
func findCustomUserByEmail(app core.App, email string) (*CustomUser, error) {
|
||||
model := &CustomUser{}
|
||||
|
||||
err := dao.ModelQuery(model).
|
||||
err := app.ModelQuery(model).
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
|
@ -392,30 +368,31 @@ func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
|
|||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
apis.InitApi(testApp)
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.ContextAuthRecordKey, authRecord)
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord as CustomUser
|
||||
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
|
||||
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// delete the custom user (should unset the client auth record)
|
||||
if err := testApp.Dao().Delete(customUser); err != nil {
|
||||
if err := testApp.Delete(customUser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
|
||||
if total := len(testApp.SubscriptionsBroker().Clients()); total != 0 {
|
||||
t.Fatalf("Expected no subscription clients, found %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -423,30 +400,31 @@ func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
|
|||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
apis.InitApi(testApp)
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.ContextAuthRecordKey, authRecord)
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord as CustomUser
|
||||
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
|
||||
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// change its email
|
||||
customUser.Email = "new@example.com"
|
||||
if err := testApp.Dao().Save(customUser); err != nil {
|
||||
if err := testApp.Save(customUser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
|
||||
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuthRecord.Email() != customUser.Email {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,765 +1,75 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/resolvers"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindRecordAuthApi registers the auth record api endpoints and
|
||||
// the corresponding handlers.
|
||||
func bindRecordAuthApi(app core.App, rg *echo.Group) {
|
||||
api := recordAuthApi{app: app}
|
||||
|
||||
func bindRecordAuthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
// global oauth2 subscription redirect handler
|
||||
rg.GET("/oauth2-redirect", api.oauth2SubscriptionRedirect)
|
||||
rg.POST("/oauth2-redirect", api.oauth2SubscriptionRedirect) // needed in case of response_mode=form_post
|
||||
rg.GET("/oauth2-redirect", oauth2SubscriptionRedirect)
|
||||
// add again as POST in case of response_mode=form_post
|
||||
rg.POST("/oauth2-redirect", oauth2SubscriptionRedirect)
|
||||
|
||||
// common collection record related routes
|
||||
subGroup := rg.Group(
|
||||
"/collections/:collection",
|
||||
ActivityLogger(app),
|
||||
LoadCollectionContext(app, models.CollectionTypeAuth),
|
||||
sub := rg.Group("/collections/{collection}")
|
||||
|
||||
sub.GET("/auth-methods", recordAuthMethods).Bind(
|
||||
collectionPathRateLimit("", "listAuthMethods"),
|
||||
)
|
||||
subGroup.GET("/auth-methods", api.authMethods)
|
||||
subGroup.POST("/auth-refresh", api.authRefresh, RequireSameContextRecordAuth())
|
||||
subGroup.POST("/auth-with-oauth2", api.authWithOAuth2)
|
||||
subGroup.POST("/auth-with-password", api.authWithPassword)
|
||||
subGroup.POST("/request-password-reset", api.requestPasswordReset)
|
||||
subGroup.POST("/confirm-password-reset", api.confirmPasswordReset)
|
||||
subGroup.POST("/request-verification", api.requestVerification)
|
||||
subGroup.POST("/confirm-verification", api.confirmVerification)
|
||||
subGroup.POST("/request-email-change", api.requestEmailChange, RequireSameContextRecordAuth())
|
||||
subGroup.POST("/confirm-email-change", api.confirmEmailChange)
|
||||
subGroup.GET("/records/:id/external-auths", api.listExternalAuths, RequireAdminOrOwnerAuth("id"))
|
||||
subGroup.DELETE("/records/:id/external-auths/:provider", api.unlinkExternalAuth, RequireAdminOrOwnerAuth("id"))
|
||||
|
||||
sub.POST("/auth-refresh", recordAuthRefresh).Bind(
|
||||
collectionPathRateLimit("", "authRefresh"),
|
||||
RequireSameCollectionContextAuth(""),
|
||||
)
|
||||
|
||||
sub.POST("/auth-with-password", recordAuthWithPassword).Bind(
|
||||
collectionPathRateLimit("", "authWithPassword", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/auth-with-oauth2", recordAuthWithOAuth2).Bind(
|
||||
collectionPathRateLimit("", "authWithOAuth2", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/request-otp", recordRequestOTP).Bind(
|
||||
collectionPathRateLimit("", "requestOTP"),
|
||||
)
|
||||
sub.POST("/auth-with-otp", recordAuthWithOTP).Bind(
|
||||
collectionPathRateLimit("", "authWithOTP", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/request-password-reset", recordRequestPasswordReset).Bind(
|
||||
collectionPathRateLimit("", "requestPasswordReset"),
|
||||
)
|
||||
sub.POST("/confirm-password-reset", recordConfirmPasswordReset).Bind(
|
||||
collectionPathRateLimit("", "confirmPasswordReset"),
|
||||
)
|
||||
|
||||
sub.POST("/request-verification", recordRequestVerification).Bind(
|
||||
collectionPathRateLimit("", "requestVerification"),
|
||||
)
|
||||
sub.POST("/confirm-verification", recordConfirmVerification).Bind(
|
||||
collectionPathRateLimit("", "confirmVerification"),
|
||||
)
|
||||
|
||||
sub.POST("/request-email-change", recordRequestEmailChange).Bind(
|
||||
collectionPathRateLimit("", "requestEmailChange"),
|
||||
RequireSameCollectionContextAuth(""),
|
||||
)
|
||||
sub.POST("/confirm-email-change", recordConfirmEmailChange).Bind(
|
||||
collectionPathRateLimit("", "confirmEmailChange"),
|
||||
)
|
||||
|
||||
sub.POST("/impersonate/{id}", recordAuthImpersonate).Bind(RequireSuperuserAuth())
|
||||
}
|
||||
|
||||
type recordAuthApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) authRefresh(c echo.Context) error {
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if record == nil {
|
||||
return NewNotFoundError("Missing auth record context.", nil)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRefreshEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = record.Collection()
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error {
|
||||
return api.app.OnRecordAfterAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error {
|
||||
return RecordAuthResponse(api.app, e.HttpContext, e.Record, nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type providerInfo struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
State string `json:"state"`
|
||||
AuthUrl string `json:"authUrl"`
|
||||
// technically could be omitted if the provider doesn't support PKCE,
|
||||
// but to avoid breaking existing typed clients we'll return them as empty string
|
||||
CodeVerifier string `json:"codeVerifier"`
|
||||
CodeChallenge string `json:"codeChallenge"`
|
||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) authMethods(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
authOptions := collection.AuthOptions()
|
||||
|
||||
result := struct {
|
||||
AuthProviders []providerInfo `json:"authProviders"`
|
||||
UsernamePassword bool `json:"usernamePassword"`
|
||||
EmailPassword bool `json:"emailPassword"`
|
||||
OnlyVerified bool `json:"onlyVerified"`
|
||||
}{
|
||||
UsernamePassword: authOptions.AllowUsernameAuth,
|
||||
EmailPassword: authOptions.AllowEmailAuth,
|
||||
OnlyVerified: authOptions.OnlyVerified,
|
||||
AuthProviders: []providerInfo{},
|
||||
}
|
||||
|
||||
if !authOptions.AllowOAuth2Auth {
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
nameConfigMap := api.app.Settings().NamedAuthProviderConfigs()
|
||||
for name, config := range nameConfigMap {
|
||||
if !config.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
provider, err := auth.NewProviderByName(name)
|
||||
if err != nil {
|
||||
api.app.Logger().Debug("Missing or invalid provider name", slog.String("name", name))
|
||||
continue // skip provider
|
||||
}
|
||||
|
||||
if err := config.SetupProvider(provider); err != nil {
|
||||
api.app.Logger().Debug(
|
||||
"Failed to setup provider",
|
||||
slog.String("name", name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
continue // skip provider
|
||||
}
|
||||
|
||||
info := providerInfo{
|
||||
Name: name,
|
||||
DisplayName: provider.DisplayName(),
|
||||
State: security.RandomString(30),
|
||||
}
|
||||
|
||||
if info.DisplayName == "" {
|
||||
info.DisplayName = name
|
||||
}
|
||||
|
||||
urlOpts := []oauth2.AuthCodeOption{}
|
||||
|
||||
// custom providers url options
|
||||
switch name {
|
||||
case auth.NameApple:
|
||||
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
|
||||
}
|
||||
|
||||
if provider.PKCE() {
|
||||
info.CodeVerifier = security.RandomString(43)
|
||||
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
|
||||
info.CodeChallengeMethod = "S256"
|
||||
urlOpts = append(urlOpts,
|
||||
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
|
||||
)
|
||||
}
|
||||
|
||||
info.AuthUrl = provider.BuildAuthUrl(
|
||||
info.State,
|
||||
urlOpts...,
|
||||
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
|
||||
|
||||
result.AuthProviders = append(result.AuthProviders, info)
|
||||
}
|
||||
|
||||
// sort providers
|
||||
sort.SliceStable(result.AuthProviders, func(i, j int) bool {
|
||||
return result.AuthProviders[i].Name < result.AuthProviders[j].Name
|
||||
})
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) authWithOAuth2(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
if !collection.AuthOptions().AllowOAuth2Auth {
|
||||
return NewBadRequestError("The collection is not configured to allow OAuth2 authentication.", nil)
|
||||
}
|
||||
|
||||
var fallbackAuthRecord *models.Record
|
||||
|
||||
loggedAuthRecord, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if loggedAuthRecord != nil && loggedAuthRecord.Collection().Id == collection.Id {
|
||||
fallbackAuthRecord = loggedAuthRecord
|
||||
}
|
||||
|
||||
form := forms.NewRecordOAuth2Login(api.app, collection, fallbackAuthRecord)
|
||||
if readErr := c.Bind(form); readErr != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthWithOAuth2Event)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.ProviderName = form.Provider
|
||||
|
||||
form.SetBeforeNewRecordCreateFunc(func(createForm *forms.RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error {
|
||||
return createForm.DrySubmit(func(txDao *daos.Dao) error {
|
||||
event.IsNewRecord = true
|
||||
|
||||
// clone the current request data and assign the form create data as its body data
|
||||
requestInfo := *RequestInfo(c)
|
||||
requestInfo.Context = models.RequestInfoContextOAuth2
|
||||
requestInfo.Data = form.CreateData
|
||||
|
||||
createRuleFunc := func(q *dbx.SelectQuery) error {
|
||||
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin != nil {
|
||||
return nil // either admin or the rule is empty
|
||||
}
|
||||
|
||||
if collection.CreateRule == nil {
|
||||
return errors.New("Only admins can create new accounts with OAuth2")
|
||||
}
|
||||
|
||||
if *collection.CreateRule != "" {
|
||||
resolver := resolvers.NewRecordFieldResolver(txDao, collection, &requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.CreateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := txDao.FindRecordById(collection.Id, createForm.Id, createRuleFunc); err != nil {
|
||||
return fmt.Errorf("Failed create rule constraint: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
_, _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData]) forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData] {
|
||||
return func(data *forms.RecordOAuth2LoginData) error {
|
||||
event.Record = data.Record
|
||||
event.OAuth2User = data.OAuth2User
|
||||
event.ProviderClient = data.ProviderClient
|
||||
event.IsNewRecord = data.Record == nil
|
||||
|
||||
return api.app.OnRecordBeforeAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error {
|
||||
data.Record = e.Record
|
||||
data.OAuth2User = e.OAuth2User
|
||||
|
||||
if err := next(data); err != nil {
|
||||
return NewBadRequestError("Failed to authenticate.", err)
|
||||
}
|
||||
|
||||
e.Record = data.Record
|
||||
e.OAuth2User = data.OAuth2User
|
||||
|
||||
meta := struct {
|
||||
*auth.AuthUser
|
||||
IsNew bool `json:"isNew"`
|
||||
}{
|
||||
AuthUser: e.OAuth2User,
|
||||
IsNew: event.IsNewRecord,
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error {
|
||||
// clear the lastLoginAlertSentAt field so that we can enforce password auth notifications
|
||||
if !e.Record.LastLoginAlertSentAt().IsZero() {
|
||||
e.Record.Set(schema.FieldNameLastLoginAlertSentAt, "")
|
||||
if err := api.app.Dao().SaveRecord(e.Record); err != nil {
|
||||
api.app.Logger().Warn("Failed to reset lastLoginAlertSentAt", "error", err, "recordId", e.Record.Id)
|
||||
}
|
||||
}
|
||||
|
||||
return RecordAuthResponse(api.app, e.HttpContext, e.Record, meta)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) authWithPassword(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordPasswordLogin(api.app, collection)
|
||||
if readErr := c.Bind(form); readErr != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthWithPasswordEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Password = form.Password
|
||||
event.Identity = form.Identity
|
||||
|
||||
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to authenticate.", err)
|
||||
}
|
||||
|
||||
// @todo remove after the refactoring
|
||||
if collection.AuthOptions().AllowOAuth2Auth && e.Record.Email() != "" {
|
||||
externalAuths, err := api.app.Dao().FindAllExternalAuthsByRecord(e.Record)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to authenticate.", err)
|
||||
}
|
||||
if len(externalAuths) > 0 {
|
||||
lastLoginAlert := e.Record.LastLoginAlertSentAt().Time()
|
||||
|
||||
// send an email alert if the password auth is after OAuth2 auth (lastLoginAlert will be empty)
|
||||
// or if it has been ~7 days since the last alert
|
||||
if lastLoginAlert.IsZero() || time.Now().UTC().Sub(lastLoginAlert).Hours() > 168 {
|
||||
providerNames := make([]string, len(externalAuths))
|
||||
for i, ea := range externalAuths {
|
||||
var name string
|
||||
if provider, err := auth.NewProviderByName(ea.Provider); err == nil {
|
||||
name = provider.DisplayName()
|
||||
}
|
||||
if name == "" {
|
||||
name = ea.Provider
|
||||
}
|
||||
providerNames[i] = name
|
||||
}
|
||||
|
||||
if err := mails.SendRecordPasswordLoginAlert(api.app, e.Record, providerNames...); err != nil {
|
||||
return NewBadRequestError("Failed to authenticate.", err)
|
||||
}
|
||||
|
||||
e.Record.SetLastLoginAlertSentAt(types.NowDateTime())
|
||||
if err := api.app.Dao().SaveRecord(e.Record); err != nil {
|
||||
api.app.Logger().Warn("Failed to update lastLoginAlertSentAt", "error", err, "recordId", e.Record.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error {
|
||||
return RecordAuthResponse(api.app, e.HttpContext, e.Record, nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) requestPasswordReset(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
authOptions := collection.AuthOptions()
|
||||
if !authOptions.AllowUsernameAuth && !authOptions.AllowEmailAuth {
|
||||
return NewBadRequestError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordPasswordResetRequest(api.app, collection)
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
if err := form.Validate(); err != nil {
|
||||
return NewBadRequestError("An error occurred while validating the form.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestPasswordResetEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetEvent) error {
|
||||
// run in background because we don't need to show the result to the client
|
||||
routine.FireAndForget(func() {
|
||||
if err := next(e.Record); err != nil {
|
||||
api.app.Logger().Debug(
|
||||
"Failed to send password reset email",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
return api.app.OnRecordAfterRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// eagerly write 204 response and skip submit errors
|
||||
// as a measure against emails enumeration
|
||||
if !c.Response().Committed {
|
||||
c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) confirmPasswordReset(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordPasswordResetConfirm(api.app, collection)
|
||||
if readErr := c.Bind(form); readErr != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmPasswordResetEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to set new password.", err)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) requestVerification(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordVerificationRequest(api.app, collection)
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
if err := form.Validate(); err != nil {
|
||||
return NewBadRequestError("An error occurred while validating the form.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestVerificationEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationEvent) error {
|
||||
// run in background because we don't need to show the result to the client
|
||||
routine.FireAndForget(func() {
|
||||
if err := next(e.Record); err != nil {
|
||||
api.app.Logger().Debug(
|
||||
"Failed to send verification email",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
return api.app.OnRecordAfterRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// eagerly write 204 response and skip submit errors
|
||||
// as a measure against users enumeration
|
||||
if !c.Response().Committed {
|
||||
c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) confirmVerification(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordVerificationConfirm(api.app, collection)
|
||||
if readErr := c.Bind(form); readErr != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmVerificationEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("An error occurred while submitting the form.", err)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) requestEmailChange(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
if record == nil {
|
||||
return NewUnauthorizedError("The request requires valid auth record.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordEmailChangeRequest(api.app, record)
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestEmailChangeEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
return api.app.OnRecordBeforeRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to request email change.", err)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) confirmEmailChange(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
form := forms.NewRecordEmailChangeConfirm(api.app, collection)
|
||||
if readErr := c.Bind(form); readErr != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmEmailChangeEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
|
||||
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(record *models.Record) error {
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to confirm email change.", err)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return submitErr
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) listExternalAuths(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
id := c.PathParam("id")
|
||||
if id == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
record, err := api.app.Dao().FindRecordById(collection.Id, id)
|
||||
if err != nil || record == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
externalAuths, err := api.app.Dao().FindAllExternalAuthsByRecord(record)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to fetch the external auths for the specified auth record.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordListExternalAuthsEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.ExternalAuths = externalAuths
|
||||
|
||||
return api.app.OnRecordListExternalAuthsRequest().Trigger(event, func(e *core.RecordListExternalAuthsEvent) error {
|
||||
return e.HttpContext.JSON(http.StatusOK, e.ExternalAuths)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("Missing collection context.", nil)
|
||||
}
|
||||
|
||||
id := c.PathParam("id")
|
||||
provider := c.PathParam("provider")
|
||||
if id == "" || provider == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
record, err := api.app.Dao().FindRecordById(collection.Id, id)
|
||||
if err != nil || record == nil {
|
||||
return NewNotFoundError("", err)
|
||||
}
|
||||
|
||||
externalAuth, err := api.app.Dao().FindExternalAuthByRecordAndProvider(record, provider)
|
||||
if err != nil {
|
||||
return NewNotFoundError("Missing external auth provider relation.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordUnlinkExternalAuthEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.ExternalAuth = externalAuth
|
||||
|
||||
return api.app.OnRecordBeforeUnlinkExternalAuthRequest().Trigger(event, func(e *core.RecordUnlinkExternalAuthEvent) error {
|
||||
if err := api.app.Dao().DeleteExternalAuth(externalAuth); err != nil {
|
||||
return NewBadRequestError("Cannot unlink the external auth provider.", err)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterUnlinkExternalAuthRequest().Trigger(event, func(e *core.RecordUnlinkExternalAuthEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const (
|
||||
oauth2SubscriptionTopic string = "@oauth2"
|
||||
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
|
||||
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
|
||||
)
|
||||
|
||||
type oauth2RedirectData struct {
|
||||
State string `form:"state" query:"state" json:"state"`
|
||||
Code string `form:"code" query:"code" json:"code"`
|
||||
Error string `form:"error" query:"error" json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
|
||||
redirectStatusCode := http.StatusTemporaryRedirect
|
||||
if c.Request().Method != http.MethodGet {
|
||||
redirectStatusCode = http.StatusSeeOther
|
||||
}
|
||||
|
||||
data := oauth2RedirectData{}
|
||||
if err := c.Bind(&data); err != nil {
|
||||
api.app.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
|
||||
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
if data.State == "" {
|
||||
api.app.Logger().Debug("Missing OAuth2 state parameter")
|
||||
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
client, err := api.app.SubscriptionsBroker().ClientById(data.State)
|
||||
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
|
||||
api.app.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
|
||||
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
defer client.Unsubscribe(oauth2SubscriptionTopic)
|
||||
|
||||
encodedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
api.app.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
|
||||
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: oauth2SubscriptionTopic,
|
||||
Data: encodedData,
|
||||
}
|
||||
|
||||
client.Send(msg)
|
||||
|
||||
if data.Error != "" || data.Code == "" {
|
||||
api.app.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
|
||||
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
return c.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
|
||||
func findAuthCollection(e *core.RequestEvent) (*core.Collection, error) {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
|
||||
if err != nil || !collection.IsAuth() {
|
||||
return nil, e.NotFoundError("Missing or invalid auth collection context.", err)
|
||||
}
|
||||
|
||||
return collection, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func recordConfirmEmailChange(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers can change their emails directly.", nil)
|
||||
}
|
||||
|
||||
form := newEmailChangeConfirmForm(e.App, collection)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
authRecord, newEmail, err := form.parseToken()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Invalid or expired token.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmEmailChangeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = authRecord
|
||||
event.NewEmail = newEmail
|
||||
|
||||
return e.App.OnRecordConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeRequestEvent) error {
|
||||
authRecord.Set(core.FieldNameEmail, e.NewEmail)
|
||||
authRecord.Set(core.FieldNameVerified, true)
|
||||
authRecord.RefreshTokenKey() // invalidate old tokens
|
||||
|
||||
if err := e.App.Save(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func newEmailChangeConfirmForm(app core.App, collection *core.Collection) *EmailChangeConfirmForm {
|
||||
return &EmailChangeConfirmForm{
|
||||
app: app,
|
||||
collection: collection,
|
||||
}
|
||||
}
|
||||
|
||||
type EmailChangeConfirmForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
Password string `form:"password" json:"password"`
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 100), validation.By(form.checkPassword)),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) checkToken(value any) error {
|
||||
_, _, err := form.parseToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) checkPassword(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
authRecord, _, _ := form.parseToken()
|
||||
if authRecord == nil || !authRecord.ValidatePassword(v) {
|
||||
return validation.NewError("validation_invalid_password", "Missing or invalid auth record password.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) parseToken() (*core.Record, string, error) {
|
||||
// check token payload
|
||||
claims, _ := security.ParseUnverifiedJWT(form.Token)
|
||||
newEmail, _ := claims[core.TokenClaimNewEmail].(string)
|
||||
if newEmail == "" {
|
||||
return nil, "", validation.NewError("validation_invalid_token_payload", "Invalid token payload - newEmail must be set.")
|
||||
}
|
||||
|
||||
// ensure that there aren't other users with the new email
|
||||
_, err := form.app.FindAuthRecordByEmail(form.collection, newEmail)
|
||||
if err == nil {
|
||||
return nil, "", validation.NewError("validation_existing_token_email", "The new email address is already registered: "+newEmail)
|
||||
}
|
||||
|
||||
// verify that the token is not expired and its signature is valid
|
||||
authRecord, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeEmailChange)
|
||||
if err != nil {
|
||||
return nil, "", validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if authRecord.Collection().Id != form.collection.Id {
|
||||
return nil, "", validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
return authRecord, newEmail, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,205 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmEmailChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-email-change",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{"token`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token and correct password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoxNjQwOTkxNjYxfQ.dff842MO0mgRTHY8dktp0dqG9-7LGQOgRuiAbQpYBls",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{`,
|
||||
`"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-email change token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{`,
|
||||
`"code":"validation_invalid_token_payload"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and incorrect password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567891"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"password":{`,
|
||||
`"code":"validation_invalid_password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and correct password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmEmailChangeRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByEmail("users", "change@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected to find user with email %q, got error: %v", "change@example.com", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token in different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAfterConfirmEmailChangeRequest error response",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmEmailChangeRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:confirmEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmEmailChange"},
|
||||
{MaxRequests: 0, Label: "users:confirmEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
)
|
||||
|
||||
func recordRequestEmailChange(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers can change their emails directly.", nil)
|
||||
}
|
||||
|
||||
record := e.Auth
|
||||
if record == nil {
|
||||
return e.UnauthorizedError("The request requires valid auth record.", nil)
|
||||
}
|
||||
|
||||
form := newEmailChangeRequestForm(e.App, record)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestEmailChangeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.NewEmail = form.NewEmail
|
||||
|
||||
return e.App.OnRecordRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeRequestEvent) error {
|
||||
if err := mails.SendRecordChangeEmail(e.App, e.Record, e.NewEmail); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func newEmailChangeRequestForm(app core.App, record *core.Record) *emailChangeRequestForm {
|
||||
return &emailChangeRequestForm{
|
||||
app: app,
|
||||
record: record,
|
||||
}
|
||||
}
|
||||
|
||||
type emailChangeRequestForm struct {
|
||||
app core.App
|
||||
record *core.Record
|
||||
|
||||
NewEmail string `form:"newEmail" json:"newEmail"`
|
||||
}
|
||||
|
||||
func (form *emailChangeRequestForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.NewEmail,
|
||||
validation.Required,
|
||||
validation.Length(1, 255),
|
||||
is.EmailFormat,
|
||||
validation.NotIn(form.record.Email()),
|
||||
validation.By(form.checkUniqueEmail),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *emailChangeRequestForm) checkUniqueEmail(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
found, _ := form.app.FindAuthRecordByEmail(form.record.Collection(), v)
|
||||
if found != nil && found.Id != form.record.Id {
|
||||
return validation.NewError("validation_invalid_new_email", "Invalid new email address.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestEmailChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "record authentication but from different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser authentication",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"newEmail":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid data (existing email)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"test2@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"newEmail":{"code":"validation_invalid_new_email"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid data (new email)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestEmailChangeRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordEmailChangeSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-email-change") {
|
||||
t.Fatalf("Expected email change email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestEmailChange"},
|
||||
{MaxRequests: 0, Label: "users:requestEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// note: for now allow superusers but it may change in the future to allow access
|
||||
// also to users with "Manage API" rule access depending on the use cases that will arise
|
||||
func recordAuthImpersonate(e *core.RequestEvent) error {
|
||||
if !e.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("", nil)
|
||||
}
|
||||
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, e.Request.PathValue("id"))
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
form := &impersonateForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)
|
||||
if err != nil {
|
||||
e.InternalServerError("Failed to generate static auth token", err)
|
||||
}
|
||||
|
||||
return recordAuthResponse(e, record, token, "", nil)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type impersonateForm struct {
|
||||
// Duration is the optional custom token duration in seconds.
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
}
|
||||
|
||||
func (form *impersonateForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Duration, validation.Min(0)),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthImpersonate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as different user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as the same user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"record":{`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields should remain hidden even though we are authenticated as superuser
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser with custom invalid duration",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: strings.NewReader(`{"duration":-1}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"duration":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser with custom valid duration",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: strings.NewReader(`{"duration":100}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"record":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type otpResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration int64 `json:"duration"` // in seconds
|
||||
}
|
||||
|
||||
type mfaResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration int64 `json:"duration"` // in seconds
|
||||
}
|
||||
|
||||
type passwordResponse struct {
|
||||
IdentityFields []string `json:"identityFields"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type oauth2Response struct {
|
||||
Providers []providerInfo `json:"providers"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type providerInfo struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
State string `json:"state"`
|
||||
AuthURL string `json:"authURL"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use AuthURL instead
|
||||
// AuthUrl will be removed after dropping v0.22 support
|
||||
AuthUrl string `json:"authUrl"`
|
||||
|
||||
// technically could be omitted if the provider doesn't support PKCE,
|
||||
// but to avoid breaking existing typed clients we'll return them as empty string
|
||||
CodeVerifier string `json:"codeVerifier"`
|
||||
CodeChallenge string `json:"codeChallenge"`
|
||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||
}
|
||||
|
||||
type authMethodsResponse struct {
|
||||
Password passwordResponse `json:"password"`
|
||||
OAuth2 oauth2Response `json:"oauth2"`
|
||||
MFA mfaResponse `json:"mfa"`
|
||||
OTP otpResponse `json:"otp"`
|
||||
|
||||
// legacy fields
|
||||
// @todo remove after dropping v0.22 support
|
||||
AuthProviders []providerInfo `json:"authProviders"`
|
||||
UsernamePassword bool `json:"usernamePassword"`
|
||||
EmailPassword bool `json:"emailPassword"`
|
||||
}
|
||||
|
||||
func (amr *authMethodsResponse) fillLegacyFields() {
|
||||
amr.EmailPassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "email")
|
||||
|
||||
amr.UsernamePassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "username")
|
||||
|
||||
if amr.OAuth2.Enabled {
|
||||
amr.AuthProviders = amr.OAuth2.Providers
|
||||
}
|
||||
}
|
||||
|
||||
func recordAuthMethods(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := authMethodsResponse{
|
||||
Password: passwordResponse{
|
||||
IdentityFields: make([]string, 0, len(collection.PasswordAuth.IdentityFields)),
|
||||
},
|
||||
OAuth2: oauth2Response{
|
||||
Providers: make([]providerInfo, 0, len(collection.OAuth2.Providers)),
|
||||
},
|
||||
OTP: otpResponse{
|
||||
Enabled: collection.OTP.Enabled,
|
||||
},
|
||||
MFA: mfaResponse{
|
||||
Enabled: collection.MFA.Enabled,
|
||||
},
|
||||
}
|
||||
|
||||
if collection.PasswordAuth.Enabled {
|
||||
result.Password.Enabled = true
|
||||
result.Password.IdentityFields = collection.PasswordAuth.IdentityFields
|
||||
}
|
||||
|
||||
if collection.OTP.Enabled {
|
||||
result.OTP.Duration = collection.OTP.Duration
|
||||
}
|
||||
|
||||
if collection.MFA.Enabled {
|
||||
result.MFA.Duration = collection.MFA.Duration
|
||||
}
|
||||
|
||||
if !collection.OAuth2.Enabled {
|
||||
result.fillLegacyFields()
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
result.OAuth2.Enabled = true
|
||||
|
||||
for _, config := range collection.OAuth2.Providers {
|
||||
provider, err := config.InitProvider()
|
||||
if err != nil {
|
||||
e.App.Logger().Debug(
|
||||
"Failed to setup OAuth2 provider",
|
||||
slog.String("name", config.Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
continue // skip provider
|
||||
}
|
||||
|
||||
info := providerInfo{
|
||||
Name: config.Name,
|
||||
DisplayName: provider.DisplayName(),
|
||||
State: security.RandomString(30),
|
||||
}
|
||||
|
||||
if info.DisplayName == "" {
|
||||
info.DisplayName = config.Name
|
||||
}
|
||||
|
||||
urlOpts := []oauth2.AuthCodeOption{}
|
||||
|
||||
// custom providers url options
|
||||
switch config.Name {
|
||||
case auth.NameApple:
|
||||
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
|
||||
}
|
||||
|
||||
if provider.PKCE() {
|
||||
info.CodeVerifier = security.RandomString(43)
|
||||
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
|
||||
info.CodeChallengeMethod = "S256"
|
||||
urlOpts = append(urlOpts,
|
||||
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
|
||||
)
|
||||
}
|
||||
|
||||
info.AuthURL = provider.BuildAuthURL(
|
||||
info.State,
|
||||
urlOpts...,
|
||||
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
|
||||
|
||||
info.AuthUrl = info.AuthURL
|
||||
|
||||
result.OAuth2.Providers = append(result.OAuth2.Providers, info)
|
||||
}
|
||||
|
||||
result.fillLegacyFields()
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthMethodsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/missing/auth-methods",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/demo1/auth-methods",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with none auth methods allowed",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"password":{"identityFields":[],"enabled":false}`,
|
||||
`"oauth2":{"providers":[],"enabled":false}`,
|
||||
`"mfa":{"enabled":false,"duration":0}`,
|
||||
`"otp":{"enabled":false,"duration":0}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with all auth methods allowed",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/users/auth-methods",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"password":{"identityFields":["email","username"],"enabled":true}`,
|
||||
`"mfa":{"enabled":true,"duration":1800}`,
|
||||
`"otp":{"enabled":true,"duration":300}`,
|
||||
`"oauth2":{`,
|
||||
`"providers":[{`,
|
||||
`"name":"google"`,
|
||||
`"name":"gitlab"`,
|
||||
`"state":`,
|
||||
`"displayName":`,
|
||||
`"codeVerifier":`,
|
||||
`"codeChallenge":`,
|
||||
`"codeChallengeMethod":`,
|
||||
`"authURL":`,
|
||||
`redirect_uri="`, // ensures that the redirect_uri is the last url param
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - nologin:listAuthMethods",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:listAuthMethods"},
|
||||
{MaxRequests: 0, Label: "nologin:listAuthMethods"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:listAuthMethods",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:listAuthMethods"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func recordRequestOTP(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OTP.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
|
||||
}
|
||||
|
||||
form := &createOTPForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write a dummy 200 response as a very rudimentary user emails enumeration protection
|
||||
e.JSON(http.StatusOK, map[string]string{
|
||||
"otpId": core.GenerateDefaultRandomId(),
|
||||
})
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
event := new(core.RecordCreateOTPRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Password = security.RandomStringWithAlphabet(collection.OTP.Length, "1234567890")
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestOTPRequest().Trigger(event, func(e *core.RecordCreateOTPRequestEvent) error {
|
||||
var otp *core.OTP
|
||||
|
||||
// limit the new OTP creations for a single user
|
||||
if !e.App.IsDev() {
|
||||
otps, err := e.App.FindAllOTPsByRecord(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to fetch previous record OTPs.", err))
|
||||
}
|
||||
|
||||
totalRecent := 0
|
||||
for _, existingOTP := range otps {
|
||||
if !existingOTP.HasExpired(collection.OTP.DurationTime()) {
|
||||
totalRecent++
|
||||
}
|
||||
// use the last issued one
|
||||
if totalRecent > 9 {
|
||||
otp = otps[0] // otps are DESC sorted
|
||||
e.App.Logger().Warn(
|
||||
"Too many OTP requests - reusing the last issued",
|
||||
"email", form.Email,
|
||||
"recordId", e.Record.Id,
|
||||
"otpId", existingOTP.Id,
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if otp == nil {
|
||||
// create new OTP
|
||||
// ---
|
||||
otp = core.NewOTP(e.App)
|
||||
otp.SetCollectionRef(e.Record.Collection().Id)
|
||||
otp.SetRecordRef(e.Record.Id)
|
||||
otp.SetPassword(e.Password)
|
||||
err = e.App.Save(otp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send OTP email
|
||||
// (in the background as a very basic timing attacks and emails enumeration protection)
|
||||
// ---
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
err = mails.SendRecordOTP(app, e.Record, otp.Id, e.Password)
|
||||
if err != nil {
|
||||
app.Logger().Error("Failed to send OTP email", "error", errors.Join(err, e.App.Delete(otp)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{
|
||||
"otpId": otp.Id,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type createOTPForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form createOTPForm) validate() error {
|
||||
return validation.ValidateStruct(&form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRecordRequestOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with disabled otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersCol.OTP.Enabled = false
|
||||
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid request data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"invalid"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"email":{"code":"validation_is_email`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`, // some fake random generated string
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (with < 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert 8 non-expired and 2 expired
|
||||
for i := 0; i < 10; i++ {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = "otp_" + strconv.Itoa(i)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if i >= 8 {
|
||||
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
|
||||
otp.SetRaw("created", expiredDate)
|
||||
otp.SetRaw("updated", expiredDate)
|
||||
}
|
||||
if err := app.SaveNoValidate(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"otpId":"otp_`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("Expected 1 email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (with > 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert 10 non-expired
|
||||
for i := 0; i < 10; i++ {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = "otp_" + strconv.Itoa(i)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.SaveNoValidate(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"otp_9"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected 0 sent emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestOTP"},
|
||||
{MaxRequests: 0, Label: "users:requestOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordConfirmPasswordReset(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
form := new(recordConfirmPasswordResetForm)
|
||||
form.app = e.App
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
authRecord, err := e.App.FindAuthRecordByToken(form.Token, core.TokenTypePasswordReset)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Invalid or expired password reset token.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmPasswordResetRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = authRecord
|
||||
|
||||
return e.App.OnRecordConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetRequestEvent) error {
|
||||
authRecord.SetPassword(form.Password)
|
||||
|
||||
if !authRecord.Verified() {
|
||||
payload, err := security.ParseUnverifiedJWT(form.Token)
|
||||
if err == nil && authRecord.Email() == cast.ToString(payload[core.TokenClaimEmail]) {
|
||||
// mark as verified if the email hasn't changed
|
||||
authRecord.SetVerified(true)
|
||||
}
|
||||
}
|
||||
|
||||
err = form.app.Save(authRecord)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to set new password.", err))
|
||||
}
|
||||
|
||||
form.app.Store().Remove(getPasswordResetResendKey(authRecord))
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordConfirmPasswordResetForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
Password string `form:"password" json:"password"`
|
||||
PasswordConfirm string `form:"passwordConfirm" json:"passwordConfirm"`
|
||||
}
|
||||
|
||||
func (form *recordConfirmPasswordResetForm) validate() error {
|
||||
min := 1
|
||||
passField, ok := form.collection.Fields.GetByName(core.FieldNamePassword).(*core.PasswordField)
|
||||
if ok && passField != nil && passField.Min > 0 {
|
||||
min = passField.Min
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(min, 255)), // the FieldPassword validator will check further the specicic length constraints
|
||||
validation.Field(&form.PasswordConfirm, validation.Required, validation.By(validators.Equal(form.Password))),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordConfirmPasswordResetForm) checkToken(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypePasswordReset)
|
||||
if err != nil || record == nil {
|
||||
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if record.Collection().Id != form.collection.Id {
|
||||
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,345 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
`"passwordConfirm":{"code":"validation_required"`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token and invalid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.5Tm6_6amQqOlX3urAnXlEdmxwG5qQJfiTg6U0hHR1hk",
|
||||
"password":"1234567",
|
||||
"passwordConfirm":"7654321"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
`"passwordConfirm":{"code":"validation_values_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-password reset token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-password-reset?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-password-reset?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (unverified user)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to be unverified")
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to be marked as verified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (unverified user with different email from the one in the token)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to be unverified")
|
||||
}
|
||||
|
||||
// manually change the email to check whether the verified state will be updated
|
||||
user.SetEmail("test_update@example.com")
|
||||
if err := app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to update user test email: %v", err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test_update@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to remain unverified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (verified user)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
// ensure that the user is already verified
|
||||
user.SetVerified(true)
|
||||
if err := app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to update user verified state")
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to remain verified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAfterConfirmPasswordResetRequest error response",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:confirmPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmPasswordReset"},
|
||||
{MaxRequests: 0, Label: "users:confirmPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func recordRequestPasswordReset(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.PasswordAuth.Enabled {
|
||||
return e.BadRequestError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := new(recordRequestPasswordResetForm)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
resendKey := getPasswordResetResendKey(record)
|
||||
if e.App.Store().Has(resendKey) {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return errors.New("try again later - you've already requested a password reset email")
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestPasswordResetRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetRequestEvent) error {
|
||||
// run in background because we don't need to show the result to the client
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
if err := mails.SendRecordPasswordReset(app, e.Record); err != nil {
|
||||
app.Logger().Error("Failed to send password reset email", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
time.AfterFunc(2*time.Minute, func() {
|
||||
app.Store().Remove(resendKey)
|
||||
})
|
||||
})
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordRequestPasswordResetForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form *recordRequestPasswordResetForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
||||
func getPasswordResetResendKey(record *core.Record) string {
|
||||
return "@limitPasswordResetEmail_" + record.Collection().Id + record.Id
|
||||
}
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record in a collection with disabled password login",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestPasswordResetRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordPasswordResetSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-password-reset") {
|
||||
t.Fatalf("Expected password reset email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// simulate recent verification sent
|
||||
authRecord, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resendKey := "@limitPasswordResetEmail_" + authRecord.Collection().Id + authRecord.Id
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestPasswordReset"},
|
||||
{MaxRequests: 0, Label: "users:requestPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordAuthRefresh(e *core.RequestEvent) error {
|
||||
record := e.Auth
|
||||
if record == nil {
|
||||
return e.NotFoundError("Missing auth record context.", nil)
|
||||
}
|
||||
|
||||
currentToken := getAuthTokenFromRequest(e)
|
||||
claims, _ := security.ParseUnverifiedJWT(currentToken)
|
||||
if v, ok := claims[core.TokenClaimRefreshable]; !ok || !cast.ToBool(v) {
|
||||
return e.ForbiddenError("The current auth token is not refreshable.", nil)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRefreshRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = record.Collection()
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshRequestEvent) error {
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, "", nil)
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser trying to refresh the auth of another auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + same auth collection as the token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":`,
|
||||
`"record":`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"emailVisibility":false`,
|
||||
`"email":"test@example.com"`, // the owner can always view their email address
|
||||
`"expand":`,
|
||||
`"rel":`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"missing":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "auth record + same auth collection as the token but static/unrefreshable",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unverified auth record in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im8xeTBkZDBzcGQ3ODZtZCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.Zi0yXE-CNmnbTdVaQEzYZVuECqRdn3LgEM6pmB3XWBE",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "verified auth record in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":`,
|
||||
`"record":`,
|
||||
`"id":"gk390qegs4y47wn"`,
|
||||
`"verified":true`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAfterAuthRefreshRequest error response",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authRefresh",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authRefresh"},
|
||||
{MaxRequests: 0, Label: "users:authRefresh"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authRefresh",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:authRefresh"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,102 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordConfirmVerification(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers are verified by default.", nil)
|
||||
}
|
||||
|
||||
form := new(recordConfirmVerificationForm)
|
||||
form.app = e.App
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeVerification)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired verification token.", err)
|
||||
}
|
||||
|
||||
wasVerified := record.Verified()
|
||||
|
||||
event := new(core.RecordConfirmVerificationRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
|
||||
if wasVerified {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
e.Record.SetVerified(true)
|
||||
|
||||
if err := e.App.Save(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
|
||||
}
|
||||
|
||||
e.App.Store().Remove(getVerificationResendKey(e.Record))
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordConfirmVerificationForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
}
|
||||
|
||||
func (form *recordConfirmVerificationForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordConfirmVerificationForm) checkToken(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
claims, _ := security.ParseUnverifiedJWT(v)
|
||||
email := cast.ToString(claims["email"])
|
||||
if email == "" {
|
||||
return validation.NewError("validation_invalid_token_claims", "Missing email token claim.")
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypeVerification)
|
||||
if err != nil || record == nil {
|
||||
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if record.Collection().Id != form.collection.Id {
|
||||
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
if record.Email() != email {
|
||||
return validation.NewError("validation_token_email_mismatch", "The record email doesn't match with the requested token claims.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,210 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.qqelNNL2Udl6K_TJ282sNHYCpASgA6SIuSVKGfBHMZU"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-verification token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-verification?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-verification?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token (already verified)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdDJAZXhhbXBsZS5jb20ifQ.QQmM3odNFVk6u4J4-5H8IBM3dfk9YCD7mPW-8PhBAI8"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid verification token from a collection without allowed login",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAfterConfirmVerificationRequest error response",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - nologin:confirmVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmVerification"},
|
||||
{MaxRequests: 0, Label: "nologin:confirmVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func recordRequestVerification(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers are verified by default.", nil)
|
||||
}
|
||||
|
||||
form := new(recordRequestVerificationForm)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
resendKey := getVerificationResendKey(record)
|
||||
if !record.Verified() && e.App.Store().Has(resendKey) {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return errors.New("try again later - you've already requested a verification email")
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestVerificationRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationRequestEvent) error {
|
||||
if e.Record.Verified() {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// run in background because we don't need to show the result to the client
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
if err := mails.SendRecordVerification(app, e.Record); err != nil {
|
||||
app.Logger().Error("Failed to send verification email", "error", err)
|
||||
}
|
||||
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
time.AfterFunc(2*time.Minute, func() {
|
||||
app.Store().Remove(resendKey)
|
||||
})
|
||||
})
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordRequestVerificationForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form *recordRequestVerificationForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
||||
func getVerificationResendKey(record *core.Record) string {
|
||||
return "@limitVerificationEmail_" + record.Collection().Id + record.Id
|
||||
}
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "already verified auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test2@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestVerificationRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestVerificationRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordVerificationSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-verification") {
|
||||
t.Fatalf("Expected verification email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
// terminated before firing the event
|
||||
// "OnRecordRequestVerificationRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// simulate recent verification sent
|
||||
authRecord, err := app.FindFirstRecordByData("users", "email", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resendKey := "@limitVerificationEmail_" + authRecord.Collection().Id + authRecord.Id
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestVerification"},
|
||||
{MaxRequests: 0, Label: "users:requestVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func recordAuthWithOAuth2(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OAuth2.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OAuth2 authentication.", nil)
|
||||
}
|
||||
|
||||
var fallbackAuthRecord *core.Record
|
||||
if e.Auth != nil && e.Auth.Collection().Id == collection.Id {
|
||||
fallbackAuthRecord = e.Auth
|
||||
}
|
||||
|
||||
form := new(recordOAuth2LoginForm)
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
if form.RedirectUrl != "" && form.RedirectURL == "" {
|
||||
e.App.Logger().Warn("[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL.")
|
||||
form.RedirectURL = form.RedirectUrl
|
||||
}
|
||||
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// load provider configuration
|
||||
providerConfig, ok := collection.OAuth2.GetProviderConfig(form.Provider)
|
||||
if !ok {
|
||||
return e.InternalServerError("Missing or invalid provider config.", nil)
|
||||
}
|
||||
|
||||
provider, err := providerConfig.InitProvider()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to init provider "+form.Provider, err))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(e.Request.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
provider.SetContext(ctx)
|
||||
provider.SetRedirectURL(form.RedirectURL)
|
||||
|
||||
var opts []oauth2.AuthCodeOption
|
||||
|
||||
if provider.PKCE() {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
|
||||
}
|
||||
|
||||
// fetch token
|
||||
token, err := provider.FetchToken(form.Code, opts...)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 token.", err))
|
||||
}
|
||||
|
||||
// fetch external auth user
|
||||
authUser, err := provider.FetchAuthUser(token)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 user.", err))
|
||||
}
|
||||
|
||||
var authRecord *core.Record
|
||||
|
||||
// check for existing relation with the auth record
|
||||
externalAuthRel, err := e.App.FindFirstExternalAuthByExpr(dbx.HashExp{
|
||||
"collectionRef": form.collection.Id,
|
||||
"provider": form.Provider,
|
||||
"providerId": authUser.Id,
|
||||
})
|
||||
switch {
|
||||
case err == nil && externalAuthRel != nil:
|
||||
authRecord, err = e.App.FindRecordById(form.collection, externalAuthRel.RecordRef())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case fallbackAuthRecord != nil && fallbackAuthRecord.Collection().Id == form.collection.Id:
|
||||
// fallback to the logged auth record (if any)
|
||||
authRecord = fallbackAuthRecord
|
||||
case authUser.Email != "":
|
||||
// look for an existing auth record by the external auth record's email
|
||||
authRecord, _ = e.App.FindAuthRecordByEmail(form.collection.Id, authUser.Email)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
event := new(core.RecordAuthWithOAuth2RequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.ProviderName = form.Provider
|
||||
event.ProviderClient = provider
|
||||
event.OAuth2User = authUser
|
||||
event.CreateData = form.CreateData
|
||||
event.Record = authRecord
|
||||
event.IsNewRecord = authRecord == nil
|
||||
|
||||
return e.App.OnRecordAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2RequestEvent) error {
|
||||
if err := oauth2Submit(e, externalAuthRel); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to authenticate.", err))
|
||||
}
|
||||
|
||||
meta := struct {
|
||||
*auth.AuthUser
|
||||
IsNew bool `json:"isNew"`
|
||||
}{
|
||||
AuthUser: e.OAuth2User,
|
||||
IsNew: e.IsNewRecord,
|
||||
}
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOAuth2, meta)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordOAuth2LoginForm struct {
|
||||
collection *core.Collection
|
||||
|
||||
// Additional data that will be used for creating a new auth record
|
||||
// if an existing OAuth2 account doesn't exist.
|
||||
CreateData map[string]any `form:"createData" json:"createData"`
|
||||
|
||||
// The name of the OAuth2 client provider (eg. "google")
|
||||
Provider string `form:"provider" json:"provider"`
|
||||
|
||||
// The authorization code returned from the initial request.
|
||||
Code string `form:"code" json:"code"`
|
||||
|
||||
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
|
||||
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`
|
||||
|
||||
// The redirect url sent with the initial request.
|
||||
RedirectURL string `form:"redirectURL" json:"redirectURL"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use RedirectURL instead
|
||||
// RedirectUrl will be removed after dropping v0.22 support
|
||||
RedirectUrl string `form:"redirectUrl" json:"redirectUrl"`
|
||||
}
|
||||
|
||||
func (form *recordOAuth2LoginForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Provider, validation.Required, validation.By(form.checkProviderName)),
|
||||
validation.Field(&form.Code, validation.Required),
|
||||
validation.Field(&form.RedirectURL, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordOAuth2LoginForm) checkProviderName(value any) error {
|
||||
name, _ := value.(string)
|
||||
|
||||
_, ok := form.collection.OAuth2.GetProviderConfig(name)
|
||||
if !ok {
|
||||
return validation.NewError("validation_invalid_provider", fmt.Sprintf("Provider with name %q is missing or is not enabled.", name)).
|
||||
SetParams(map[string]any{"name": name})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool {
|
||||
// ensure that username is unique
|
||||
checkUnique := dbutils.HasSingleColumnUniqueIndex(collection.OAuth2.MappedFields.Username, collection.Indexes)
|
||||
if checkUnique {
|
||||
if _, err := txApp.FindFirstRecordByData(collection, collection.OAuth2.MappedFields.Username, username); err == nil {
|
||||
return false // already exist
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that the value matches the pattern of the username field (if text)
|
||||
txtField, _ := collection.Fields.GetByName(collection.OAuth2.MappedFields.Username).(*core.TextField)
|
||||
|
||||
return txtField != nil && txtField.ValidatePlainValue(username) == nil
|
||||
}
|
||||
|
||||
func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *core.ExternalAuth) error {
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
if e.Record == nil {
|
||||
// extra check to prevent creating a superuser record via
|
||||
// OAuth2 in case the method is used by another action
|
||||
if e.Collection.Name == core.CollectionNameSuperusers {
|
||||
return errors.New("superusers are not allowed to sign-up with OAuth2")
|
||||
}
|
||||
|
||||
payload := maps.Clone(e.CreateData)
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
|
||||
payload[core.FieldNameEmail] = e.OAuth2User.Email
|
||||
|
||||
// set a random password if none is set
|
||||
if v, _ := payload[core.FieldNamePassword].(string); v == "" {
|
||||
payload[core.FieldNamePassword] = security.RandomString(30)
|
||||
payload[core.FieldNamePassword+"Confirm"] = payload[core.FieldNamePassword]
|
||||
}
|
||||
|
||||
// map known fields (unless the field was explicitly submitted as part of CreateData)
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Id]; !ok && e.Collection.OAuth2.MappedFields.Id != "" {
|
||||
payload[e.Collection.OAuth2.MappedFields.Id] = e.OAuth2User.Id
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Name]; !ok && e.Collection.OAuth2.MappedFields.Name != "" {
|
||||
payload[e.Collection.OAuth2.MappedFields.Name] = e.OAuth2User.Name
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Username]; !ok &&
|
||||
// no explicit username payload value and existing OAuth2 mapping
|
||||
e.Collection.OAuth2.MappedFields.Username != "" &&
|
||||
// extra checks for backward compatibility with earlier versions
|
||||
oldCanAssignUsername(txApp, e.Collection, e.OAuth2User.Username) {
|
||||
payload[e.Collection.OAuth2.MappedFields.Username] = e.OAuth2User.Username
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.AvatarURL]; !ok && e.Collection.OAuth2.MappedFields.AvatarURL != "" {
|
||||
mappedField := e.Collection.Fields.GetByName(e.Collection.OAuth2.MappedFields.AvatarURL)
|
||||
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
|
||||
// download the avatar if the mapped field is a file
|
||||
avatarFile, err := func() (*filesystem.File, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = avatarFile
|
||||
} else {
|
||||
// otherwise - assign the url string
|
||||
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = e.OAuth2User.AvatarURL
|
||||
}
|
||||
}
|
||||
|
||||
createdRecord, err := sendOAuth2RecordCreateRequest(txApp, e, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Record = createdRecord
|
||||
|
||||
if e.Record.Email() == e.OAuth2User.Email && !e.Record.Verified() {
|
||||
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
|
||||
e.Record.SetVerified(true)
|
||||
if err := txApp.Save(e.Record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var needUpdate bool
|
||||
|
||||
isLoggedAuthRecord := e.Auth != nil &&
|
||||
e.Auth.Id == e.Record.Id &&
|
||||
e.Auth.Collection().Id == e.Record.Collection().Id
|
||||
|
||||
// set random password for users with unverified email
|
||||
// (this is in case a malicious actor has registered previously with the user email)
|
||||
if !isLoggedAuthRecord && e.Record.Email() != "" && !e.Record.Verified() {
|
||||
e.Record.SetPassword(security.RandomString(30))
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
// update the existing auth record empty email if the data.OAuth2User has one
|
||||
// (this is in case previously the auth record was created
|
||||
// with an OAuth2 provider that didn't return an email address)
|
||||
if e.Record.Email() == "" && e.OAuth2User.Email != "" {
|
||||
e.Record.SetEmail(e.OAuth2User.Email)
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
// update the existing auth record verified state
|
||||
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
|
||||
if !e.Record.Verified() && (e.Record.Email() == "" || e.Record.Email() == e.OAuth2User.Email) {
|
||||
e.Record.SetVerified(true)
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if needUpdate {
|
||||
if err := txApp.Save(e.Record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create ExternalAuth relation if missing
|
||||
if optExternalAuth == nil {
|
||||
optExternalAuth = core.NewExternalAuth(txApp)
|
||||
optExternalAuth.SetCollectionRef(e.Record.Collection().Id)
|
||||
optExternalAuth.SetRecordRef(e.Record.Id)
|
||||
optExternalAuth.SetProvider(e.ProviderName)
|
||||
optExternalAuth.SetProviderId(e.OAuth2User.Id)
|
||||
|
||||
if err := txApp.Save(optExternalAuth); err != nil {
|
||||
return fmt.Errorf("failed to save linked rel: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2RequestEvent, payload map[string]any) (*core.Record, error) {
|
||||
ir := &core.InternalRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + e.Collection.Name + "/records",
|
||||
Body: payload,
|
||||
}
|
||||
|
||||
response, err := processInternalRequest(txApp, e.RequestEvent, ir, core.RequestInfoContextOAuth2, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if response.Status != http.StatusOK {
|
||||
return nil, errors.New("failed to create OAuth2 auth record")
|
||||
}
|
||||
|
||||
recordResponse := struct {
|
||||
Id string `json:"id"`
|
||||
}{}
|
||||
|
||||
raw, err := json.Marshal(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(raw, &recordResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return txApp.FindRecordById(e.Collection, recordResponse.Id)
|
||||
}
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2SubscriptionTopic string = "@oauth2"
|
||||
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
|
||||
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
|
||||
)
|
||||
|
||||
type oauth2RedirectData struct {
|
||||
State string `form:"state" json:"state"`
|
||||
Code string `form:"code" json:"code"`
|
||||
Error string `form:"error" json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func oauth2SubscriptionRedirect(e *core.RequestEvent) error {
|
||||
redirectStatusCode := http.StatusTemporaryRedirect
|
||||
if e.Request.Method != http.MethodGet {
|
||||
redirectStatusCode = http.StatusSeeOther
|
||||
}
|
||||
|
||||
data := oauth2RedirectData{}
|
||||
|
||||
if e.Request.Method == http.MethodPost {
|
||||
if err := e.BindBody(&data); err != nil {
|
||||
e.App.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
} else {
|
||||
query := e.Request.URL.Query()
|
||||
data.State = query.Get("state")
|
||||
data.Code = query.Get("code")
|
||||
data.Error = query.Get("error")
|
||||
}
|
||||
|
||||
if data.State == "" {
|
||||
e.App.Logger().Debug("Missing OAuth2 state parameter")
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
client, err := e.App.SubscriptionsBroker().ClientById(data.State)
|
||||
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
|
||||
e.App.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
defer client.Unsubscribe(oauth2SubscriptionTopic)
|
||||
|
||||
encodedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
e.App.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: oauth2SubscriptionTopic,
|
||||
Data: encodedData,
|
||||
}
|
||||
|
||||
client.Send(msg)
|
||||
|
||||
if data.Error != "" || data.Code == "" {
|
||||
e.App.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
|
||||
}
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithOAuth2Redirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
c1 := subscriptions.NewDefaultClient()
|
||||
|
||||
c2 := subscriptions.NewDefaultClient()
|
||||
c2.Subscribe("@oauth2")
|
||||
|
||||
c3 := subscriptions.NewDefaultClient()
|
||||
c3.Subscribe("test1", "@oauth2")
|
||||
|
||||
c4 := subscriptions.NewDefaultClient()
|
||||
c4.Subscribe("test1", "test2")
|
||||
|
||||
c5 := subscriptions.NewDefaultClient()
|
||||
c5.Subscribe("@oauth2")
|
||||
c5.Discard()
|
||||
|
||||
clientStubs = append(clientStubs, map[string]subscriptions.Client{
|
||||
"c1": c1,
|
||||
"c2": c2,
|
||||
"c3": c3,
|
||||
"c4": c4,
|
||||
"c5": c5,
|
||||
})
|
||||
}
|
||||
|
||||
checkFailureRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-failure") {
|
||||
t.Fatalf("Expected failure redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
checkSuccessRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-success") {
|
||||
t.Fatalf("Expected success redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
|
||||
if len(expectedMessages[clientId]) == 0 {
|
||||
t.Fatalf("Unexpected client %q message, got %s:\n%s", clientId, msg.Name, msg.Data)
|
||||
}
|
||||
|
||||
if msg.Name != "@oauth2" {
|
||||
t.Fatalf("Expected @oauth2 msg.Name, got %q", msg.Name)
|
||||
}
|
||||
|
||||
for _, txt := range expectedMessages[clientId] {
|
||||
if !strings.Contains(string(msg.Data), txt) {
|
||||
t.Fatalf("Failed to find %q in \n%s", txt, msg.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
beforeTestFunc := func(
|
||||
clients map[string]subscriptions.Client,
|
||||
expectedMessages map[string][]string,
|
||||
) func(testing.TB, *tests.TestApp, *core.ServeEvent) {
|
||||
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
for _, client := range clients {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
// add to the app store so that it can be cancelled manually after test completion
|
||||
app.Store().Set("cancelFunc", cancelFunc)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-clients["c1"].Channel():
|
||||
checkClientMessages(t, "c1", msg, expectedMessages)
|
||||
case msg := <-clients["c2"].Channel():
|
||||
checkClientMessages(t, "c2", msg, expectedMessages)
|
||||
case msg := <-clients["c3"].Channel():
|
||||
checkClientMessages(t, "c3", msg, expectedMessages)
|
||||
case msg := <-clients["c4"].Channel():
|
||||
checkClientMessages(t, "c4", msg, expectedMessages)
|
||||
case msg := <-clients["c5"].Channel():
|
||||
checkClientMessages(t, "c5", msg, expectedMessages)
|
||||
case <-ctx.Done():
|
||||
for _, c := range clients {
|
||||
close(c.Channel())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "no state query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "invalid or missing client",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=missing",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "no code query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "error query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "discarded client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "(POST) client with @oauth2 subscription",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/oauth2-redirect",
|
||||
Body: strings.NewReader("code=123&state=" + clientStubs[7]["c3"].Id()),
|
||||
Headers: map[string]string{
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[7], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[7]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusSeeOther,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[7]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,99 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func recordAuthWithOTP(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OTP.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
|
||||
}
|
||||
|
||||
form := &authWithOTPForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthWithOTPRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
// extra validations
|
||||
// (note: returns a generic 400 as a very basic OTPs enumeration protection)
|
||||
// ---
|
||||
event.OTP, err = e.App.FindOTPById(form.OTPId)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired OTP", err)
|
||||
}
|
||||
|
||||
if event.OTP.CollectionRef() != collection.Id {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is for a different collection"))
|
||||
}
|
||||
|
||||
if event.OTP.HasExpired(collection.OTP.DurationTime()) {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is expired"))
|
||||
}
|
||||
|
||||
event.Record, err = e.App.FindRecordById(event.OTP.CollectionRef(), event.OTP.RecordRef())
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired OTP", fmt.Errorf("missing auth record: %w", err))
|
||||
}
|
||||
|
||||
// since otps are usually simple digit numbers we enforce an extra rate limit rule to prevent enumerations
|
||||
err = checkRateLimit(e, "@pb_otp_"+event.OTP.Id+event.Record.Id, core.RateLimitRule{MaxRequests: 4, Duration: 180})
|
||||
if err != nil {
|
||||
return e.TooManyRequestsError("Too many attempts, please try again later with a new OTP.", nil)
|
||||
}
|
||||
|
||||
if !event.OTP.ValidatePassword(form.Password) {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("incorrect password"))
|
||||
}
|
||||
// ---
|
||||
|
||||
return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error {
|
||||
err = RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// try to delete the used otp
|
||||
if e.OTP != nil {
|
||||
err = e.App.Delete(e.OTP)
|
||||
if err != nil {
|
||||
e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id)
|
||||
}
|
||||
}
|
||||
|
||||
// note: we don't update the user verified state the same way as in the password reset confirmation
|
||||
// at the moment because it is not clear whether the otp confirmation came from the user email
|
||||
// (e.g. it could be from an sms or some other channel)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type authWithOTPForm struct {
|
||||
OTPId string `form:"otpId" json:"otpId"`
|
||||
Password string `form:"password" json:"password"`
|
||||
}
|
||||
|
||||
func (form *authWithOTPForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.OTPId, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 71)),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,438 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-with-otp",
|
||||
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with disabled otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersCol.OTP.Enabled = false
|
||||
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"otpId":{"code":"validation_required"`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid request data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 256) + `",
|
||||
"password":"` + strings.Repeat("a", 72) + `"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"otpId":{"code":"validation_length_out_of_range"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"missing",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "otp for different collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(client.Collection().Id)
|
||||
otp.SetRecordRef(client.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "otp with wrong password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("1234567890")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired otp with valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
|
||||
otp.SetRaw("created", expiredDate)
|
||||
otp.SetRaw("updated", expiredDate)
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"mfaId":"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
// ---
|
||||
"OnModelValidate": 1,
|
||||
"OnModelCreate": 1, // mfa record
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1, // otp delete
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password (disabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err := app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"meta":`,
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// ---
|
||||
"OnModelValidate": 1,
|
||||
"OnModelCreate": 1, // authOrigin
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1, // otp delete
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authWithOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithOTP"},
|
||||
{MaxRequests: 100, Label: "users:auth"},
|
||||
{MaxRequests: 0, Label: "users:authWithOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authWithOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:auth"},
|
||||
{MaxRequests: 0, Label: "*:authWithOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - users:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithOTP"},
|
||||
{MaxRequests: 0, Label: "users:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthWithOTPManualRateLimiterCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var storeCache map[string]any
|
||||
|
||||
otpAId := strings.Repeat("a", 15)
|
||||
otpBId := strings.Repeat("b", 15)
|
||||
|
||||
scenarios := []struct {
|
||||
otpId string
|
||||
password string
|
||||
expectedStatus int
|
||||
}{
|
||||
{otpAId, "12345", 400},
|
||||
{otpAId, "12345", 400},
|
||||
{otpAId, "12345", 400},
|
||||
{otpAId, "12345", 400},
|
||||
{otpAId, "123456", 429},
|
||||
{otpBId, "12345", 400},
|
||||
{otpBId, "123456", 200},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
(&tests.ApiScenario{
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + s.otpId + `",
|
||||
"password":"` + s.password + `"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
for k, v := range storeCache {
|
||||
app.Store().Set(k, v)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err := app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, id := range []string{otpAId, otpBId} {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = id
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: s.expectedStatus,
|
||||
ExpectedContent: []string{`"`}, // it doesn't matter anything non-empty
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
storeCache = app.Store().GetAll()
|
||||
},
|
||||
}).Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func recordAuthWithPassword(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.PasswordAuth.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := &authWithPasswordForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(collection); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
var foundRecord *core.Record
|
||||
var foundErr error
|
||||
|
||||
if form.IdentityField != "" {
|
||||
foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, form.IdentityField, form.Identity)
|
||||
} else {
|
||||
// prioritize email lookup
|
||||
isEmail := is.EmailFormat.Validate(form.Identity) == nil
|
||||
if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) {
|
||||
foundRecord, foundErr = e.App.FindAuthRecordByEmail(collection.Id, form.Identity)
|
||||
}
|
||||
|
||||
// search by the other identity fields
|
||||
if !isEmail || foundErr != nil {
|
||||
for _, name := range collection.PasswordAuth.IdentityFields {
|
||||
if !isEmail && name == core.FieldNameEmail {
|
||||
continue // no need to search by the email field if it is not an email
|
||||
}
|
||||
|
||||
foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, name, form.Identity)
|
||||
if foundErr == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ignore not found errors to allow custom record find implementations
|
||||
if foundErr != nil && !errors.Is(foundErr, sql.ErrNoRows) {
|
||||
return e.InternalServerError("", foundErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthWithPasswordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = foundRecord
|
||||
event.Identity = form.Identity
|
||||
event.Password = form.Password
|
||||
event.IdentityField = form.IdentityField
|
||||
|
||||
return e.App.OnRecordAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordRequestEvent) error {
|
||||
if e.Record == nil || !e.Record.ValidatePassword(e.Password) {
|
||||
return e.BadRequestError("Failed to authenticate.", errors.New("invalid login credentials"))
|
||||
}
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodPassword, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type authWithPasswordForm struct {
|
||||
Identity string `form:"identity" json:"identity"`
|
||||
Password string `form:"password" json:"password"`
|
||||
|
||||
// IdentityField specifies the field to use to search for the identity
|
||||
// (leave it empty for "auto" detection).
|
||||
IdentityField string `form:"identityField" json:"identityField"`
|
||||
}
|
||||
|
||||
func (form *authWithPasswordForm) validate(collection *core.Collection) error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Identity, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.IdentityField, validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...)),
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,514 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "disabled password auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body params",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"","password":""}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"identity":{`,
|
||||
`"password":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthWithPasswordRequest error response",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and invalid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"invalid"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field (email) and valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field (username) and valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and valid password with mismatched explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and valid password with matched explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity (unverified) and valid password in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test2@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "already authenticated record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"gk390qegs4y47wn"`,
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with mfa first auth check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{
|
||||
`"mfaId":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
// mfa create
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(mfas); v != 1 {
|
||||
t.Fatalf("Expected 1 mfa record to be created, got %d", v)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with mfa second auth check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"mfaId": "` + strings.Repeat("a", 15) + `",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.Id = strings.Repeat("a", 15)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("test")
|
||||
if err := app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 0, // disabled auth email alerts
|
||||
"OnMailerRecordAuthAlertSend": 0,
|
||||
// mfa delete
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with enabled mfa but unsatisfied mfa rule (aka. skip the mfa check)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
users, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users.MFA.Enabled = true
|
||||
users.MFA.Rule = "1=2"
|
||||
|
||||
if err := app.Save(users); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 0, // disabled auth email alerts
|
||||
"OnMailerRecordAuthAlertSend": 0,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(mfas); v != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", v)
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authWithPassword",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithPassword"},
|
||||
{MaxRequests: 100, Label: "users:auth"},
|
||||
{MaxRequests: 0, Label: "users:authWithPassword"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authWithPassword",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:auth"},
|
||||
{MaxRequests: 0, Label: "*:authWithPassword"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - users:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithPassword"},
|
||||
{MaxRequests: 0, Label: "users:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,121 +1,123 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/resolvers"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindRecordCrudApi registers the record crud api endpoints and
|
||||
// the corresponding handlers.
|
||||
func bindRecordCrudApi(app core.App, rg *echo.Group) {
|
||||
api := recordApi{app: app}
|
||||
|
||||
subGroup := rg.Group(
|
||||
"/collections/:collection",
|
||||
ActivityLogger(app),
|
||||
)
|
||||
|
||||
subGroup.GET("/records", api.list, LoadCollectionContext(app))
|
||||
subGroup.GET("/records/:id", api.view, LoadCollectionContext(app))
|
||||
subGroup.POST("/records", api.create, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
|
||||
subGroup.PATCH("/records/:id", api.update, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
|
||||
subGroup.DELETE("/records/:id", api.delete, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
|
||||
//
|
||||
// note: the rate limiter is "inlined" because some of the crud actions are also used in the batch APIs
|
||||
func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
|
||||
subGroup.GET("", recordsList)
|
||||
subGroup.GET("/{id}", recordView)
|
||||
subGroup.POST("", recordCreate(nil)).Bind(dynamicCollectionBodyLimit(""))
|
||||
subGroup.PATCH("/{id}", recordUpdate(nil)).Bind(dynamicCollectionBodyLimit(""))
|
||||
subGroup.DELETE("/{id}", recordDelete(nil))
|
||||
}
|
||||
|
||||
type recordApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (api *recordApi) list(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("", "Missing collection context.")
|
||||
func recordsList(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
requestInfo := RequestInfo(c)
|
||||
|
||||
// forbid users and guests to query special filter/sort fields
|
||||
if err := checkForAdminOnlyRuleFields(requestInfo); err != nil {
|
||||
err = checkCollectionRateLimit(e, collection, "list")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if requestInfo.Admin == nil && collection.ListRule == nil {
|
||||
// only admins can access if the rule is nil
|
||||
return NewForbiddenError("Only admins can perform this action.", nil)
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
fieldsResolver := resolvers.NewRecordFieldResolver(
|
||||
api.app.Dao(),
|
||||
if collection.ListRule == nil && !requestInfo.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
// forbid users and guests to query special filter/sort fields
|
||||
err = checkForSuperuserOnlyRuleFields(requestInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldsResolver := core.NewRecordFieldResolver(
|
||||
e.App,
|
||||
collection,
|
||||
requestInfo,
|
||||
// hidden fields are searchable only by admins
|
||||
requestInfo.Admin != nil,
|
||||
// hidden fields are searchable only by superusers
|
||||
requestInfo.HasSuperuserAuth(),
|
||||
)
|
||||
|
||||
searchProvider := search.NewProvider(fieldsResolver).
|
||||
Query(api.app.Dao().RecordQuery(collection))
|
||||
Query(e.App.RecordQuery(collection))
|
||||
|
||||
if requestInfo.Admin == nil && collection.ListRule != nil {
|
||||
if !requestInfo.HasSuperuserAuth() && collection.ListRule != nil {
|
||||
searchProvider.AddFilter(search.FilterData(*collection.ListRule))
|
||||
}
|
||||
|
||||
records := []*models.Record{}
|
||||
records := []*core.Record{}
|
||||
|
||||
result, err := searchProvider.ParseAndExec(c.QueryParams().Encode(), &records)
|
||||
result, err := searchProvider.ParseAndExec(e.Request.URL.Query().Encode(), &records)
|
||||
if err != nil {
|
||||
return NewBadRequestError("", err)
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordsListEvent)
|
||||
event.HttpContext = c
|
||||
event := new(core.RecordsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Records = records
|
||||
event.Result = result
|
||||
|
||||
return api.app.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
return e.App.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListRequestEvent) error {
|
||||
if err := EnrichRecords(e.RequestEvent, e.Records); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich records", err))
|
||||
}
|
||||
|
||||
if err := EnrichRecords(e.HttpContext, api.app.Dao(), e.Records); err != nil {
|
||||
api.app.Logger().Debug("Failed to enrich list records", slog.String("error", err.Error()))
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Result)
|
||||
return e.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *recordApi) view(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("", "Missing collection context.")
|
||||
func recordView(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
recordId := c.PathParam("id")
|
||||
err = checkCollectionRateLimit(e, collection, "view")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo := RequestInfo(c)
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if requestInfo.Admin == nil && collection.ViewRule == nil {
|
||||
// only admins can access if the rule is nil
|
||||
return NewForbiddenError("Only admins can perform this action.", nil)
|
||||
if collection.ViewRule == nil && !requestInfo.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if requestInfo.Admin == nil && collection.ViewRule != nil && *collection.ViewRule != "" {
|
||||
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
|
||||
if !requestInfo.HasSuperuserAuth() && collection.ViewRule != nil && *collection.ViewRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -126,290 +128,472 @@ func (api *recordApi) view(c echo.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
|
||||
record, fetchErr := e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if fetchErr != nil || record == nil {
|
||||
return NewNotFoundError("", fetchErr)
|
||||
return firstApiError(err, e.NotFoundError("", fetchErr))
|
||||
}
|
||||
|
||||
event := new(core.RecordViewEvent)
|
||||
event.HttpContext = c
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordViewRequest().Trigger(event, func(e *core.RecordViewEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
return e.App.OnRecordViewRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
if err := EnrichRecord(e.RequestEvent, e.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
|
||||
api.app.Logger().Debug(
|
||||
"Failed to enrich view record",
|
||||
slog.String("id", e.Record.Id),
|
||||
slog.String("collectionName", e.Record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Record)
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *recordApi) create(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("", "Missing collection context.")
|
||||
}
|
||||
func recordCreate(optFinalizer func() error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
requestInfo := RequestInfo(c)
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
if requestInfo.Admin == nil && collection.CreateRule == nil {
|
||||
// only admins can access if the rule is nil
|
||||
return NewForbiddenError("Only admins can perform this action.", nil)
|
||||
}
|
||||
err = checkCollectionRateLimit(e, collection, "create")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hasFullManageAccess := requestInfo.Admin != nil
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
// temporary save the record and check it against the create rule
|
||||
if requestInfo.Admin == nil && collection.CreateRule != nil {
|
||||
testRecord := models.NewRecord(collection)
|
||||
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
|
||||
canSkipRuleCheck := hasSuperuserAuth
|
||||
|
||||
// special case for the first superuser creation
|
||||
// ---
|
||||
if !canSkipRuleCheck && collection.Name == core.CollectionNameSuperusers {
|
||||
total, totalErr := e.App.CountRecords(core.CollectionNameSuperusers)
|
||||
canSkipRuleCheck = totalErr == nil && total == 0
|
||||
}
|
||||
// ---
|
||||
|
||||
if !canSkipRuleCheck && collection.CreateRule == nil {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
|
||||
data, err := recordDataFromRequest(e, record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
|
||||
}
|
||||
|
||||
// replace modifiers fields so that the resolved value is always
|
||||
// available when accessing requestInfo.Data using just the field name
|
||||
if requestInfo.HasModifierDataKeys() {
|
||||
requestInfo.Data = testRecord.ReplaceModifers(requestInfo.Data)
|
||||
}
|
||||
// available when accessing requestInfo.Body
|
||||
requestInfo.Body = data
|
||||
|
||||
testForm := forms.NewRecordUpsert(api.app, testRecord)
|
||||
testForm.SetFullManageAccess(true)
|
||||
if err := testForm.LoadRequest(c.Request(), ""); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
form := forms.NewRecordUpsert(e.App, record)
|
||||
if hasSuperuserAuth {
|
||||
form.GrantSuperuserAccess()
|
||||
}
|
||||
form.Load(data)
|
||||
|
||||
// force unset the verified state to prevent ManageRule misuse
|
||||
if !hasFullManageAccess {
|
||||
testForm.Verified = false
|
||||
}
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
createRuleFunc := func(q *dbx.SelectQuery) error {
|
||||
if *collection.CreateRule == "" {
|
||||
return nil // no create rule to resolve
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordCreateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
form.SetApp(e.App)
|
||||
form.SetRecord(e.Record)
|
||||
|
||||
// temporary save the record and check it against the create and manage rules
|
||||
if !canSkipRuleCheck && e.Collection.CreateRule != nil {
|
||||
// temporary grant manager access level
|
||||
form.GrantManagerAccess()
|
||||
|
||||
// manually unset the verified field to prevent manage API rule misuse in case the rule relies on it
|
||||
initialVerified := e.Record.Verified()
|
||||
if initialVerified {
|
||||
e.Record.SetVerified(false)
|
||||
}
|
||||
|
||||
createRuleFunc := func(q *dbx.SelectQuery) error {
|
||||
if *e.Collection.CreateRule == "" {
|
||||
return nil // no create rule to resolve
|
||||
}
|
||||
|
||||
resolver := core.NewRecordFieldResolver(e.App, e.Collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*e.Collection.CreateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
testErr := form.DrySubmit(func(txApp core.App, drySavedRecord *core.Record) error {
|
||||
foundRecord, err := txApp.FindRecordById(drySavedRecord.Collection(), drySavedRecord.Id, createRuleFunc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DrySubmit create rule failure: %w", err)
|
||||
}
|
||||
|
||||
// reset the form access level in case it satisfies the Manage API rule
|
||||
if !hasAuthManageAccess(txApp, requestInfo, foundRecord) {
|
||||
form.ResetAccess()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if testErr != nil {
|
||||
return e.BadRequestError("Failed to create record.", testErr)
|
||||
}
|
||||
|
||||
// restore initial verified state (it will be further validated on submit)
|
||||
if initialVerified != e.Record.Verified() {
|
||||
e.Record.SetVerified(initialVerified)
|
||||
}
|
||||
}
|
||||
|
||||
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.CreateRule).BuildExpr(resolver)
|
||||
err := form.Submit()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to create record.", err))
|
||||
}
|
||||
|
||||
err = EnrichRecord(e.RequestEvent, e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
err = e.JSON(http.StatusOK, e.Record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
return nil
|
||||
}
|
||||
|
||||
testErr := testForm.DrySubmit(func(txDao *daos.Dao) error {
|
||||
foundRecord, err := txDao.FindRecordById(collection.Id, testRecord.Id, createRuleFunc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DrySubmit create rule failure: %w", err)
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", err))
|
||||
}
|
||||
}
|
||||
hasFullManageAccess = hasAuthManageAccess(txDao, foundRecord, requestInfo)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if testErr != nil {
|
||||
return NewBadRequestError("Failed to create record.", testErr)
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
}
|
||||
|
||||
record := models.NewRecord(collection)
|
||||
form := forms.NewRecordUpsert(api.app, record)
|
||||
form.SetFullManageAccess(hasFullManageAccess)
|
||||
|
||||
// load request
|
||||
if err := form.LoadRequest(c.Request(), ""); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordCreateEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.UploadedFiles = form.FilesToUpload()
|
||||
|
||||
// create the record
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(m *models.Record) error {
|
||||
event.Record = m
|
||||
|
||||
return api.app.OnRecordBeforeCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to create record.", err)
|
||||
}
|
||||
|
||||
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
|
||||
api.app.Logger().Debug(
|
||||
"Failed to enrich create record",
|
||||
slog.String("id", e.Record.Id),
|
||||
slog.String("collectionName", e.Record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
})
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", err))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (api *recordApi) update(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("", "Missing collection context.")
|
||||
func recordUpdate(optFinalizer func() error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "update")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
|
||||
|
||||
if !hasSuperuserAuth && collection.UpdateRule == nil {
|
||||
return firstApiError(err, e.ForbiddenError("Only superusers can perform this action.", nil))
|
||||
}
|
||||
|
||||
// eager fetch the record so that the modifiers field values can be resolved
|
||||
record, err := e.App.FindRecordById(collection, recordId)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.NotFoundError("", err))
|
||||
}
|
||||
|
||||
data, err := recordDataFromRequest(e, record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
|
||||
}
|
||||
|
||||
// replace modifiers fields so that the resolved value is always
|
||||
// available when accessing requestInfo.Body
|
||||
requestInfo.Body = data
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !hasSuperuserAuth && collection.UpdateRule != nil && *collection.UpdateRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refetch with access checks
|
||||
record, err = e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.NotFoundError("", err))
|
||||
}
|
||||
|
||||
form := forms.NewRecordUpsert(e.App, record)
|
||||
if hasSuperuserAuth {
|
||||
form.GrantSuperuserAccess()
|
||||
}
|
||||
form.Load(data)
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordUpdateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
form.SetApp(e.App)
|
||||
form.SetRecord(e.Record)
|
||||
if !form.HasManageAccess() && hasAuthManageAccess(e.App, requestInfo, e.Record) {
|
||||
form.GrantManagerAccess()
|
||||
}
|
||||
|
||||
err := form.Submit()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to update record.", err))
|
||||
}
|
||||
|
||||
err = EnrichRecord(e.RequestEvent, e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
err = e.JSON(http.StatusOK, e.Record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
recordId := c.PathParam("id")
|
||||
if recordId == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
func recordDelete(optFinalizer func() error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
requestInfo := RequestInfo(c)
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
if requestInfo.Admin == nil && collection.UpdateRule == nil {
|
||||
// only admins can access if the rule is nil
|
||||
return NewForbiddenError("Only admins can perform this action.", nil)
|
||||
}
|
||||
err = checkCollectionRateLimit(e, collection, "delete")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// eager fetch the record so that the modifier field values are replaced
|
||||
// and available when accessing requestInfo.Data using just the field name
|
||||
if requestInfo.HasModifierDataKeys() {
|
||||
record, err := api.app.Dao().FindRecordById(collection.Id, recordId)
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule == nil {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule != nil && *collection.DeleteRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if err != nil || record == nil {
|
||||
return NewNotFoundError("", err)
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
requestInfo.Data = record.ReplaceModifers(requestInfo.Data)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if requestInfo.Admin == nil && collection.UpdateRule != nil && *collection.UpdateRule != "" {
|
||||
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordDeleteRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
if err := e.App.Delete(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
|
||||
}
|
||||
|
||||
err = e.NoContent(http.StatusNoContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetch record
|
||||
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
|
||||
if fetchErr != nil || record == nil {
|
||||
return NewNotFoundError("", fetchErr)
|
||||
}
|
||||
|
||||
form := forms.NewRecordUpsert(api.app, record)
|
||||
form.SetFullManageAccess(requestInfo.Admin != nil || hasAuthManageAccess(api.app.Dao(), record, requestInfo))
|
||||
|
||||
// load request
|
||||
if err := form.LoadRequest(c.Request(), ""); err != nil {
|
||||
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordUpdateEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.UploadedFiles = form.FilesToUpload()
|
||||
|
||||
// update the record
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
|
||||
return func(m *models.Record) error {
|
||||
event.Record = m
|
||||
|
||||
return api.app.OnRecordBeforeUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error {
|
||||
if err := next(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to update record.", err)
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
|
||||
}
|
||||
|
||||
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
|
||||
api.app.Logger().Debug(
|
||||
"Failed to enrich update record",
|
||||
slog.String("id", e.Record.Id),
|
||||
slog.String("collectionName", e.Record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (api *recordApi) delete(c echo.Context) error {
|
||||
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
|
||||
if collection == nil {
|
||||
return NewNotFoundError("", "Missing collection context.")
|
||||
}
|
||||
|
||||
recordId := c.PathParam("id")
|
||||
if recordId == "" {
|
||||
return NewNotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo := RequestInfo(c)
|
||||
|
||||
if requestInfo.Admin == nil && collection.DeleteRule == nil {
|
||||
// only admins can access if the rule is nil
|
||||
return NewForbiddenError("Only admins can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if requestInfo.Admin == nil && collection.DeleteRule != nil && *collection.DeleteRule != "" {
|
||||
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
|
||||
if fetchErr != nil || record == nil {
|
||||
return NewNotFoundError("", fetchErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordDeleteEvent)
|
||||
event.HttpContext = c
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return api.app.OnRecordBeforeDeleteRequest().Trigger(event, func(e *core.RecordDeleteEvent) error {
|
||||
// delete the record
|
||||
if err := api.app.Dao().DeleteRecord(e.Record); err != nil {
|
||||
return NewBadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err)
|
||||
}
|
||||
|
||||
return api.app.OnRecordAfterDeleteRequest().Trigger(event, func(e *core.RecordDeleteEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[string]any, error) {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// resolve regular fields
|
||||
result := record.ReplaceModifiers(info.Body)
|
||||
|
||||
// resolve uploaded files
|
||||
uploadedFiles, err := extractUploadedFiles(e.Request, record.Collection(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(uploadedFiles) > 0 {
|
||||
for k, v := range uploadedFiles {
|
||||
result[k] = v
|
||||
}
|
||||
result = record.ReplaceModifiers(result)
|
||||
}
|
||||
|
||||
isAuth := record.Collection().IsAuth()
|
||||
|
||||
// unset hidden fields for non-superusers
|
||||
if !info.HasSuperuserAuth() {
|
||||
for _, f := range record.Collection().Fields {
|
||||
if f.GetHidden() {
|
||||
// exception for the auth collection "password" field
|
||||
if isAuth && f.GetName() == core.FieldNamePassword {
|
||||
continue
|
||||
}
|
||||
|
||||
delete(result, f.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractUploadedFiles(request *http.Request, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) {
|
||||
contentType := request.Header.Get("content-type")
|
||||
if !strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return nil, nil // not multipart/form-data request
|
||||
}
|
||||
|
||||
result := map[string][]*filesystem.File{}
|
||||
|
||||
for _, field := range collection.Fields {
|
||||
if field.Type() != core.FieldTypeFile {
|
||||
continue
|
||||
}
|
||||
|
||||
baseKey := field.GetName()
|
||||
|
||||
keys := []string{
|
||||
baseKey,
|
||||
// prepend and append modifiers
|
||||
"+" + baseKey,
|
||||
baseKey + "+",
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if prefix != "" {
|
||||
k = prefix + "." + k
|
||||
}
|
||||
files, err := FindUploadedFiles(request, k)
|
||||
if err != nil && !errors.Is(err, http.ErrMissingFile) {
|
||||
return nil, err
|
||||
}
|
||||
if len(files) > 0 {
|
||||
result[k] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,314 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudAuthOriginList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with authOrigins",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"9r2j0m74260ur8i"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without authOrigins",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"9r2j0m74260ur8i"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"fingerprint": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"fingerprint":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"fingerprint":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"id":"9r2j0m74260ur8i"`,
|
||||
`"fingerprint":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,316 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudExternalAuthList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with externalAuths",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"f1z5b3843pzc964"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without externalAuths",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test2@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"dlmflokuq1xl342"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"provider": "github",
|
||||
"providerId": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
`"providerId":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"providerId": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"id":"dlmflokuq1xl342"`,
|
||||
`"providerId":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,388 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudMFAList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with mfas",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without mfas",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFAView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"user1_0"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFADelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFACreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"method": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFAUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"method":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,388 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudOTPList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with otps",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without otps",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"user1_0"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"password": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"password":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudSuperuserList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalPages":1`,
|
||||
`"totalItems":4`,
|
||||
`"items":[{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 4, // + 3 AuthOrigins
|
||||
"OnModelDeleteExecute": 4,
|
||||
"OnModelAfterDeleteSuccess": 4,
|
||||
"OnRecordDelete": 4,
|
||||
"OnRecordDeleteExecute": 4,
|
||||
"OnRecordAfterDeleteSuccess": 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "delete the last superuser",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// delete all other superusers
|
||||
superusers, err := app.FindAllRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{"id": "sywbhecnh46rhm0"}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, superuser := range superusers {
|
||||
if err = app.Delete(superuser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelAfterDeleteError": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordAfterDeleteError": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"email": "test_new@example.com",
|
||||
"password": "1234567890",
|
||||
"passwordConfirm": "1234567890",
|
||||
"verified": false
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "guest creating first superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// delete all superusers
|
||||
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// because the action has no auth the email field shouldn't be returned if emailVisibility is not set
|
||||
`"email"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"email":"test_new@example.com"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"email": "test_new@example.com",
|
||||
"verified": true
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
`"email":"test_new@example.com"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,121 +1,111 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/resolvers"
|
||||
"github.com/pocketbase/pocketbase/tokens"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
const ContextRequestInfoKey = "requestInfo"
|
||||
const (
|
||||
expandQueryParam = "expand"
|
||||
fieldsQueryParam = "fields"
|
||||
)
|
||||
|
||||
const expandQueryParam = "expand"
|
||||
const fieldsQueryParam = "fields"
|
||||
|
||||
// Deprecated: Use RequestInfo instead.
|
||||
func RequestData(c echo.Context) *models.RequestInfo {
|
||||
log.Println("RequestData(c) is deprecated and will be removed in the future! You can replace it with RequestInfo(c).")
|
||||
return RequestInfo(c)
|
||||
}
|
||||
|
||||
// RequestInfo exports cached common request data fields
|
||||
// (query, body, logged auth state, etc.) from the provided context.
|
||||
func RequestInfo(c echo.Context) *models.RequestInfo {
|
||||
// return cached to avoid copying the body multiple times
|
||||
if v := c.Get(ContextRequestInfoKey); v != nil {
|
||||
if data, ok := v.(*models.RequestInfo); ok {
|
||||
// refresh auth state
|
||||
data.AuthRecord, _ = c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
data.Admin, _ = c.Get(ContextAdminKey).(*models.Admin)
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
result := &models.RequestInfo{
|
||||
Context: models.RequestInfoContextDefault,
|
||||
Method: c.Request().Method,
|
||||
Query: map[string]any{},
|
||||
Data: map[string]any{},
|
||||
Headers: map[string]any{},
|
||||
}
|
||||
|
||||
// extract the first value of all headers and normalizes the keys
|
||||
// ("X-Token" is converted to "x_token")
|
||||
for k, v := range c.Request().Header {
|
||||
if len(v) > 0 {
|
||||
result.Headers[inflector.Snakecase(k)] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
result.AuthRecord, _ = c.Get(ContextAuthRecordKey).(*models.Record)
|
||||
result.Admin, _ = c.Get(ContextAdminKey).(*models.Admin)
|
||||
echo.BindQueryParams(c, &result.Query)
|
||||
rest.BindBody(c, &result.Data)
|
||||
|
||||
c.Set(ContextRequestInfoKey, result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// RecordAuthResponse writes standardised json record auth response
|
||||
// RecordAuthResponse writes standardized json record auth response
|
||||
// into the specified request context.
|
||||
func RecordAuthResponse(
|
||||
app core.App,
|
||||
c echo.Context,
|
||||
authRecord *models.Record,
|
||||
meta any,
|
||||
finalizers ...func(token string) error,
|
||||
) error {
|
||||
if !authRecord.Verified() && authRecord.Collection().AuthOptions().OnlyVerified {
|
||||
return NewForbiddenError("Please verify your account first.", nil)
|
||||
}
|
||||
|
||||
token, tokenErr := tokens.NewRecordAuthToken(app, authRecord)
|
||||
//
|
||||
// The authMethod argument specify the name of the current authentication method (eg. password, oauth2, etc.)
|
||||
// that it is used primarily as an auth identifier during MFA and for login alerts.
|
||||
//
|
||||
// Set authMethod to empty string if you want to ignore the MFA checks and the login alerts
|
||||
// (can be also adjusted additionally via the OnRecordAuthRequest hook).
|
||||
func RecordAuthResponse(e *core.RequestEvent, authRecord *core.Record, authMethod string, meta any) error {
|
||||
token, tokenErr := authRecord.NewAuthToken()
|
||||
if tokenErr != nil {
|
||||
return NewBadRequestError("Failed to create auth token.", tokenErr)
|
||||
return e.InternalServerError("Failed to create auth token.", tokenErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthEvent)
|
||||
event.HttpContext = c
|
||||
return recordAuthResponse(e, authRecord, token, authMethod, meta)
|
||||
}
|
||||
|
||||
func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token string, authMethod string, meta any) error {
|
||||
originalRequestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ok, err := e.App.CanAccessRecord(authRecord, originalRequestInfo, authRecord.Collection().AuthRule)
|
||||
if !ok {
|
||||
return firstApiError(err, e.ForbiddenError("The request doesn't satisfy the collection requirements to authenticate.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = authRecord.Collection()
|
||||
event.Record = authRecord
|
||||
event.Token = token
|
||||
event.Meta = meta
|
||||
event.AuthMethod = authMethod
|
||||
|
||||
return app.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return e.App.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthRequestEvent) error {
|
||||
if e.Written() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// allow always returning the email address of the authenticated account
|
||||
e.Record.IgnoreEmailVisibility(true)
|
||||
// MFA
|
||||
// ---
|
||||
mfaId, err := checkMFA(e.RequestEvent, e.Record, e.AuthMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// expand record relations
|
||||
expands := strings.Split(c.QueryParam(expandQueryParam), ",")
|
||||
if len(expands) > 0 {
|
||||
// create a copy of the cached request data and adjust it to the current auth record
|
||||
requestInfo := *RequestInfo(e.HttpContext)
|
||||
requestInfo.Admin = nil
|
||||
requestInfo.AuthRecord = e.Record
|
||||
failed := app.Dao().ExpandRecord(
|
||||
e.Record,
|
||||
expands,
|
||||
expandFetch(app.Dao(), &requestInfo),
|
||||
)
|
||||
if len(failed) > 0 {
|
||||
app.Logger().Debug("[RecordAuthResponse] Failed to expand relations", slog.Any("errors", failed))
|
||||
// require additional authentication
|
||||
if mfaId != "" {
|
||||
return e.JSON(http.StatusUnauthorized, map[string]string{
|
||||
"mfaId": mfaId,
|
||||
})
|
||||
}
|
||||
// ---
|
||||
|
||||
// create a shallow copy of the cached request data and adjust it to the current auth record
|
||||
requestInfo := *originalRequestInfo
|
||||
requestInfo.Auth = e.Record
|
||||
|
||||
err = triggerRecordEnrichHooks(e.App, &requestInfo, []*core.Record{e.Record}, func() error {
|
||||
if e.Record.IsSuperuser() {
|
||||
e.Record.Unhide(e.Record.Collection().Fields.FieldNames()...)
|
||||
}
|
||||
|
||||
// allow always returning the email address of the authenticated model
|
||||
e.Record.IgnoreEmailVisibility(true)
|
||||
|
||||
// expand record relations
|
||||
expands := strings.Split(e.Request.URL.Query().Get(expandQueryParam), ",")
|
||||
if len(expands) > 0 {
|
||||
failed := e.App.ExpandRecord(e.Record, expands, expandFetch(e.App, &requestInfo))
|
||||
if len(failed) > 0 {
|
||||
e.App.Logger().Warn("[recordAuthResponse] Failed to expand relations", "error", failed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.AuthMethod != "" && authRecord.Collection().AuthAlert.Enabled {
|
||||
if err = authAlert(e.RequestEvent, e.Record); err != nil {
|
||||
e.App.Logger().Warn("[recordAuthResponse] Failed to send login alert", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -128,68 +118,254 @@ func RecordAuthResponse(
|
|||
result["meta"] = e.Meta
|
||||
}
|
||||
|
||||
for _, f := range finalizers {
|
||||
if err := f(e.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
return e.JSON(http.StatusOK, result)
|
||||
})
|
||||
}
|
||||
|
||||
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule.
|
||||
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
|
||||
rule := record.Collection().MFA.Rule
|
||||
if rule == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var exists bool
|
||||
|
||||
query := e.App.RecordQuery(record.Collection()).
|
||||
Select("(1)").
|
||||
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
|
||||
|
||||
// parse and apply the access rule filter
|
||||
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
|
||||
expr, err := search.FilterData(rule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resolver.UpdateQuery(query)
|
||||
|
||||
err = query.AndWhere(expr).Limit(1).Row(&exists)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// checkMFA handles any MFA auth checks that needs to be performed for the specified request event.
|
||||
// Returns the mfaId that needs to be written as response to the user.
|
||||
//
|
||||
// (note: all auth methods are treated as equal and there is no requirement for "pairing").
|
||||
func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod string) (string, error) {
|
||||
if !authRecord.Collection().MFA.Enabled || currentAuthMethod == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ok, err := wantsMFA(e, authRecord)
|
||||
if !ok {
|
||||
if err != nil {
|
||||
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, result)
|
||||
})
|
||||
return "", nil // no mfa needed for this auth record
|
||||
}
|
||||
|
||||
// read the mfaId either from the qyery params or request body
|
||||
mfaId := e.Request.URL.Query().Get("mfaId")
|
||||
if mfaId == "" {
|
||||
// check the body
|
||||
data := struct {
|
||||
MfaId string `form:"mfaId" json:"mfaId" xml:"mfaId"`
|
||||
}{}
|
||||
if err := e.BindBody(&data); err != nil {
|
||||
return "", firstApiError(err, e.BadRequestError("Failed to read MFA Id", err))
|
||||
}
|
||||
mfaId = data.MfaId
|
||||
}
|
||||
|
||||
// first-time auth
|
||||
// ---
|
||||
if mfaId == "" {
|
||||
mfa := core.NewMFA(e.App)
|
||||
mfa.SetCollectionRef(authRecord.Collection().Id)
|
||||
mfa.SetRecordRef(authRecord.Id)
|
||||
mfa.SetMethod(currentAuthMethod)
|
||||
if err := e.App.Save(mfa); err != nil {
|
||||
return "", firstApiError(err, e.InternalServerError("Failed to create MFA record", err))
|
||||
}
|
||||
|
||||
return mfa.Id, nil
|
||||
}
|
||||
|
||||
// second-time auth
|
||||
// ---
|
||||
mfa, err := e.App.FindMFAById(mfaId)
|
||||
deleteMFA := func() {
|
||||
// try to delete the expired mfa
|
||||
if mfa != nil {
|
||||
if deleteErr := e.App.Delete(mfa); deleteErr != nil {
|
||||
e.App.Logger().Warn("Failed to delete expired MFA record", "error", deleteErr, "mfaId", mfa.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
|
||||
deleteMFA()
|
||||
return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err))
|
||||
}
|
||||
|
||||
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {
|
||||
return "", e.BadRequestError("Invalid MFA session.", nil)
|
||||
}
|
||||
|
||||
if mfa.Method() == currentAuthMethod {
|
||||
return "", e.BadRequestError("A different authentication method is required.", nil)
|
||||
}
|
||||
|
||||
deleteMFA()
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// EnrichRecord parses the request context and enrich the provided record:
|
||||
// - expands relations (if defaultExpands and/or ?expand query param is set)
|
||||
// - ensures that the emails of the auth record and its expanded auth relations
|
||||
// are visible only for the current logged admin, record owner or record with manage access
|
||||
func EnrichRecord(c echo.Context, dao *daos.Dao, record *models.Record, defaultExpands ...string) error {
|
||||
return EnrichRecords(c, dao, []*models.Record{record}, defaultExpands...)
|
||||
// are visible only for the current logged superuser, record owner or record with manage access
|
||||
func EnrichRecord(e *core.RequestEvent, record *core.Record, defaultExpands ...string) error {
|
||||
return EnrichRecords(e, []*core.Record{record}, defaultExpands...)
|
||||
}
|
||||
|
||||
// EnrichRecords parses the request context and enriches the provided records:
|
||||
// - expands relations (if defaultExpands and/or ?expand query param is set)
|
||||
// - ensures that the emails of the auth records and their expanded auth relations
|
||||
// are visible only for the current logged admin, record owner or record with manage access
|
||||
func EnrichRecords(c echo.Context, dao *daos.Dao, records []*models.Record, defaultExpands ...string) error {
|
||||
requestInfo := RequestInfo(c)
|
||||
|
||||
if err := autoIgnoreAuthRecordsEmailVisibility(dao, records, requestInfo); err != nil {
|
||||
return fmt.Errorf("failed to resolve email visibility: %w", err)
|
||||
// are visible only for the current logged superuser, record owner or record with manage access
|
||||
//
|
||||
// Note: Expects all records to be from the same collection!
|
||||
func EnrichRecords(e *core.RequestEvent, records []*core.Record, defaultExpands ...string) error {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
expands := defaultExpands
|
||||
if param := c.QueryParam(expandQueryParam); param != "" {
|
||||
expands = append(expands, strings.Split(param, ",")...)
|
||||
}
|
||||
if len(expands) == 0 {
|
||||
return nil // nothing to expand
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errs := dao.ExpandRecords(records, expands, expandFetch(dao, requestInfo))
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to expand: %v", errs)
|
||||
return triggerRecordEnrichHooks(e.App, info, records, func() error {
|
||||
expands := defaultExpands
|
||||
if param := e.Request.URL.Query().Get(expandQueryParam); param != "" {
|
||||
expands = append(expands, strings.Split(param, ",")...)
|
||||
}
|
||||
|
||||
err := defaultEnrichRecords(e.App, info, records, expands...)
|
||||
if err != nil {
|
||||
// only log as it is not critical
|
||||
e.App.Logger().Warn("failed to apply default enriching", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var iterate func(record *core.Record) error
|
||||
|
||||
type iterator[T any] struct {
|
||||
items []T
|
||||
index int
|
||||
}
|
||||
|
||||
func (ri *iterator[T]) next() T {
|
||||
var item T
|
||||
|
||||
if ri.index < len(ri.items) {
|
||||
item = ri.items[ri.index]
|
||||
ri.index++
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
func triggerRecordEnrichHooks(app core.App, requestInfo *core.RequestInfo, records []*core.Record, finalizer func() error) error {
|
||||
it := iterator[*core.Record]{items: records}
|
||||
|
||||
enrichHook := app.OnRecordEnrich()
|
||||
|
||||
event := new(core.RecordEnrichEvent)
|
||||
event.App = app
|
||||
event.RequestInfo = requestInfo
|
||||
|
||||
iterate = func(record *core.Record) error {
|
||||
if record == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
event.Record = record
|
||||
|
||||
return enrichHook.Trigger(event, func(ee *core.RecordEnrichEvent) error {
|
||||
next := it.next()
|
||||
if next == nil {
|
||||
if finalizer != nil {
|
||||
return finalizer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
event.App = ee.App // in case it was replaced with a transaction
|
||||
event.Record = next
|
||||
|
||||
err := iterate(next)
|
||||
|
||||
event.App = app
|
||||
event.Record = record
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
return iterate(it.next())
|
||||
}
|
||||
|
||||
func defaultEnrichRecords(app core.App, requestInfo *core.RequestInfo, records []*core.Record, expands ...string) error {
|
||||
err := autoResolveRecordsFlags(app, records, requestInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve records flags: %w", err)
|
||||
}
|
||||
|
||||
if len(expands) > 0 {
|
||||
expandErrs := app.ExpandRecords(records, expands, expandFetch(app, requestInfo))
|
||||
if len(expandErrs) > 0 {
|
||||
errsSlice := make([]error, 0, len(expandErrs))
|
||||
for key, err := range expandErrs {
|
||||
errsSlice = append(errsSlice, fmt.Errorf("failed to expand %q: %w", key, err))
|
||||
}
|
||||
return fmt.Errorf("failed to expand records: %w", errors.Join(errsSlice...))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandFetch is the records fetch function that is used to expand related records.
|
||||
func expandFetch(
|
||||
dao *daos.Dao,
|
||||
requestInfo *models.RequestInfo,
|
||||
) daos.ExpandFetchFunc {
|
||||
return func(relCollection *models.Collection, relIds []string) ([]*models.Record, error) {
|
||||
records, err := dao.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
|
||||
if requestInfo.Admin != nil {
|
||||
return nil // admins can access everything
|
||||
func expandFetch(app core.App, originalRequestInfo *core.RequestInfo) core.ExpandFetchFunc {
|
||||
requestInfoClone := *originalRequestInfo
|
||||
requestInfoPtr := &requestInfoClone
|
||||
requestInfoPtr.Context = core.RequestInfoContextExpand
|
||||
|
||||
return func(relCollection *core.Collection, relIds []string) ([]*core.Record, error) {
|
||||
records, findErr := app.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
|
||||
if requestInfoPtr.Auth != nil && requestInfoPtr.Auth.IsSuperuser() {
|
||||
return nil // superusers can access everything
|
||||
}
|
||||
|
||||
if relCollection.ViewRule == nil {
|
||||
return fmt.Errorf("only admins can view collection %q records", relCollection.Name)
|
||||
return fmt.Errorf("only superusers can view collection %q records", relCollection.Name)
|
||||
}
|
||||
|
||||
if *relCollection.ViewRule != "" {
|
||||
resolver := resolvers.NewRecordFieldResolver(dao, relCollection, requestInfo, true)
|
||||
resolver := core.NewRecordFieldResolver(app, relCollection, requestInfoPtr, true)
|
||||
expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -200,50 +376,66 @@ func expandFetch(
|
|||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil && len(records) > 0 {
|
||||
autoIgnoreAuthRecordsEmailVisibility(dao, records, requestInfo)
|
||||
if findErr != nil {
|
||||
return nil, findErr
|
||||
}
|
||||
|
||||
return records, err
|
||||
enrichErr := triggerRecordEnrichHooks(app, requestInfoPtr, records, func() error {
|
||||
if err := autoResolveRecordsFlags(app, records, requestInfoPtr); err != nil {
|
||||
// non-critical error
|
||||
app.Logger().Warn("Failed to apply autoResolveRecordsFlags for the expanded records", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if enrichErr != nil {
|
||||
return nil, enrichErr
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
}
|
||||
|
||||
// autoIgnoreAuthRecordsEmailVisibility ignores the email visibility check for
|
||||
// the provided record if the current auth model is admin, owner or a "manager".
|
||||
// autoResolveRecordsFlags resolves various visibility flags of the provided records.
|
||||
//
|
||||
// Note: Expects all records to be from the same auth collection!
|
||||
func autoIgnoreAuthRecordsEmailVisibility(
|
||||
dao *daos.Dao,
|
||||
records []*models.Record,
|
||||
requestInfo *models.RequestInfo,
|
||||
) error {
|
||||
if len(records) == 0 || !records[0].Collection().IsAuth() {
|
||||
return nil // nothing to check
|
||||
// Currently it enables:
|
||||
// - export of hidden fields if the current auth model is a superuser
|
||||
// - email export ignoring the emailVisibity checks if the current auth model is superuser, owner or a "manager".
|
||||
//
|
||||
// Note: Expects all records to be from the same collection!
|
||||
func autoResolveRecordsFlags(app core.App, records []*core.Record, requestInfo *core.RequestInfo) error {
|
||||
if len(records) == 0 {
|
||||
return nil // nothing to resolve
|
||||
}
|
||||
|
||||
if requestInfo.Admin != nil {
|
||||
if requestInfo.HasSuperuserAuth() {
|
||||
hiddenFields := records[0].Collection().Fields.FieldNames()
|
||||
for _, rec := range records {
|
||||
rec.Unhide(hiddenFields...)
|
||||
rec.IgnoreEmailVisibility(true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// additional emailVisibility checks
|
||||
// ---------------------------------------------------------------
|
||||
if !records[0].Collection().IsAuth() {
|
||||
return nil // not auth collection records
|
||||
}
|
||||
|
||||
collection := records[0].Collection()
|
||||
|
||||
mappedRecords := make(map[string]*models.Record, len(records))
|
||||
mappedRecords := make(map[string]*core.Record, len(records))
|
||||
recordIds := make([]any, len(records))
|
||||
for i, rec := range records {
|
||||
mappedRecords[rec.Id] = rec
|
||||
recordIds[i] = rec.Id
|
||||
}
|
||||
|
||||
if requestInfo != nil && requestInfo.AuthRecord != nil && mappedRecords[requestInfo.AuthRecord.Id] != nil {
|
||||
mappedRecords[requestInfo.AuthRecord.Id].IgnoreEmailVisibility(true)
|
||||
if requestInfo.Auth != nil && mappedRecords[requestInfo.Auth.Id] != nil {
|
||||
mappedRecords[requestInfo.Auth.Id].IgnoreEmailVisibility(true)
|
||||
}
|
||||
|
||||
authOptions := collection.AuthOptions()
|
||||
if authOptions.ManageRule == nil || *authOptions.ManageRule == "" {
|
||||
if collection.ManageRule == nil || *collection.ManageRule == "" {
|
||||
return nil // no manage rule to check
|
||||
}
|
||||
|
||||
|
|
@ -251,12 +443,12 @@ func autoIgnoreAuthRecordsEmailVisibility(
|
|||
// ---
|
||||
managedIds := []string{}
|
||||
|
||||
query := dao.RecordQuery(collection).
|
||||
Select(dao.DB().QuoteSimpleColumnName(collection.Name) + ".id").
|
||||
AndWhere(dbx.In(dao.DB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
|
||||
query := app.RecordQuery(collection).
|
||||
Select(app.DB().QuoteSimpleColumnName(collection.Name) + ".id").
|
||||
AndWhere(dbx.In(app.DB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
|
||||
|
||||
resolver := resolvers.NewRecordFieldResolver(dao, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*authOptions.ManageRule).BuildExpr(resolver)
|
||||
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.ManageRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -278,30 +470,26 @@ func autoIgnoreAuthRecordsEmailVisibility(
|
|||
return nil
|
||||
}
|
||||
|
||||
// hasAuthManageAccess checks whether the client is allowed to have full
|
||||
// hasAuthManageAccess checks whether the client is allowed to have
|
||||
// [forms.RecordUpsert] auth management permissions
|
||||
// (aka. allowing to change system auth fields without oldPassword).
|
||||
func hasAuthManageAccess(
|
||||
dao *daos.Dao,
|
||||
record *models.Record,
|
||||
requestInfo *models.RequestInfo,
|
||||
) bool {
|
||||
// (e.g. allowing to change system auth fields without oldPassword).
|
||||
func hasAuthManageAccess(app core.App, requestInfo *core.RequestInfo, record *core.Record) bool {
|
||||
if !record.Collection().IsAuth() {
|
||||
return false
|
||||
}
|
||||
|
||||
manageRule := record.Collection().AuthOptions().ManageRule
|
||||
manageRule := record.Collection().ManageRule
|
||||
|
||||
if manageRule == nil || *manageRule == "" {
|
||||
return false // only for admins (manageRule can't be empty)
|
||||
return false // only for superusers (manageRule can't be empty)
|
||||
}
|
||||
|
||||
if requestInfo == nil || requestInfo.AuthRecord == nil {
|
||||
if requestInfo == nil || requestInfo.Auth == nil {
|
||||
return false // no auth record
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
resolver := resolvers.NewRecordFieldResolver(dao, record.Collection(), requestInfo, true)
|
||||
resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, true)
|
||||
expr, err := search.FilterData(*manageRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -311,35 +499,118 @@ func hasAuthManageAccess(
|
|||
return nil
|
||||
}
|
||||
|
||||
_, findErr := dao.FindRecordById(record.Collection().Id, record.Id, ruleFunc)
|
||||
_, findErr := app.FindRecordById(record.Collection().Id, record.Id, ruleFunc)
|
||||
|
||||
return findErr == nil
|
||||
}
|
||||
|
||||
var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam}
|
||||
var adminOnlyRuleFields = []string{"@collection.", "@request."}
|
||||
var superuserOnlyRuleFields = []string{"@collection.", "@request."}
|
||||
|
||||
// @todo consider moving the rules check to the RecordFieldResolver.
|
||||
//
|
||||
// checkForAdminOnlyRuleFields loosely checks and returns an error if
|
||||
// the provided RequestInfo contains rule fields that only the admin can use.
|
||||
func checkForAdminOnlyRuleFields(requestInfo *models.RequestInfo) error {
|
||||
if requestInfo.Admin != nil || len(requestInfo.Query) == 0 {
|
||||
return nil // admin or nothing to check
|
||||
// checkForSuperuserOnlyRuleFields loosely checks and returns an error if
|
||||
// the provided RequestInfo contains rule fields that only the superuser can use.
|
||||
func checkForSuperuserOnlyRuleFields(requestInfo *core.RequestInfo) error {
|
||||
if len(requestInfo.Query) == 0 || requestInfo.HasSuperuserAuth() {
|
||||
return nil // superuser or nothing to check
|
||||
}
|
||||
|
||||
for _, param := range ruleQueryParams {
|
||||
v, _ := requestInfo.Query[param].(string)
|
||||
v := requestInfo.Query[param]
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, field := range adminOnlyRuleFields {
|
||||
for _, field := range superuserOnlyRuleFields {
|
||||
if strings.Contains(v, field) {
|
||||
return NewForbiddenError("Only admins can filter by "+field, nil)
|
||||
return router.NewForbiddenError("Only superusers can filter by "+field, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// firstApiError returns the first ApiError from the errors list
|
||||
// (this is used usually to prevent unnecessary wraping and to allow bubling ApiError from nested hooks)
|
||||
//
|
||||
// If no ApiError is found, returns a default "Internal server" error.
|
||||
func firstApiError(errs ...error) *router.ApiError {
|
||||
var apiErr *router.ApiError
|
||||
var ok bool
|
||||
|
||||
for _, err := range errs {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// quick assert to avoid the reflection checks
|
||||
apiErr, ok = err.(*router.ApiError)
|
||||
if ok {
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// nested/wrapped errors
|
||||
if errors.As(err, &apiErr) {
|
||||
return apiErr
|
||||
}
|
||||
}
|
||||
|
||||
return router.NewInternalServerError("", errors.Join(errs...))
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const maxAuthOrigins = 5
|
||||
|
||||
func authAlert(e *core.RequestEvent, authRecord *core.Record) error {
|
||||
// generating fingerprint
|
||||
// ---
|
||||
userAgent := e.Request.UserAgent()
|
||||
if len(userAgent) > 300 {
|
||||
userAgent = userAgent[:300]
|
||||
}
|
||||
fingerprint := security.MD5(e.RealIP() + userAgent)
|
||||
// ---
|
||||
|
||||
origins, err := e.App.FindAllAuthOriginsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isFirstLogin := len(origins) == 0
|
||||
|
||||
var currentOrigin *core.AuthOrigin
|
||||
for _, origin := range origins {
|
||||
if origin.Fingerprint() == fingerprint {
|
||||
currentOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
if currentOrigin == nil {
|
||||
currentOrigin = core.NewAuthOrigin(e.App)
|
||||
currentOrigin.SetCollectionRef(authRecord.Collection().Id)
|
||||
currentOrigin.SetRecordRef(authRecord.Id)
|
||||
currentOrigin.SetFingerprint(fingerprint)
|
||||
}
|
||||
|
||||
// send email alert for the new origin auth (skip first login)
|
||||
if !isFirstLogin && currentOrigin.IsNew() && authRecord.Email() != "" {
|
||||
if err := mails.SendRecordAuthAlert(e.App, authRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// try to keep only up to maxAuthOrigins
|
||||
// (pop the last used ones; it is not executed in a transaction to avoid unnecessary locks)
|
||||
if currentOrigin.IsNew() && len(origins) >= maxAuthOrigins {
|
||||
for i := len(origins) - 1; i >= maxAuthOrigins-1; i-- {
|
||||
if err := e.App.Delete(origins[i]); err != nil {
|
||||
// treat as non-critical error, just log for now
|
||||
e.App.Logger().Warn("Failed to delete old AuthOrigin record", "error", err, "authOriginId", origins[i].Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create/update the origin fingerprint
|
||||
return e.App.Save(currentOrigin)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,231 +6,742 @@ import (
|
|||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRequestInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/?test=123", strings.NewReader(`{"test":456}`))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
req.Header.Set("X-Token-Test", "123")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
dummyRecord := &models.Record{}
|
||||
dummyRecord.Id = "id1"
|
||||
c.Set(apis.ContextAuthRecordKey, dummyRecord)
|
||||
|
||||
dummyAdmin := &models.Admin{}
|
||||
dummyAdmin.Id = "id2"
|
||||
c.Set(apis.ContextAdminKey, dummyAdmin)
|
||||
|
||||
result := apis.RequestInfo(c)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected *models.RequestInfo instance, got nil")
|
||||
}
|
||||
|
||||
if result.Method != http.MethodPost {
|
||||
t.Fatalf("Expected Method %v, got %v", http.MethodPost, result.Method)
|
||||
}
|
||||
|
||||
rawHeaders, _ := json.Marshal(result.Headers)
|
||||
expectedHeaders := `{"content_type":"application/json","x_token_test":"123"}`
|
||||
if v := string(rawHeaders); v != expectedHeaders {
|
||||
t.Fatalf("Expected Query %v, got %v", expectedHeaders, v)
|
||||
}
|
||||
|
||||
rawQuery, _ := json.Marshal(result.Query)
|
||||
expectedQuery := `{"test":"123"}`
|
||||
if v := string(rawQuery); v != expectedQuery {
|
||||
t.Fatalf("Expected Query %v, got %v", expectedQuery, v)
|
||||
}
|
||||
|
||||
rawData, _ := json.Marshal(result.Data)
|
||||
expectedData := `{"test":456}`
|
||||
if v := string(rawData); v != expectedData {
|
||||
t.Fatalf("Expected Data %v, got %v", expectedData, v)
|
||||
}
|
||||
|
||||
if result.AuthRecord == nil || result.AuthRecord.Id != dummyRecord.Id {
|
||||
t.Fatalf("Expected AuthRecord %v, got %v", dummyRecord, result.AuthRecord)
|
||||
}
|
||||
|
||||
if result.Admin == nil || result.Admin.Id != dummyAdmin.Id {
|
||||
t.Fatalf("Expected Admin %v, got %v", dummyAdmin, result.Admin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponse(t *testing.T) {
|
||||
func TestEnrichRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// mock test data
|
||||
// ---
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
dummyAdmin := &models.Admin{}
|
||||
dummyAdmin.Id = "id1"
|
||||
|
||||
nonAuthRecord, err := app.Dao().FindRecordById("demo1", "al1h9ijdeojtsjy")
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authRecord, err := app.Dao().FindRecordById("users", "4q1xlclmfloku33")
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
unverifiedAuthRecord, err := app.Dao().FindRecordById("clients", "o1y0dd0spd786md")
|
||||
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// temp update the view rule to ensure that request context is set to "expand"
|
||||
demo4, err := app.FindCollectionByNameOrId("demo4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
|
||||
if err := app.Save(demo4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
record *models.Record
|
||||
meta any
|
||||
expectError bool
|
||||
expectedContent []string
|
||||
notExpectedContent []string
|
||||
expectedEvents map[string]int
|
||||
name string
|
||||
auth *core.Record
|
||||
records []*core.Record
|
||||
queryExpand string
|
||||
defaultExpands []string
|
||||
expected []string
|
||||
notExpected []string
|
||||
}{
|
||||
// email visibility checks
|
||||
{
|
||||
name: "non auth record",
|
||||
record: nonAuthRecord,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid auth record but with unverified email in onlyVerified collection",
|
||||
record: unverifiedAuthRecord,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid auth record - without meta",
|
||||
record: authRecord,
|
||||
expectError: false,
|
||||
expectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"id":"`,
|
||||
`"expand":{"rel":{`,
|
||||
name: "[emailVisibility] guest",
|
||||
auth: nil,
|
||||
records: usersRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
},
|
||||
notExpectedContent: []string{
|
||||
`"meta":`,
|
||||
},
|
||||
expectedEvents: map[string]int{
|
||||
"OnRecordAuthRequest": 1,
|
||||
notExpected: []string{
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid auth record - with meta",
|
||||
record: authRecord,
|
||||
meta: map[string]any{"meta_test": 123},
|
||||
expectError: false,
|
||||
expectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"id":"`,
|
||||
`"expand":{"rel":{`,
|
||||
`"meta":{"meta_test":123`,
|
||||
name: "[emailVisibility] owner",
|
||||
auth: user,
|
||||
records: usersRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
`"test@example.com"`, // owner
|
||||
},
|
||||
expectedEvents: map[string]int{
|
||||
"OnRecordAuthRequest": 1,
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] manager",
|
||||
auth: user,
|
||||
records: nologinRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] superuser",
|
||||
auth: superuser,
|
||||
records: nologinRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
|
||||
auth: user,
|
||||
records: demo1Records,
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"expand":{}`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
|
||||
auth: superuser,
|
||||
records: demo1Records,
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test@example.com"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
|
||||
// expand checks
|
||||
{
|
||||
name: "[expand] guest (query)",
|
||||
auth: nil,
|
||||
records: usersRecords,
|
||||
queryExpand: "rel",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] guest (default expands)",
|
||||
auth: nil,
|
||||
records: usersRecords,
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] @request.context=expand check",
|
||||
auth: nil,
|
||||
records: demo5Records,
|
||||
queryExpand: "rel_one",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{}`,
|
||||
`"expand":{"`,
|
||||
`"rel_many":[{`,
|
||||
`"rel_one":{`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
app.ResetEventCalls()
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand=rel", nil)
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Set(apis.ContextAdminKey, dummyAdmin)
|
||||
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
|
||||
e.Record.WithCustomData(true)
|
||||
e.Record.Set("customField", "123")
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
responseErr := apis.RecordAuthResponse(app, c, s.record, s.meta)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
hasErr := responseErr != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.name, s.expectError, hasErr, responseErr)
|
||||
}
|
||||
requestEvent := new(core.RequestEvent)
|
||||
requestEvent.App = app
|
||||
requestEvent.Request = req
|
||||
requestEvent.Response = rec
|
||||
requestEvent.Auth = s.auth
|
||||
|
||||
if len(app.EventCalls) != len(s.expectedEvents) {
|
||||
t.Fatalf("[%s] Expected events \n%v, \ngot \n%v", s.name, s.expectedEvents, app.EventCalls)
|
||||
}
|
||||
for k, v := range s.expectedEvents {
|
||||
if app.EventCalls[k] != v {
|
||||
t.Fatalf("[%s] Expected event %s to be called %d times, got %d", s.name, k, v, app.EventCalls[k])
|
||||
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
response := rec.Body.String()
|
||||
|
||||
for _, v := range s.expectedContent {
|
||||
if !strings.Contains(response, v) {
|
||||
t.Fatalf("[%s] Missing %v in response \n%v", s.name, v, response)
|
||||
raw, err := json.Marshal(s.records)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
for _, v := range s.notExpectedContent {
|
||||
if strings.Contains(response, v) {
|
||||
t.Fatalf("[%s] Unexpected %v in response \n%v", s.name, v, response)
|
||||
for _, str := range s.expected {
|
||||
if !strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, str := range s.notExpected {
|
||||
if strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrichRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand=rel_many", nil)
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
dummyAdmin := &models.Admin{}
|
||||
dummyAdmin.Id = "test_id"
|
||||
c.Set(apis.ContextAdminKey, dummyAdmin)
|
||||
|
||||
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
records, err := app.Dao().FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
apis.EnrichRecords(c, app.Dao(), records, "rel_one")
|
||||
scenarios := []struct {
|
||||
name string
|
||||
rule *string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"admin only rule",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty rule",
|
||||
types.Pointer(""),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"false rule",
|
||||
types.Pointer("1=2"),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"true rule",
|
||||
types.Pointer("1=1"),
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
expand := record.Expand()
|
||||
if len(expand) == 0 {
|
||||
t.Fatalf("Expected non-empty expand, got nil for record %v", record)
|
||||
}
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
user.Collection().AuthRule = s.rule
|
||||
|
||||
if len(record.GetStringSlice("rel_one")) != 0 {
|
||||
if _, ok := expand["rel_one"]; !ok {
|
||||
t.Fatalf("Expected rel_one to be expanded for record %v, got \n%v", record, expand)
|
||||
err := apis.RecordAuthResponse(event, user, "", nil)
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(record.GetStringSlice("rel_many")) != 0 {
|
||||
if _, ok := expand["rel_many"]; !ok {
|
||||
t.Fatalf("Expected rel_many to be expanded for record %v, got \n%v", record, expand)
|
||||
// in all cases login alert shouldn't be send because of the empty auth method
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
apiErr, ok := err.(*router.ApiError)
|
||||
|
||||
if !ok || apiErr == nil {
|
||||
t.Fatalf("Expected ApiError, got %v", apiErr)
|
||||
}
|
||||
|
||||
if apiErr.Status != http.StatusForbidden {
|
||||
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
|
||||
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
devices []string // mock existing device fingerprints
|
||||
expectDevices []string
|
||||
enabled bool
|
||||
expectEmail bool
|
||||
}{
|
||||
{
|
||||
name: "first login",
|
||||
devices: nil,
|
||||
expectDevices: []string{testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "existing device",
|
||||
devices: []string{"1", testFingerprint},
|
||||
expectDevices: []string{"1", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "new device (< 5)",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "new device (>= 5)",
|
||||
devices: []string{"1", "2", "3", "4", "5"},
|
||||
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "with disabled auth alert collection flag",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2"},
|
||||
enabled: false,
|
||||
expectEmail: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
user.Collection().AuthRule = types.Pointer("")
|
||||
user.Collection().AuthAlert.Enabled = s.enabled
|
||||
|
||||
// ensure that there are no other auth origins
|
||||
err = app.DeleteAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert the mock devices
|
||||
for _, fingerprint := range s.devices {
|
||||
d := core.NewAuthOrigin(app)
|
||||
d.SetCollectionRef(user.Collection().Id)
|
||||
d.SetRecordRef(user.Id)
|
||||
d.SetFingerprint(fingerprint)
|
||||
if err = app.Save(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve auth response: %v", err)
|
||||
}
|
||||
|
||||
var expectTotalSend int
|
||||
if s.expectEmail {
|
||||
expectTotalSend = 1
|
||||
}
|
||||
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
|
||||
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
|
||||
}
|
||||
|
||||
devices, err := app.FindAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve auth origins: %v", err)
|
||||
}
|
||||
|
||||
if len(devices) != len(s.expectDevices) {
|
||||
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
|
||||
}
|
||||
|
||||
for _, fingerprint := range s.expectDevices {
|
||||
var exists bool
|
||||
fingerprints := make([]string, 0, len(devices))
|
||||
for _, d := range devices {
|
||||
if d.Fingerprint() == fingerprint {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
fingerprints = append(fingerprints, d.Fingerprint())
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseMFACheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = rec
|
||||
|
||||
resetMFAs := func(authRecord *core.Record) {
|
||||
// ensure that mfa is enabled
|
||||
user.Collection().MFA.Enabled = true
|
||||
user.Collection().MFA.Duration = 5
|
||||
user.Collection().MFA.Rule = ""
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
for _, mfa := range mfas {
|
||||
if err := app.Delete(mfa); err != nil {
|
||||
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// reset response
|
||||
rec = httptest.NewRecorder()
|
||||
event.Response = rec
|
||||
}
|
||||
|
||||
totalMFAs := func(authRecord *core.Record) int {
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
return len(mfas)
|
||||
}
|
||||
|
||||
t.Run("no collection MFA enabled", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no explicit auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=2"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=1"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa first-time", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected 0 mfa records, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
|
||||
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if totalMFAs(user) != 0 {
|
||||
t.Fatal("Expected the expired mfa record to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa for different auth record", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user2.Collection().Id)
|
||||
mfa.SetRecordRef(user2.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no user mfas, got %d", total)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user2); total != 1 {
|
||||
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
269
apis/serve.go
269
apis/serve.go
|
|
@ -3,6 +3,7 @@ package apis
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -12,14 +13,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/migrations"
|
||||
"github.com/pocketbase/pocketbase/migrations/logs"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/migrate"
|
||||
"github.com/pocketbase/pocketbase/ui"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
|
@ -29,10 +26,16 @@ type ServeConfig struct {
|
|||
// ShowStartBanner indicates whether to show or hide the server start console message.
|
||||
ShowStartBanner bool
|
||||
|
||||
// HttpAddr is the TCP address to listen for the HTTP server (eg. `127.0.0.1:80`).
|
||||
// DashboardPath specifies the route path to the superusers dashboard interface
|
||||
// (default to "/_/{path...}").
|
||||
//
|
||||
// Note: Must include the "{path...}" wildcard parameter.
|
||||
DashboardPath string
|
||||
|
||||
// HttpAddr is the TCP address to listen for the HTTP server (eg. "127.0.0.1:80").
|
||||
HttpAddr string
|
||||
|
||||
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. `127.0.0.1:443`).
|
||||
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. "127.0.0.1:443").
|
||||
HttpsAddr string
|
||||
|
||||
// Optional domains list to use when issuing the TLS certificate.
|
||||
|
|
@ -58,36 +61,43 @@ type ServeConfig struct {
|
|||
// HttpAddr: "127.0.0.1:8080",
|
||||
// ShowStartBanner: false,
|
||||
// })
|
||||
func Serve(app core.App, config ServeConfig) (*http.Server, error) {
|
||||
func Serve(app core.App, config ServeConfig) error {
|
||||
if len(config.AllowedOrigins) == 0 {
|
||||
config.AllowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
if config.DashboardPath == "" {
|
||||
config.DashboardPath = "/_/{path...}"
|
||||
} else if !strings.HasSuffix(config.DashboardPath, "{path...}") {
|
||||
return errors.New("invalid dashboard path - missing {path...} wildcard")
|
||||
}
|
||||
|
||||
// ensure that the latest migrations are applied before starting the server
|
||||
if err := runMigrations(app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// reload app settings in case a new default value was set with a migration
|
||||
// (or if this is the first time the init migration was executed)
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
color.Yellow("=====================================")
|
||||
color.Yellow("WARNING: Settings load error! \n%v", err)
|
||||
color.Yellow("Fallback to the application defaults.")
|
||||
color.Yellow("=====================================")
|
||||
}
|
||||
|
||||
router, err := InitApi(app)
|
||||
err := app.RunAllMigrations()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// configure cors
|
||||
router.Use(middleware.CORSWithConfig(middleware.CORSConfig{
|
||||
Skipper: middleware.DefaultSkipper,
|
||||
AllowOrigins: config.AllowedOrigins,
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}))
|
||||
pbRouter, err := NewRouter(app)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pbRouter.Bind(&hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultCorsMiddlewareId,
|
||||
Func: CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: config.AllowedOrigins,
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}),
|
||||
Priority: DefaultCorsMiddlewarePriority,
|
||||
})
|
||||
|
||||
pbRouter.BindFunc(installerRedirect(app, config.DashboardPath))
|
||||
|
||||
pbRouter.GET(config.DashboardPath, Static(ui.DistDirFS, false)).
|
||||
BindFunc(dashboardRemoveInstallerParam()).
|
||||
BindFunc(dashboardCacheControl()).
|
||||
BindFunc(Gzip())
|
||||
|
||||
// start http server
|
||||
// ---
|
||||
|
|
@ -118,25 +128,12 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
|
|||
|
||||
// implicit www->non-www redirect(s)
|
||||
if len(wwwRedirects) > 0 {
|
||||
router.Pre(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
host := c.Request().Host
|
||||
|
||||
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, wwwRedirects) {
|
||||
return c.Redirect(
|
||||
http.StatusTemporaryRedirect,
|
||||
(c.Scheme() + "://" + host[4:] + c.Request().RequestURI),
|
||||
)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
pbRouter.Bind(wwwRedirect(wwwRedirects))
|
||||
}
|
||||
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(filepath.Join(app.DataDir(), ".autocert_cache")),
|
||||
Cache: autocert.DirCache(filepath.Join(app.DataDir(), core.LocalAutocertCacheDirName)),
|
||||
HostPolicy: autocert.HostWhitelist(hostNames...),
|
||||
}
|
||||
|
||||
|
|
@ -151,24 +148,96 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
|
|||
GetCertificate: certManager.GetCertificate,
|
||||
NextProtos: []string{acme.ALPNProto},
|
||||
},
|
||||
ReadTimeout: 10 * time.Minute,
|
||||
// higher defaults to accommodate large file uploads/downloads
|
||||
WriteTimeout: 3 * time.Minute,
|
||||
ReadTimeout: 3 * time.Minute,
|
||||
ReadHeaderTimeout: 30 * time.Second,
|
||||
// WriteTimeout: 60 * time.Second, // breaks sse!
|
||||
Handler: router,
|
||||
Addr: mainAddr,
|
||||
Addr: mainAddr,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return baseCtx
|
||||
},
|
||||
ErrorLog: log.New(&serverErrorLogWriter{app: app}, "", 0),
|
||||
}
|
||||
|
||||
serveEvent := &core.ServeEvent{
|
||||
App: app,
|
||||
Router: router,
|
||||
Server: server,
|
||||
CertManager: certManager,
|
||||
}
|
||||
if err := app.OnBeforeServe().Trigger(serveEvent); err != nil {
|
||||
return nil, err
|
||||
serveEvent := new(core.ServeEvent)
|
||||
serveEvent.App = app
|
||||
serveEvent.Router = pbRouter
|
||||
serveEvent.Server = server
|
||||
serveEvent.CertManager = certManager
|
||||
|
||||
var listener net.Listener
|
||||
|
||||
// graceful shutdown
|
||||
// ---------------------------------------------------------------
|
||||
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
|
||||
// Note that the WaitGroup would do nothing if the app.OnTerminate() hook isn't triggered.
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// try to gracefully shutdown the server on app termination
|
||||
app.OnTerminate().Bind(&hook.Handler[*core.TerminateEvent]{
|
||||
Id: "pbGracefulShutdown",
|
||||
Func: func(te *core.TerminateEvent) error {
|
||||
cancelBaseCtx()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
_ = server.Shutdown(ctx)
|
||||
|
||||
if te.IsRestart {
|
||||
// wait for execve and other handlers up to 3 seconds before exit
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
wg.Done()
|
||||
})
|
||||
} else {
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
return te.Next()
|
||||
},
|
||||
Priority: -9999,
|
||||
})
|
||||
|
||||
// wait for the graceful shutdown to complete before exit
|
||||
defer func() {
|
||||
wg.Wait()
|
||||
|
||||
if listener != nil {
|
||||
_ = listener.Close()
|
||||
}
|
||||
}()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// trigger the OnServe hook and start the tcp listener
|
||||
serveHookErr := app.OnServe().Trigger(serveEvent, func(e *core.ServeEvent) error {
|
||||
handler, err := e.Router.BuildMux()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Server.Handler = handler
|
||||
|
||||
addr := e.Server.Addr
|
||||
|
||||
// fallback similar to the std Server.ListenAndServe/ListenAndServeTLS
|
||||
if addr == "" {
|
||||
if config.HttpsAddr != "" {
|
||||
addr = ":https"
|
||||
} else {
|
||||
addr = ":http"
|
||||
}
|
||||
}
|
||||
|
||||
var lnErr error
|
||||
|
||||
listener, lnErr = net.Listen("tcp", addr)
|
||||
|
||||
return lnErr
|
||||
})
|
||||
if serveHookErr != nil {
|
||||
return serveHookErr
|
||||
}
|
||||
|
||||
if config.ShowStartBanner {
|
||||
|
|
@ -198,80 +267,32 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
|
|||
regular.Printf("└─ Admin UI: %s\n", color.CyanString("%s://%s/_/", schema, addr))
|
||||
}
|
||||
|
||||
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
|
||||
// Note that the WaitGroup would not do anything if the app.OnTerminate() hook isn't triggered.
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// try to gracefully shutdown the server on app termination
|
||||
app.OnTerminate().Add(func(e *core.TerminateEvent) error {
|
||||
cancelBaseCtx()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
server.Shutdown(ctx)
|
||||
if e.IsRestart {
|
||||
// wait for execve and other handlers up to 5 seconds before exit
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
wg.Done()
|
||||
})
|
||||
} else {
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// wait for the graceful shutdown to complete before exit
|
||||
defer wg.Wait()
|
||||
|
||||
// ---
|
||||
// @todo consider removing the server return value because it is
|
||||
// not really useful when combined with the blocking serve calls
|
||||
// ---
|
||||
|
||||
// start HTTPS server
|
||||
var serveErr error
|
||||
if config.HttpsAddr != "" {
|
||||
// if httpAddr is set, start an HTTP server to redirect the traffic to the HTTPS version
|
||||
if config.HttpAddr != "" {
|
||||
// start an additional HTTP server for redirecting the traffic to the HTTPS version
|
||||
go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil))
|
||||
}
|
||||
|
||||
return server, server.ListenAndServeTLS("", "")
|
||||
// start HTTPS server
|
||||
serveErr = server.ServeTLS(listener, "", "")
|
||||
} else {
|
||||
// OR start HTTP server
|
||||
serveErr = server.Serve(listener)
|
||||
}
|
||||
|
||||
// OR start HTTP server
|
||||
return server, server.ListenAndServe()
|
||||
}
|
||||
|
||||
type migrationsConnection struct {
|
||||
DB *dbx.DB
|
||||
MigrationsList migrate.MigrationsList
|
||||
}
|
||||
|
||||
func runMigrations(app core.App) error {
|
||||
connections := []migrationsConnection{
|
||||
{
|
||||
DB: app.DB(),
|
||||
MigrationsList: migrations.AppMigrations,
|
||||
},
|
||||
{
|
||||
DB: app.LogsDB(),
|
||||
MigrationsList: logs.LogsMigrations,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range connections {
|
||||
runner, err := migrate.NewRunner(c.DB, c.MigrationsList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := runner.Up(); err != nil {
|
||||
return err
|
||||
}
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
return serveErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type serverErrorLogWriter struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (s *serverErrorLogWriter) Write(p []byte) (int, error) {
|
||||
s.app.Logger().Debug(strings.TrimSpace(string(p)))
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
|
|
|||
141
apis/settings.go
141
apis/settings.go
|
|
@ -4,136 +4,121 @@ import (
|
|||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/models/settings"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindSettingsApi registers the settings api endpoints.
|
||||
func bindSettingsApi(app core.App, rg *echo.Group) {
|
||||
api := settingsApi{app: app}
|
||||
|
||||
subGroup := rg.Group("/settings", ActivityLogger(app), RequireAdminAuth())
|
||||
subGroup.GET("", api.list)
|
||||
subGroup.PATCH("", api.set)
|
||||
subGroup.POST("/test/s3", api.testS3)
|
||||
subGroup.POST("/test/email", api.testEmail)
|
||||
subGroup.POST("/apple/generate-client-secret", api.generateAppleClientSecret)
|
||||
func bindSettingsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/settings").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", settingsList)
|
||||
subGroup.PATCH("", settingsSet)
|
||||
subGroup.POST("/test/s3", settingsTestS3)
|
||||
subGroup.POST("/test/email", settingsTestEmail)
|
||||
subGroup.POST("/apple/generate-client-secret", settingsGenerateAppleClientSecret)
|
||||
}
|
||||
|
||||
type settingsApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (api *settingsApi) list(c echo.Context) error {
|
||||
settings, err := api.app.Settings().RedactClone()
|
||||
func settingsList(e *core.RequestEvent) error {
|
||||
clone, err := e.App.Settings().Clone()
|
||||
if err != nil {
|
||||
return NewBadRequestError("", err)
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
event := new(core.SettingsListEvent)
|
||||
event.HttpContext = c
|
||||
event.RedactedSettings = settings
|
||||
event := new(core.SettingsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Settings = clone
|
||||
|
||||
return api.app.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, e.RedactedSettings)
|
||||
return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
|
||||
return e.JSON(http.StatusOK, e.Settings)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *settingsApi) set(c echo.Context) error {
|
||||
form := forms.NewSettingsUpsert(api.app)
|
||||
func settingsSet(e *core.RequestEvent) error {
|
||||
event := new(core.SettingsUpdateRequestEvent)
|
||||
event.RequestEvent = e
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
if clone, err := e.App.Settings().Clone(); err == nil {
|
||||
event.OldSettings = clone
|
||||
} else {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
event := new(core.SettingsUpdateEvent)
|
||||
event.HttpContext = c
|
||||
event.OldSettings = api.app.Settings()
|
||||
if clone, err := e.App.Settings().Clone(); err == nil {
|
||||
event.NewSettings = clone
|
||||
} else {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
// update the settings
|
||||
return form.Submit(func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] {
|
||||
return func(s *settings.Settings) error {
|
||||
event.NewSettings = s
|
||||
if err := e.BindBody(&event.NewSettings); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
return api.app.OnSettingsBeforeUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error {
|
||||
if err := next(e.NewSettings); err != nil {
|
||||
return NewBadRequestError("An error occurred while submitting the form.", err)
|
||||
}
|
||||
|
||||
return api.app.OnSettingsAfterUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error {
|
||||
if e.HttpContext.Response().Committed {
|
||||
return nil
|
||||
}
|
||||
|
||||
redactedSettings, err := api.app.Settings().RedactClone()
|
||||
if err != nil {
|
||||
return NewBadRequestError("", err)
|
||||
}
|
||||
|
||||
return e.HttpContext.JSON(http.StatusOK, redactedSettings)
|
||||
})
|
||||
})
|
||||
return e.App.OnSettingsUpdateRequest().Trigger(event, func(e *core.SettingsUpdateRequestEvent) error {
|
||||
err := e.App.Save(e.NewSettings)
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while saving the new settings.", err)
|
||||
}
|
||||
|
||||
appSettings, err := e.App.Settings().Clone()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to clone app settings.", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, appSettings)
|
||||
})
|
||||
}
|
||||
|
||||
func (api *settingsApi) testS3(c echo.Context) error {
|
||||
form := forms.NewTestS3Filesystem(api.app)
|
||||
func settingsTestS3(e *core.RequestEvent) error {
|
||||
form := forms.NewTestS3Filesystem(e.App)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// send
|
||||
if err := form.Submit(); err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return NewBadRequestError("Failed to test the S3 filesystem.", fErr)
|
||||
return e.BadRequestError("Failed to test the S3 filesystem.", fErr)
|
||||
}
|
||||
|
||||
// mailer error
|
||||
return NewBadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
|
||||
return e.BadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *settingsApi) testEmail(c echo.Context) error {
|
||||
form := forms.NewTestEmailSend(api.app)
|
||||
func settingsTestEmail(e *core.RequestEvent) error {
|
||||
form := forms.NewTestEmailSend(e.App)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// send
|
||||
if err := form.Submit(); err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return NewBadRequestError("Failed to send the test email.", fErr)
|
||||
return e.BadRequestError("Failed to send the test email.", fErr)
|
||||
}
|
||||
|
||||
// mailer error
|
||||
return NewBadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
|
||||
return e.BadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *settingsApi) generateAppleClientSecret(c echo.Context) error {
|
||||
form := forms.NewAppleClientSecretCreate(api.app)
|
||||
func settingsGenerateAppleClientSecret(e *core.RequestEvent) error {
|
||||
form := forms.NewAppleClientSecretCreate(e.App)
|
||||
|
||||
// load request
|
||||
if err := c.Bind(form); err != nil {
|
||||
return NewBadRequestError("An error occurred while loading the submitted data.", err)
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// generate
|
||||
|
|
@ -141,14 +126,14 @@ func (api *settingsApi) generateAppleClientSecret(c echo.Context) error {
|
|||
if err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return NewBadRequestError("Invalid client secret data.", fErr)
|
||||
return e.BadRequestError("Invalid client secret data.", fErr)
|
||||
}
|
||||
|
||||
// secret generation error
|
||||
return NewBadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
|
||||
return e.BadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
return e.JSON(http.StatusOK, map[string]string{
|
||||
"secret": secret,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,14 +6,11 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
|
|
@ -24,26 +21,28 @@ func TestSettingsList(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/settings",
|
||||
URL: "/api/settings",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/settings",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin",
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/settings",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
|
|
@ -52,44 +51,10 @@ func TestSettingsList(t *testing.T) {
|
|||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"adminAuthToken":{`,
|
||||
`"adminPasswordResetToken":{`,
|
||||
`"adminFileToken":{`,
|
||||
`"recordAuthToken":{`,
|
||||
`"recordPasswordResetToken":{`,
|
||||
`"recordEmailChangeToken":{`,
|
||||
`"recordVerificationToken":{`,
|
||||
`"recordFileToken":{`,
|
||||
`"emailAuth":{`,
|
||||
`"googleAuth":{`,
|
||||
`"facebookAuth":{`,
|
||||
`"githubAuth":{`,
|
||||
`"gitlabAuth":{`,
|
||||
`"twitterAuth":{`,
|
||||
`"discordAuth":{`,
|
||||
`"microsoftAuth":{`,
|
||||
`"spotifyAuth":{`,
|
||||
`"kakaoAuth":{`,
|
||||
`"twitchAuth":{`,
|
||||
`"stravaAuth":{`,
|
||||
`"giteeAuth":{`,
|
||||
`"livechatAuth":{`,
|
||||
`"giteaAuth":{`,
|
||||
`"oidcAuth":{`,
|
||||
`"oidc2Auth":{`,
|
||||
`"oidc3Auth":{`,
|
||||
`"appleAuth":{`,
|
||||
`"instagramAuth":{`,
|
||||
`"vkAuth":{`,
|
||||
`"yandexAuth":{`,
|
||||
`"patreonAuth":{`,
|
||||
`"mailcowAuth":{`,
|
||||
`"bitbucketAuth":{`,
|
||||
`"planningcenterAuth":{`,
|
||||
`"secret":"******"`,
|
||||
`"clientSecret":"******"`,
|
||||
`"batch":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsListRequest": 1,
|
||||
},
|
||||
},
|
||||
|
|
@ -103,35 +68,41 @@ func TestSettingsList(t *testing.T) {
|
|||
func TestSettingsSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validData := `{"meta":{"appName":"update_test"}}`
|
||||
validData := `{
|
||||
"meta":{"appName":"update_test"},
|
||||
"s3":{"secret": "s3_secret"},
|
||||
"backups":{"s3":{"secret":"backups_s3_secret"}}
|
||||
}`
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/settings",
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/settings",
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin submitting empty data",
|
||||
Name: "authorized as superuser submitting empty data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/settings",
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(``),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
|
|
@ -140,71 +111,46 @@ func TestSettingsSet(t *testing.T) {
|
|||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"adminAuthToken":{`,
|
||||
`"adminPasswordResetToken":{`,
|
||||
`"adminFileToken":{`,
|
||||
`"recordAuthToken":{`,
|
||||
`"recordPasswordResetToken":{`,
|
||||
`"recordEmailChangeToken":{`,
|
||||
`"recordVerificationToken":{`,
|
||||
`"recordFileToken":{`,
|
||||
`"emailAuth":{`,
|
||||
`"googleAuth":{`,
|
||||
`"facebookAuth":{`,
|
||||
`"githubAuth":{`,
|
||||
`"gitlabAuth":{`,
|
||||
`"discordAuth":{`,
|
||||
`"microsoftAuth":{`,
|
||||
`"spotifyAuth":{`,
|
||||
`"kakaoAuth":{`,
|
||||
`"twitchAuth":{`,
|
||||
`"stravaAuth":{`,
|
||||
`"giteeAuth":{`,
|
||||
`"livechatAuth":{`,
|
||||
`"giteaAuth":{`,
|
||||
`"oidcAuth":{`,
|
||||
`"oidc2Auth":{`,
|
||||
`"oidc3Auth":{`,
|
||||
`"appleAuth":{`,
|
||||
`"instagramAuth":{`,
|
||||
`"vkAuth":{`,
|
||||
`"yandexAuth":{`,
|
||||
`"patreonAuth":{`,
|
||||
`"mailcowAuth":{`,
|
||||
`"bitbucketAuth":{`,
|
||||
`"planningcenterAuth":{`,
|
||||
`"secret":"******"`,
|
||||
`"clientSecret":"******"`,
|
||||
`"appName":"acme_test"`,
|
||||
`"batch":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnSettingsBeforeUpdateRequest": 1,
|
||||
"OnSettingsAfterUpdateRequest": 1,
|
||||
"*": 0,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsReload": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin submitting invalid data",
|
||||
Name: "authorized as superuser submitting invalid data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/settings",
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(`{"meta":{"appName":""}}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"meta":{"appName":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelAfterUpdateError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin submitting valid data",
|
||||
Name: "authorized as superuser submitting valid data",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/settings",
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
|
|
@ -213,71 +159,21 @@ func TestSettingsSet(t *testing.T) {
|
|||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"adminAuthToken":{`,
|
||||
`"adminPasswordResetToken":{`,
|
||||
`"adminFileToken":{`,
|
||||
`"recordAuthToken":{`,
|
||||
`"recordPasswordResetToken":{`,
|
||||
`"recordEmailChangeToken":{`,
|
||||
`"recordVerificationToken":{`,
|
||||
`"recordFileToken":{`,
|
||||
`"emailAuth":{`,
|
||||
`"googleAuth":{`,
|
||||
`"facebookAuth":{`,
|
||||
`"githubAuth":{`,
|
||||
`"gitlabAuth":{`,
|
||||
`"twitterAuth":{`,
|
||||
`"discordAuth":{`,
|
||||
`"microsoftAuth":{`,
|
||||
`"spotifyAuth":{`,
|
||||
`"kakaoAuth":{`,
|
||||
`"twitchAuth":{`,
|
||||
`"stravaAuth":{`,
|
||||
`"giteeAuth":{`,
|
||||
`"livechatAuth":{`,
|
||||
`"giteaAuth":{`,
|
||||
`"oidcAuth":{`,
|
||||
`"oidc2Auth":{`,
|
||||
`"oidc3Auth":{`,
|
||||
`"appleAuth":{`,
|
||||
`"instagramAuth":{`,
|
||||
`"vkAuth":{`,
|
||||
`"yandexAuth":{`,
|
||||
`"patreonAuth":{`,
|
||||
`"mailcowAuth":{`,
|
||||
`"bitbucketAuth":{`,
|
||||
`"planningcenterAuth":{`,
|
||||
`"secret":"******"`,
|
||||
`"clientSecret":"******"`,
|
||||
`"batch":{`,
|
||||
`"appName":"update_test"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"secret",
|
||||
"password",
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnSettingsBeforeUpdateRequest": 1,
|
||||
"OnSettingsAfterUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnSettingsAfterUpdateRequest error response",
|
||||
Method: http.MethodPatch,
|
||||
Url: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
},
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.OnSettingsAfterUpdateRequest().Add(func(e *core.SettingsUpdateEvent) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnModelBeforeUpdate": 1,
|
||||
"OnModelAfterUpdate": 1,
|
||||
"OnSettingsBeforeUpdateRequest": 1,
|
||||
"OnSettingsAfterUpdateRequest": 1,
|
||||
"*": 0,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsReload": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -294,59 +190,64 @@ func TestSettingsTestS3(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/s3",
|
||||
URL: "/api/settings/test/s3",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/s3",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/settings/test/s3",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (missing body + no s3)",
|
||||
Name: "authorized as superuser (missing body + no s3)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/s3",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
URL: "/api/settings/test/s3",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"filesystem":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (invalid filesystem)",
|
||||
Name: "authorized as superuser (invalid filesystem)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/s3",
|
||||
URL: "/api/settings/test/s3",
|
||||
Body: strings.NewReader(`{"filesystem":"invalid"}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"filesystem":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (valid filesystem and no s3)",
|
||||
Name: "authorized as superuser (valid filesystem and no s3)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/s3",
|
||||
URL: "/api/settings/test/s3",
|
||||
Body: strings.NewReader(`{"filesystem":"storage"}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -362,156 +263,199 @@ func TestSettingsTestEmail(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (invalid body)",
|
||||
Name: "authorized as superuser (invalid body)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (empty json)",
|
||||
Name: "authorized as superuser (empty json)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"email":{"code":"validation_required"`,
|
||||
`"template":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (verifiation template)",
|
||||
Name: "authorized as superuser (verifiation template)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend != 1 {
|
||||
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend)
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage.To) != 1 {
|
||||
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To)
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address)
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Verify") {
|
||||
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML)
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Verify") {
|
||||
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnMailerBeforeRecordVerificationSend": 1,
|
||||
"OnMailerAfterRecordVerificationSend": 1,
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordVerificationSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (password reset template)",
|
||||
Name: "authorized as superuser (password reset template)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "password-reset",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend)
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage.To) != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To)
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address)
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Reset password") {
|
||||
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML)
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Reset password") {
|
||||
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnMailerBeforeRecordResetPasswordSend": 1,
|
||||
"OnMailerAfterRecordResetPasswordSend": 1,
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordPasswordResetSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (email change)",
|
||||
Name: "authorized as superuser (email change)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/test/email",
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "email-change",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend)
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage.To) != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To)
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address)
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Confirm new email") {
|
||||
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML)
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Confirm new email") {
|
||||
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"OnMailerBeforeRecordChangeEmailSend": 1,
|
||||
"OnMailerAfterRecordChangeEmailSend": 1,
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordEmailChangeSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (otp)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "otp",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[otp] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[otp] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[otp] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "one-time password") {
|
||||
t.Fatalf("[otp] Expected to sent OTP email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -545,38 +489,41 @@ func TestGenerateAppleClientSecret(t *testing.T) {
|
|||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/apple/generate-client-secret",
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as auth record",
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/apple/generate-client-secret",
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (invalid body)",
|
||||
Name: "authorized as superuser (invalid body)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/apple/generate-client-secret",
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (empty json)",
|
||||
Name: "authorized as superuser (empty json)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/apple/generate-client-secret",
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
|
|
@ -586,11 +533,12 @@ func TestGenerateAppleClientSecret(t *testing.T) {
|
|||
`"privateKey":{"code":"validation_required"`,
|
||||
`"duration":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (invalid data)",
|
||||
Name: "authorized as superuser (invalid data)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/apple/generate-client-secret",
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "",
|
||||
"teamId": "123456789",
|
||||
|
|
@ -598,8 +546,8 @@ func TestGenerateAppleClientSecret(t *testing.T) {
|
|||
"privateKey": "invalid",
|
||||
"duration": -1
|
||||
}`),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
|
|
@ -609,11 +557,12 @@ func TestGenerateAppleClientSecret(t *testing.T) {
|
|||
`"privateKey":{"code":"validation_match_invalid"`,
|
||||
`"duration":{"code":"validation_min_greater_equal_than_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as admin (valid data)",
|
||||
Name: "authorized as superuser (valid data)",
|
||||
Method: http.MethodPost,
|
||||
Url: "/api/settings/apple/generate-client-secret",
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(fmt.Sprintf(`{
|
||||
"clientId": "123",
|
||||
"teamId": "1234567890",
|
||||
|
|
@ -621,13 +570,14 @@ func TestGenerateAppleClientSecret(t *testing.T) {
|
|||
"privateKey": %q,
|
||||
"duration": 1
|
||||
}`, privatePem)),
|
||||
RequestHeaders: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"secret":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
141
cmd/admin.go
141
cmd/admin.go
|
|
@ -1,141 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewAdminCommand creates and returns new command for managing
|
||||
// admin accounts (create, update, delete).
|
||||
func NewAdminCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "admin",
|
||||
Short: "Manages admin accounts",
|
||||
}
|
||||
|
||||
command.AddCommand(adminCreateCommand(app))
|
||||
command.AddCommand(adminUpdateCommand(app))
|
||||
command.AddCommand(adminDeleteCommand(app))
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func adminCreateCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "create",
|
||||
Example: "admin create test@example.com 1234567890",
|
||||
Short: "Creates a new admin account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("Missing email and password arguments.")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Missing or invalid email address.")
|
||||
}
|
||||
|
||||
if len(args[1]) < 8 {
|
||||
return errors.New("The password must be at least 8 chars long.")
|
||||
}
|
||||
|
||||
admin := &models.Admin{}
|
||||
admin.Email = args[0]
|
||||
admin.SetPassword(args[1])
|
||||
|
||||
if !app.Dao().HasTable(admin.TableName()) {
|
||||
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
|
||||
}
|
||||
|
||||
if err := app.Dao().SaveAdmin(admin); err != nil {
|
||||
return fmt.Errorf("Failed to create new admin account: %v", err)
|
||||
}
|
||||
|
||||
color.Green("Successfully created new admin %s!", admin.Email)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func adminUpdateCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "update",
|
||||
Example: "admin update test@example.com 1234567890",
|
||||
Short: "Changes the password of a single admin account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("Missing email and password arguments.")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Missing or invalid email address.")
|
||||
}
|
||||
|
||||
if len(args[1]) < 8 {
|
||||
return errors.New("The new password must be at least 8 chars long.")
|
||||
}
|
||||
|
||||
if !app.Dao().HasTable((&models.Admin{}).TableName()) {
|
||||
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
|
||||
}
|
||||
|
||||
admin, err := app.Dao().FindAdminByEmail(args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Admin with email %s doesn't exist.", args[0])
|
||||
}
|
||||
|
||||
admin.SetPassword(args[1])
|
||||
|
||||
if err := app.Dao().SaveAdmin(admin); err != nil {
|
||||
return fmt.Errorf("Failed to change admin %s password: %v", admin.Email, err)
|
||||
}
|
||||
|
||||
color.Green("Successfully changed admin %s password!", admin.Email)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func adminDeleteCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "delete",
|
||||
Example: "admin delete test@example.com",
|
||||
Short: "Deletes an existing admin account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Invalid or missing email address.")
|
||||
}
|
||||
|
||||
if !app.Dao().HasTable((&models.Admin{}).TableName()) {
|
||||
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
|
||||
}
|
||||
|
||||
admin, err := app.Dao().FindAdminByEmail(args[0])
|
||||
if err != nil {
|
||||
color.Yellow("Admin %s is already deleted.", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := app.Dao().DeleteAdmin(admin); err != nil {
|
||||
return fmt.Errorf("Failed to delete admin %s: %v", admin.Email, err)
|
||||
}
|
||||
|
||||
color.Green("Successfully deleted admin %s!", admin.Email)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
package cmd_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/cmd"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestAdminCreateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"duplicated email",
|
||||
"test@example.com",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid email and password",
|
||||
"test_new@example.com",
|
||||
"12345678",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
command := cmd.NewAdminCommand(app)
|
||||
command.SetArgs([]string{"create", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
// check whether the admin account was actually created
|
||||
admin, err := app.Dao().FindAdminByEmail(s.email)
|
||||
if err != nil {
|
||||
t.Errorf("[%s] Failed to fetch created admin %s: %v", s.name, s.email, err)
|
||||
} else if !admin.ValidatePassword(s.password) {
|
||||
t.Errorf("[%s] Expected the admin password to match", s.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminUpdateCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting admin",
|
||||
"test_missing@example.com",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid email and password",
|
||||
"test@example.com",
|
||||
"12345678",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
command := cmd.NewAdminCommand(app)
|
||||
command.SetArgs([]string{"update", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
// check whether the admin password was actually changed
|
||||
admin, err := app.Dao().FindAdminByEmail(s.email)
|
||||
if err != nil {
|
||||
t.Errorf("[%s] Failed to fetch admin %s: %v", s.name, s.email, err)
|
||||
} else if !admin.ValidatePassword(s.password) {
|
||||
t.Errorf("[%s] Expected the admin password to match", s.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDeleteCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting admin",
|
||||
"test_missing@example.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"existing admin",
|
||||
"test@example.com",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
command := cmd.NewAdminCommand(app)
|
||||
command.SetArgs([]string{"delete", s.email})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
// check whether the admin account was actually deleted
|
||||
if _, err := app.Dao().FindAdminByEmail(s.email); err == nil {
|
||||
t.Errorf("[%s] Expected the admin account to be deleted", s.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
11
cmd/serve.go
11
cmd/serve.go
|
|
@ -15,6 +15,7 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
|
|||
var allowedOrigins []string
|
||||
var httpAddr string
|
||||
var httpsAddr string
|
||||
var dashboardPath string
|
||||
|
||||
command := &cobra.Command{
|
||||
Use: "serve [domain(s)]",
|
||||
|
|
@ -36,9 +37,10 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
|
|||
}
|
||||
}
|
||||
|
||||
_, err := apis.Serve(app, apis.ServeConfig{
|
||||
err := apis.Serve(app, apis.ServeConfig{
|
||||
HttpAddr: httpAddr,
|
||||
HttpsAddr: httpsAddr,
|
||||
DashboardPath: dashboardPath,
|
||||
ShowStartBanner: showStartBanner,
|
||||
AllowedOrigins: allowedOrigins,
|
||||
CertificateDomains: args,
|
||||
|
|
@ -73,5 +75,12 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
|
|||
"TCP address to listen for the HTTPS server\n(if domain args are specified - default to 0.0.0.0:443, otherwise - default to empty string, aka. no TLS)\nThe incoming HTTP traffic also will be auto redirected to the HTTPS version",
|
||||
)
|
||||
|
||||
command.PersistentFlags().StringVar(
|
||||
&dashboardPath,
|
||||
"dashboard",
|
||||
"/_/{path...}",
|
||||
"The route path to the superusers dashboard; must include the '{path...}' wildcard parameter",
|
||||
)
|
||||
|
||||
return command
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,166 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewSuperuserCommand creates and returns new command for managing
|
||||
// superuser accounts (create, update, delete).
|
||||
func NewSuperuserCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "superuser",
|
||||
Short: "Manages superuser accounts",
|
||||
}
|
||||
|
||||
command.AddCommand(superuserUpsertCommand(app))
|
||||
command.AddCommand(superuserCreateCommand(app))
|
||||
command.AddCommand(superuserUpdateCommand(app))
|
||||
command.AddCommand(superuserDeleteCommand(app))
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserUpsertCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "upsert",
|
||||
Example: "superuser upsert test@example.com 1234567890",
|
||||
Short: "Creates, or updates if email exists, a single superuser account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("Missing email and password arguments.")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Missing or invalid email address.")
|
||||
}
|
||||
|
||||
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to fetch %q collection: %w.", core.CollectionNameSuperusers, err)
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(superusersCol, args[0])
|
||||
if err != nil {
|
||||
superuser = core.NewRecord(superusersCol)
|
||||
}
|
||||
|
||||
superuser.SetEmail(args[0])
|
||||
superuser.SetPassword(args[1])
|
||||
|
||||
if err := app.Save(superuser); err != nil {
|
||||
return fmt.Errorf("Failed to upsert superuser account: %w.", err)
|
||||
}
|
||||
|
||||
color.Green("Successfully saved superuser %q!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserCreateCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "create",
|
||||
Example: "superuser create test@example.com 1234567890",
|
||||
Short: "Creates a new superuser account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("Missing email and password arguments.")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Missing or invalid email address.")
|
||||
}
|
||||
|
||||
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to fetch %q collection: %w.", core.CollectionNameSuperusers, err)
|
||||
}
|
||||
|
||||
superuser := core.NewRecord(superusersCol)
|
||||
superuser.SetEmail(args[0])
|
||||
superuser.SetPassword(args[1])
|
||||
|
||||
if err := app.Save(superuser); err != nil {
|
||||
return fmt.Errorf("Failed to create new superuser account: %w.", err)
|
||||
}
|
||||
|
||||
color.Green("Successfully created new superuser %q!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserUpdateCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "update",
|
||||
Example: "superuser update test@example.com 1234567890",
|
||||
Short: "Changes the password of a single superuser account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("Missing email and password arguments.")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Missing or invalid email address.")
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Superuser with email %q doesn't exist.", args[0])
|
||||
}
|
||||
|
||||
superuser.SetPassword(args[1])
|
||||
|
||||
if err := app.Save(superuser); err != nil {
|
||||
return fmt.Errorf("Failed to change superuser %q password: %w.", superuser.Email(), err)
|
||||
}
|
||||
|
||||
color.Green("Successfully changed superuser %q password!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserDeleteCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "delete",
|
||||
Example: "superuser delete test@example.com",
|
||||
Short: "Deletes an existing superuser account",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("Invalid or missing email address.")
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
|
||||
if err != nil {
|
||||
color.Yellow("Superuser %q is missing or already deleted.", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := app.Delete(superuser); err != nil {
|
||||
return fmt.Errorf("Failed to delete superuser %q: %w.", superuser.Email(), err)
|
||||
}
|
||||
|
||||
color.Green("Successfully deleted superuser %q!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
|
@ -0,0 +1,310 @@
|
|||
package cmd_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/cmd"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestSuperuserUpsertCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"existing user",
|
||||
"test@example.com",
|
||||
"1234567890!",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"new user",
|
||||
"test_new@example.com",
|
||||
"1234567890!",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"upsert", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
// check whether the superuser account was actually upserted
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
|
||||
} else if !superuser.ValidatePassword(s.password) {
|
||||
t.Fatal("Expected the superuser password to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserCreateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"duplicated email",
|
||||
"test@example.com",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid email and password",
|
||||
"test_new@example.com",
|
||||
"12345678",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"create", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
// check whether the superuser account was actually created
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch created superuser %s: %v", s.email, err)
|
||||
} else if !superuser.ValidatePassword(s.password) {
|
||||
t.Fatal("Expected the superuser password to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserUpdateCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting superuser",
|
||||
"test_missing@example.com",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid email and password",
|
||||
"test@example.com",
|
||||
"12345678",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"update", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
// check whether the superuser password was actually changed
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
|
||||
} else if !superuser.ValidatePassword(s.password) {
|
||||
t.Fatal("Expected the superuser password to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserDeleteCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting superuser",
|
||||
"test_missing@example.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"existing superuser",
|
||||
"test@example.com",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"delete", s.email})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email); err == nil {
|
||||
t.Fatal("Expected the superuser account to be deleted")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1667
core/app.go
1667
core/app.go
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,239 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
const CollectionNameAuthOrigins = "_authOrigins"
|
||||
|
||||
var (
|
||||
_ Model = (*AuthOrigin)(nil)
|
||||
_ PreValidator = (*AuthOrigin)(nil)
|
||||
_ RecordProxy = (*AuthOrigin)(nil)
|
||||
)
|
||||
|
||||
// AuthOrigin defines a Record proxy for working with the authOrigins collection.
|
||||
type AuthOrigin struct {
|
||||
*Record
|
||||
}
|
||||
|
||||
// NewAuthOrigin instantiates and returns a new blank *AuthOrigin model.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// origin := core.NewOrigin(app)
|
||||
// origin.SetRecordRef(user.Id)
|
||||
// origin.SetCollectionRef(user.Collection().Id)
|
||||
// origin.SetFingerprint("...")
|
||||
// app.Save(origin)
|
||||
func NewAuthOrigin(app App) *AuthOrigin {
|
||||
m := &AuthOrigin{}
|
||||
|
||||
c, err := app.FindCachedCollectionByNameOrId(CollectionNameAuthOrigins)
|
||||
if err != nil {
|
||||
// this is just to make tests easier since authOrigins is a system collection and it is expected to be always accessible
|
||||
// (note: the loaded record is further checked on AuthOrigin.PreValidate())
|
||||
c = NewBaseCollection("@___invalid___")
|
||||
}
|
||||
|
||||
m.Record = NewRecord(c)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// PreValidate implements the [PreValidator] interface and checks
|
||||
// whether the proxy is properly loaded.
|
||||
func (m *AuthOrigin) PreValidate(ctx context.Context, app App) error {
|
||||
if m.Record == nil || m.Record.Collection().Name != CollectionNameAuthOrigins {
|
||||
return errors.New("missing or invalid AuthOrigin ProxyRecord")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyRecord returns the proxied Record model.
|
||||
func (m *AuthOrigin) ProxyRecord() *Record {
|
||||
return m.Record
|
||||
}
|
||||
|
||||
// SetProxyRecord loads the specified record model into the current proxy.
|
||||
func (m *AuthOrigin) SetProxyRecord(record *Record) {
|
||||
m.Record = record
|
||||
}
|
||||
|
||||
// CollectionRef returns the "collectionRef" field value.
|
||||
func (m *AuthOrigin) CollectionRef() string {
|
||||
return m.GetString("collectionRef")
|
||||
}
|
||||
|
||||
// SetCollectionRef updates the "collectionRef" record field value.
|
||||
func (m *AuthOrigin) SetCollectionRef(collectionId string) {
|
||||
m.Set("collectionRef", collectionId)
|
||||
}
|
||||
|
||||
// RecordRef returns the "recordRef" record field value.
|
||||
func (m *AuthOrigin) RecordRef() string {
|
||||
return m.GetString("recordRef")
|
||||
}
|
||||
|
||||
// SetRecordRef updates the "recordRef" record field value.
|
||||
func (m *AuthOrigin) SetRecordRef(recordId string) {
|
||||
m.Set("recordRef", recordId)
|
||||
}
|
||||
|
||||
// Fingerprint returns the "fingerprint" record field value.
|
||||
func (m *AuthOrigin) Fingerprint() string {
|
||||
return m.GetString("fingerprint")
|
||||
}
|
||||
|
||||
// SetFingerprint updates the "fingerprint" record field value.
|
||||
func (m *AuthOrigin) SetFingerprint(fingerprint string) {
|
||||
m.Set("fingerprint", fingerprint)
|
||||
}
|
||||
|
||||
// Created returns the "created" record field value.
|
||||
func (m *AuthOrigin) Created() types.DateTime {
|
||||
return m.GetDateTime("created")
|
||||
}
|
||||
|
||||
// Updated returns the "updated" record field value.
|
||||
func (m *AuthOrigin) Updated() types.DateTime {
|
||||
return m.GetDateTime("updated")
|
||||
}
|
||||
|
||||
func (app *BaseApp) registerAuthOriginHooks() {
|
||||
recordRefHooks[*AuthOrigin](app, CollectionNameAuthOrigins, CollectionTypeAuth)
|
||||
|
||||
// delete existing auth origins on password change
|
||||
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil || !e.Record.Collection().IsAuth() {
|
||||
return err
|
||||
}
|
||||
|
||||
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
|
||||
new := e.Record.GetString(FieldNamePassword + ":hash")
|
||||
if old != new {
|
||||
err = e.App.DeleteAllAuthOriginsByRecord(e.Record)
|
||||
if err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Failed to delete all previous auth origin fingerprints",
|
||||
"error", err,
|
||||
"recordId", e.Record.Id,
|
||||
"collectionId", e.Record.Collection().Id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// recordRefHooks registers common hooks that are usually used with record proxies
|
||||
// that have polymorphic record relations (aka. "collectionRef" and "recordRef" fields).
|
||||
func recordRefHooks[T RecordProxy](app App, collectionName string, optCollectionTypes ...string) {
|
||||
app.OnRecordValidate(collectionName).Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
collectionId := e.Record.GetString("collectionRef")
|
||||
err := validation.Validate(collectionId, validation.Required, validation.By(validateCollectionId(e.App, optCollectionTypes...)))
|
||||
if err != nil {
|
||||
return validation.Errors{"collectionRef": err}
|
||||
}
|
||||
|
||||
recordId := e.Record.GetString("recordRef")
|
||||
err = validation.Validate(recordId, validation.Required, validation.By(validateRecordId(e.App, collectionId)))
|
||||
if err != nil {
|
||||
return validation.Errors{"recordRef": err}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
// delete on collection ref delete
|
||||
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Func: func(e *CollectionEvent) error {
|
||||
if e.Collection.Name == collectionName || (len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Collection.Type)) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{"collectionRef": e.Collection.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, mfa := range rels {
|
||||
if err := txApp.Delete(mfa); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
// delete on record ref delete
|
||||
app.OnRecordDeleteExecute().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
if e.Record.Collection().Name == collectionName ||
|
||||
(len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Record.Collection().Type)) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{
|
||||
"collectionRef": e.Record.Collection().Id,
|
||||
"recordRef": e.Record.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rel := range rels {
|
||||
if err := txApp.Delete(rel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewAuthOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if origin.Collection().Name != core.CollectionNameAuthOrigins {
|
||||
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameAuthOrigins, origin.Collection().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginProxyRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.Id = "test_id"
|
||||
|
||||
origin := core.AuthOrigin{}
|
||||
origin.SetProxyRecord(record)
|
||||
|
||||
if origin.ProxyRecord() == nil || origin.ProxyRecord().Id != record.Id {
|
||||
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, origin.ProxyRecord())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginRecordRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetRecordRef(testValue)
|
||||
|
||||
if v := origin.RecordRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("recordRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginCollectionRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetCollectionRef(testValue)
|
||||
|
||||
if v := origin.CollectionRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("collectionRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetFingerprint(testValue)
|
||||
|
||||
if v := origin.Fingerprint(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("fingerprint"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginCreated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if v := origin.Created().String(); v != "" {
|
||||
t.Fatalf("Expected empty created, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
origin.SetRaw("created", now)
|
||||
|
||||
if v := origin.Created().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q created, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginUpdated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if v := origin.Updated().String(); v != "" {
|
||||
t.Fatalf("Expected empty updated, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
origin.SetRaw("updated", now)
|
||||
|
||||
if v := origin.Updated().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q updated, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginPreValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
originsCol, err := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("no proxy record", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
|
||||
if err := app.Validate(origin); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-AuthOrigin collection", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
origin.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetFingerprint("abc")
|
||||
|
||||
if err := app.Validate(origin); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AuthOrigin collection", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
origin.SetProxyRecord(core.NewRecord(originsCol))
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetFingerprint("abc")
|
||||
|
||||
if err := app.Validate(origin); err != nil {
|
||||
t.Fatalf("Expected nil validation error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthOriginValidateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
origin func() *core.AuthOrigin
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
func() *core.AuthOrigin {
|
||||
return core.NewAuthOrigin(app)
|
||||
},
|
||||
[]string{"collectionRef", "recordRef", "fingerprint"},
|
||||
},
|
||||
{
|
||||
"non-auth collection",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(demo1.Collection().Id)
|
||||
origin.SetRecordRef(demo1.Id)
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{"collectionRef"},
|
||||
},
|
||||
{
|
||||
"missing record id",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetRecordRef("missing")
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"valid ref",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := app.Validate(s.origin())
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginPasswordChangeDeletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// no auth origin associated with it
|
||||
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{user1, nil},
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
s.record.SetPassword("new_password")
|
||||
|
||||
err := app.Save(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// FindAllAuthOriginsByRecord returns all AuthOrigin models linked to the provided auth record (in DESC order).
|
||||
func (app *BaseApp) FindAllAuthOriginsByRecord(authRecord *Record) ([]*AuthOrigin, error) {
|
||||
result := []*AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAllAuthOriginsByCollection returns all AuthOrigin models linked to the provided collection (in DESC order).
|
||||
func (app *BaseApp) FindAllAuthOriginsByCollection(collection *Collection) ([]*AuthOrigin, error) {
|
||||
result := []*AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAuthOriginById returns a single AuthOrigin model by its id.
|
||||
func (app *BaseApp) FindAuthOriginById(id string) (*AuthOrigin, error) {
|
||||
result := &AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAuthOriginByRecordAndFingerprint returns a single AuthOrigin model
|
||||
// by its authRecord relation and fingerprint.
|
||||
func (app *BaseApp) FindAuthOriginByRecordAndFingerprint(authRecord *Record, fingerprint string) (*AuthOrigin, error) {
|
||||
result := &AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
"fingerprint": fingerprint,
|
||||
}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllAuthOriginsByRecord deletes all AuthOrigin models associated with the provided record.
|
||||
//
|
||||
// Returns a combined error with the failed deletes.
|
||||
func (app *BaseApp) DeleteAllAuthOriginsByRecord(authRecord *Record) error {
|
||||
models, err := app.FindAllAuthOriginsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, m := range models {
|
||||
if err := app.Delete(m); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestFindAllAuthOriginsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{superuser4, nil},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
|
||||
result, err := app.FindAllAuthOriginsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllAuthOriginsByCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
collection *core.Collection
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superusers, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib", "5f29jy38bf5zm3f"}},
|
||||
{clients, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.collection.Name, func(t *testing.T) {
|
||||
result, err := app.FindAllAuthOriginsByCollection(s.collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAuthOriginById(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"84nmscqy84lsi1t", true}, // non-origin id
|
||||
{"9r2j0m74260ur8i", false},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.id, func(t *testing.T) {
|
||||
result, err := app.FindAuthOriginById(s.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Id != s.id {
|
||||
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAuthOriginByRecordAndFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
fingerprint string
|
||||
expectError bool
|
||||
}{
|
||||
{demo1, "6afbfe481c31c08c55a746cccb88ece0", true},
|
||||
{superuser2, "", true},
|
||||
{superuser2, "abc", true},
|
||||
{superuser2, "22bbbcbed36e25321f384ccf99f60057", false}, // fingerprint from different origin
|
||||
{superuser2, "6afbfe481c31c08c55a746cccb88ece0", false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Id, s.fingerprint), func(t *testing.T) {
|
||||
result, err := app.FindAuthOriginByRecordAndFingerprint(s.record, s.fingerprint)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Fingerprint() != s.fingerprint {
|
||||
t.Fatalf("Expected origin with fingerprint %q, got %q", s.fingerprint, result.Fingerprint())
|
||||
}
|
||||
|
||||
if result.RecordRef() != s.record.Id || result.CollectionRef() != s.record.Collection().Id {
|
||||
t.Fatalf("Expected record %q (%q), got %q (%q)", s.record.Id, s.record.Collection().Id, result.RecordRef(), result.CollectionRef())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllAuthOriginsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{demo1, nil}, // non-auth record
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{superuser4, nil},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
err := app.DeleteAllAuthOriginsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1705
core/base.go
1705
core/base.go
File diff suppressed because it is too large
Load Diff
|
|
@ -12,20 +12,16 @@ import (
|
|||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
"github.com/pocketbase/pocketbase/tools/cron"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// Deprecated: Replaced with StoreKeyActiveBackup.
|
||||
const CacheKeyActiveBackup string = "@activeBackup"
|
||||
|
||||
const StoreKeyActiveBackup string = "@activeBackup"
|
||||
const (
|
||||
StoreKeyActiveBackup = "@activeBackup"
|
||||
)
|
||||
|
||||
// CreateBackup creates a new backup of the current app pb_data directory.
|
||||
//
|
||||
|
|
@ -50,61 +46,67 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
|
|||
return errors.New("try again later - another backup/restore operation has already been started")
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
name = app.generateBackupName("pb_backup_")
|
||||
}
|
||||
|
||||
app.Store().Set(StoreKeyActiveBackup, name)
|
||||
defer app.Store().Remove(StoreKeyActiveBackup)
|
||||
|
||||
// root dir entries to exclude from the backup generation
|
||||
exclude := []string{LocalBackupsDirName, LocalTempDirName}
|
||||
event := new(BackupEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Name = name
|
||||
// default root dir entries to exclude from the backup generation
|
||||
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(app.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
return app.OnBackupCreate().Trigger(event, func(e *BackupEvent) error {
|
||||
// generate a default name if missing
|
||||
if e.Name == "" {
|
||||
e.Name = generateBackupName(e.App, "pb_backup_")
|
||||
}
|
||||
|
||||
// Archive pb_data in a temp directory, exluding the "backups" and the temp dirs.
|
||||
//
|
||||
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
|
||||
// ---
|
||||
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(4))
|
||||
createErr := app.Dao().RunInTransaction(func(dataTXDao *daos.Dao) error {
|
||||
return app.LogsDao().RunInTransaction(func(logsTXDao *daos.Dao) error {
|
||||
// @todo consider experimenting with temp switching the readonly pragma after the db interface change
|
||||
return archive.Create(app.DataDir(), tempPath, exclude...)
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
// archive pb_data in a temp directory, exluding the "backups" and the temp dirs
|
||||
//
|
||||
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
|
||||
// ---
|
||||
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(6))
|
||||
createErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
return txApp.AuxRunInTransaction(func(txApp App) error {
|
||||
return archive.Create(txApp.DataDir(), tempPath, e.Exclude...)
|
||||
})
|
||||
})
|
||||
if createErr != nil {
|
||||
return createErr
|
||||
}
|
||||
defer os.Remove(tempPath)
|
||||
|
||||
// persist the backup in the backups filesystem
|
||||
// ---
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(e.Context)
|
||||
|
||||
file, err := filesystem.NewFileFromPath(tempPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.OriginalName = e.Name
|
||||
file.Name = file.OriginalName
|
||||
|
||||
if err := fsys.UploadFile(file, file.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if createErr != nil {
|
||||
return createErr
|
||||
}
|
||||
defer os.Remove(tempPath)
|
||||
|
||||
// Persist the backup in the backups filesystem.
|
||||
// ---
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
file, err := filesystem.NewFileFromPath(tempPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.OriginalName = name
|
||||
file.Name = file.OriginalName
|
||||
|
||||
if err := fsys.UploadFile(file, file.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreBackup restores the backup with the specified name and restarts
|
||||
|
|
@ -136,10 +138,6 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
|
|||
// If a failure occure during the restore process the dir changes are reverted.
|
||||
// If for whatever reason the revert is not possible, it panics.
|
||||
func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return errors.New("restore is not supported on windows")
|
||||
}
|
||||
|
||||
if app.Store().Has(StoreKeyActiveBackup) {
|
||||
return errors.New("try again later - another backup/restore operation has already been started")
|
||||
}
|
||||
|
|
@ -147,131 +145,129 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
|
|||
app.Store().Set(StoreKeyActiveBackup, name)
|
||||
defer app.Store().Remove(StoreKeyActiveBackup)
|
||||
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
event := new(BackupEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Name = name
|
||||
// default root dir entries to exclude from the backup restore
|
||||
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
// fetch the backup file in a temp location
|
||||
br, err := fsys.GetFile(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer br.Close()
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(app.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
// create a temp zip file from the blob.Reader and try to extract it
|
||||
tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempZip.Name())
|
||||
|
||||
if _, err := io.Copy(tempZip, br); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(4))
|
||||
defer os.RemoveAll(extractedDataDir)
|
||||
if err := archive.Extract(tempZip.Name(), extractedDataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure that a database file exists
|
||||
extractedDB := filepath.Join(extractedDataDir, "data.db")
|
||||
if _, err := os.Stat(extractedDB); err != nil {
|
||||
return fmt.Errorf("data.db file is missing or invalid: %w", err)
|
||||
}
|
||||
|
||||
// remove the extracted zip file since we no longer need it
|
||||
// (this is in case the app restarts and the defer calls are not called)
|
||||
if err := os.Remove(tempZip.Name()); err != nil {
|
||||
app.Logger().Debug(
|
||||
"[RestoreBackup] Failed to remove the temp zip backup file",
|
||||
slog.String("file", tempZip.Name()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
// root dir entries to exclude from the backup restore
|
||||
exclude := []string{LocalBackupsDirName, LocalTempDirName}
|
||||
|
||||
// move the current pb_data content to a special temp location
|
||||
// that will hold the old data between dirs replace
|
||||
// (the temp dir will be automatically removed on the next app start)
|
||||
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(4))
|
||||
if err := osutils.MoveDirContent(app.DataDir(), oldTempDataDir, exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
|
||||
}
|
||||
|
||||
// move the extracted archive content to the app's pb_data
|
||||
if err := osutils.MoveDirContent(extractedDataDir, app.DataDir(), exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
|
||||
}
|
||||
|
||||
revertDataDirChanges := func() error {
|
||||
if err := osutils.MoveDirContent(app.DataDir(), extractedDataDir, exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
|
||||
return app.OnBackupRestore().Trigger(event, func(e *BackupEvent) error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return errors.New("restore is not supported on Windows")
|
||||
}
|
||||
|
||||
if err := osutils.MoveDirContent(oldTempDataDir, app.DataDir(), exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(e.Context)
|
||||
|
||||
// fetch the backup file in a temp location
|
||||
br, err := fsys.GetFile(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer br.Close()
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
// create a temp zip file from the blob.Reader and try to extract it
|
||||
tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempZip.Name())
|
||||
|
||||
// restart the app
|
||||
if err := app.Restart(); err != nil {
|
||||
if revertErr := revertDataDirChanges(); revertErr != nil {
|
||||
panic(revertErr)
|
||||
if _, err := io.Copy(tempZip, br); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to restart the app process: %w", err)
|
||||
}
|
||||
extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(4))
|
||||
defer os.RemoveAll(extractedDataDir)
|
||||
if err := archive.Extract(tempZip.Name(), extractedDataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
// ensure that a database file exists
|
||||
extractedDB := filepath.Join(extractedDataDir, "data.db")
|
||||
if _, err := os.Stat(extractedDB); err != nil {
|
||||
return fmt.Errorf("data.db file is missing or invalid: %w", err)
|
||||
}
|
||||
|
||||
// initAutobackupHooks registers the autobackup app serve hooks.
|
||||
func (app *BaseApp) initAutobackupHooks() error {
|
||||
c := cron.New()
|
||||
isServe := false
|
||||
|
||||
loadJob := func() {
|
||||
c.Stop()
|
||||
|
||||
// make sure that app.Settings() is always up to date
|
||||
//
|
||||
// @todo remove with the refactoring as core.App and daos.Dao will be one.
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
app.Logger().Debug(
|
||||
"[Backup cron] Failed to get the latest app settings",
|
||||
// remove the extracted zip file since we no longer need it
|
||||
// (this is in case the app restarts and the defer calls are not called)
|
||||
if err := os.Remove(tempZip.Name()); err != nil {
|
||||
e.App.Logger().Debug(
|
||||
"[RestoreBackup] Failed to remove the temp zip backup file",
|
||||
slog.String("file", tempZip.Name()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
// move the current pb_data content to a special temp location
|
||||
// that will hold the old data between dirs replace
|
||||
// (the temp dir will be automatically removed on the next app start)
|
||||
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(4))
|
||||
if err := osutils.MoveDirContent(e.App.DataDir(), oldTempDataDir, e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
|
||||
}
|
||||
|
||||
// move the extracted archive content to the app's pb_data
|
||||
if err := osutils.MoveDirContent(extractedDataDir, e.App.DataDir(), e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
|
||||
}
|
||||
|
||||
revertDataDirChanges := func() error {
|
||||
if err := osutils.MoveDirContent(e.App.DataDir(), extractedDataDir, e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
|
||||
}
|
||||
|
||||
if err := osutils.MoveDirContent(oldTempDataDir, e.App.DataDir(), e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restart the app
|
||||
if err := e.App.Restart(); err != nil {
|
||||
if revertErr := revertDataDirChanges(); revertErr != nil {
|
||||
panic(revertErr)
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to restart the app process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// registerAutobackupHooks registers the autobackup app serve hooks.
|
||||
func (app *BaseApp) registerAutobackupHooks() {
|
||||
const jobId = "__auto_pb_backup__"
|
||||
|
||||
loadJob := func() {
|
||||
rawSchedule := app.Settings().Backups.Cron
|
||||
if rawSchedule == "" || !isServe || !app.IsBootstrapped() {
|
||||
if rawSchedule == "" {
|
||||
app.Cron().Remove(jobId)
|
||||
return
|
||||
}
|
||||
|
||||
c.Add("@autobackup", rawSchedule, func() {
|
||||
app.Cron().Add(jobId, rawSchedule, func() {
|
||||
const autoPrefix = "@auto_pb_backup_"
|
||||
|
||||
name := app.generateBackupName(autoPrefix)
|
||||
name := generateBackupName(app, autoPrefix)
|
||||
|
||||
if err := app.CreateBackup(context.Background(), name); err != nil {
|
||||
app.Logger().Debug(
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to create backup",
|
||||
slog.String("name", name),
|
||||
slog.String("error", err.Error()),
|
||||
|
|
@ -286,7 +282,7 @@ func (app *BaseApp) initAutobackupHooks() error {
|
|||
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to initialize the backup filesystem",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
|
|
@ -296,7 +292,7 @@ func (app *BaseApp) initAutobackupHooks() error {
|
|||
|
||||
files, err := fsys.List(autoPrefix)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to list autogenerated backups",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
|
|
@ -317,7 +313,7 @@ func (app *BaseApp) initAutobackupHooks() error {
|
|||
|
||||
for _, f := range toRemove {
|
||||
if err := fsys.Delete(f.Key); err != nil {
|
||||
app.Logger().Debug(
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to remove old autogenerated backup",
|
||||
slog.String("key", f.Key),
|
||||
slog.String("error", err.Error()),
|
||||
|
|
@ -325,29 +321,11 @@ func (app *BaseApp) initAutobackupHooks() error {
|
|||
}
|
||||
}
|
||||
})
|
||||
|
||||
// restart the ticker
|
||||
c.Start()
|
||||
}
|
||||
|
||||
// load on app serve
|
||||
app.OnBeforeServe().Add(func(e *ServeEvent) error {
|
||||
isServe = true
|
||||
loadJob()
|
||||
return nil
|
||||
})
|
||||
|
||||
// stop the ticker on app termination
|
||||
app.OnTerminate().Add(func(e *TerminateEvent) error {
|
||||
c.Stop()
|
||||
return nil
|
||||
})
|
||||
|
||||
// reload on app settings change
|
||||
app.OnModelAfterUpdate((&models.Param{}).TableName()).Add(func(e *ModelEvent) error {
|
||||
p := e.Model.(*models.Param)
|
||||
if p == nil || p.Key != models.ParamAppSettings {
|
||||
return nil
|
||||
app.OnBootstrap().BindFunc(func(e *BootstrapEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadJob()
|
||||
|
|
@ -355,10 +333,18 @@ func (app *BaseApp) initAutobackupHooks() error {
|
|||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
app.OnSettingsReload().BindFunc(func(e *SettingsReloadEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadJob()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (app *BaseApp) generateBackupName(prefix string) string {
|
||||
func generateBackupName(app App, prefix string) string {
|
||||
appName := inflector.Snakecase(app.Settings().Meta.AppName)
|
||||
if len(appName) > 50 {
|
||||
appName = appName[:50]
|
||||
|
|
|
|||
|
|
@ -128,9 +128,9 @@ func verifyBackupContent(app core.App, path string) error {
|
|||
"data.db",
|
||||
"data.db-shm",
|
||||
"data.db-wal",
|
||||
"logs.db",
|
||||
"logs.db-shm",
|
||||
"logs.db-wal",
|
||||
"aux.db",
|
||||
"aux.db-shm",
|
||||
"aux.db-wal",
|
||||
".gitignore",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,63 +0,0 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestBaseAppRefreshSettings(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// cleanup all stored settings
|
||||
if _, err := app.DB().NewQuery("DELETE from _params;").Execute(); err != nil {
|
||||
t.Fatalf("Failed to delete all test settings: %v", err)
|
||||
}
|
||||
|
||||
// check if the new settings are saved in the db
|
||||
app.ResetEventCalls()
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
t.Fatalf("Failed to refresh the settings after delete: %v", err)
|
||||
}
|
||||
testEventCalls(t, app, map[string]int{
|
||||
"OnModelBeforeCreate": 1,
|
||||
"OnModelAfterCreate": 1,
|
||||
})
|
||||
param, err := app.Dao().FindParamByKey(models.ParamAppSettings)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected new settings to be persisted, got %v", err)
|
||||
}
|
||||
|
||||
// change the db entry and refresh the app settings (ensure that there was no db update)
|
||||
param.Value = types.JsonRaw([]byte(`{"example": 123}`))
|
||||
if err := app.Dao().SaveParam(param.Key, param.Value); err != nil {
|
||||
t.Fatalf("Failed to update the test settings: %v", err)
|
||||
}
|
||||
app.ResetEventCalls()
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
t.Fatalf("Failed to refresh the app settings: %v", err)
|
||||
}
|
||||
testEventCalls(t, app, nil)
|
||||
|
||||
// try to refresh again without doing any changes
|
||||
app.ResetEventCalls()
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
t.Fatalf("Failed to refresh the app settings without change: %v", err)
|
||||
}
|
||||
testEventCalls(t, app, nil)
|
||||
}
|
||||
|
||||
func testEventCalls(t *testing.T, app *tests.TestApp, events map[string]int) {
|
||||
if len(events) != len(app.EventCalls) {
|
||||
t.Fatalf("Expected events doesn't match: \n%v, \ngot \n%v", events, app.EventCalls)
|
||||
}
|
||||
|
||||
for name, total := range events {
|
||||
if v, ok := app.EventCalls[name]; !ok || v != total {
|
||||
t.Fatalf("Expected events doesn't exist or match: \n%v, \ngot \n%v", events, app.EventCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,59 +1,56 @@
|
|||
package core
|
||||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/migrations"
|
||||
"github.com/pocketbase/pocketbase/migrations/logs"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/logger"
|
||||
"github.com/pocketbase/pocketbase/tools/mailer"
|
||||
"github.com/pocketbase/pocketbase/tools/migrate"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewBaseApp(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := NewBaseApp(BaseAppConfig{
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "test_env",
|
||||
IsDev: true,
|
||||
})
|
||||
|
||||
if app.dataDir != testDataDir {
|
||||
t.Fatalf("expected dataDir %q, got %q", testDataDir, app.dataDir)
|
||||
if app.DataDir() != testDataDir {
|
||||
t.Fatalf("expected DataDir %q, got %q", testDataDir, app.DataDir())
|
||||
}
|
||||
|
||||
if app.encryptionEnv != "test_env" {
|
||||
t.Fatalf("expected encryptionEnv test_env, got %q", app.dataDir)
|
||||
if app.EncryptionEnv() != "test_env" {
|
||||
t.Fatalf("expected EncryptionEnv test_env, got %q", app.EncryptionEnv())
|
||||
}
|
||||
|
||||
if !app.isDev {
|
||||
t.Fatalf("expected isDev true, got %v", app.isDev)
|
||||
if !app.IsDev() {
|
||||
t.Fatalf("expected IsDev true, got %v", app.IsDev())
|
||||
}
|
||||
|
||||
if app.store == nil {
|
||||
t.Fatal("expected store to be set, got nil")
|
||||
if app.Store() == nil {
|
||||
t.Fatal("expected Store to be set, got nil")
|
||||
}
|
||||
|
||||
if app.settings == nil {
|
||||
t.Fatal("expected settings to be set, got nil")
|
||||
if app.Settings() == nil {
|
||||
t.Fatal("expected Settings to be set, got nil")
|
||||
}
|
||||
|
||||
if app.subscriptionsBroker == nil {
|
||||
t.Fatal("expected subscriptionsBroker to be set, got nil")
|
||||
if app.SubscriptionsBroker() == nil {
|
||||
t.Fatal("expected SubscriptionsBroker to be set, got nil")
|
||||
}
|
||||
|
||||
if app.Cron() == nil {
|
||||
t.Fatal("expected Cron to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -61,9 +58,8 @@ func TestBaseAppBootstrap(t *testing.T) {
|
|||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := NewBaseApp(BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "pb_test_env",
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
|
|
@ -83,72 +79,59 @@ func TestBaseAppBootstrap(t *testing.T) {
|
|||
t.Fatal("Expected test data directory to be created.")
|
||||
}
|
||||
|
||||
if app.dao == nil {
|
||||
t.Fatal("Expected app.dao to be initialized, got nil.")
|
||||
type nilCheck struct {
|
||||
name string
|
||||
value any
|
||||
expectNil bool
|
||||
}
|
||||
|
||||
if app.dao.BeforeCreateFunc == nil {
|
||||
t.Fatal("Expected app.dao.BeforeCreateFunc to be set, got nil.")
|
||||
runNilChecks := func(checks []nilCheck) {
|
||||
for _, check := range checks {
|
||||
t.Run(check.name, func(t *testing.T) {
|
||||
isNil := check.value == nil
|
||||
if isNil != check.expectNil {
|
||||
t.Fatalf("Expected isNil %v, got %v", check.expectNil, isNil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if app.dao.AfterCreateFunc == nil {
|
||||
t.Fatal("Expected app.dao.AfterCreateFunc to be set, got nil.")
|
||||
nilChecksBeforeReset := []nilCheck{
|
||||
{"[before] concurrentDB", app.DB(), false},
|
||||
{"[before] nonconcurrentDB", app.NonconcurrentDB(), false},
|
||||
{"[before] auxConcurrentDB", app.AuxDB(), false},
|
||||
{"[before] auxNonconcurrentDB", app.AuxNonconcurrentDB(), false},
|
||||
{"[before] settings", app.Settings(), false},
|
||||
{"[before] logger", app.Logger(), false},
|
||||
{"[before] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
|
||||
}
|
||||
|
||||
if app.dao.BeforeUpdateFunc == nil {
|
||||
t.Fatal("Expected app.dao.BeforeUpdateFunc to be set, got nil.")
|
||||
}
|
||||
|
||||
if app.dao.AfterUpdateFunc == nil {
|
||||
t.Fatal("Expected app.dao.AfterUpdateFunc to be set, got nil.")
|
||||
}
|
||||
|
||||
if app.dao.BeforeDeleteFunc == nil {
|
||||
t.Fatal("Expected app.dao.BeforeDeleteFunc to be set, got nil.")
|
||||
}
|
||||
|
||||
if app.dao.AfterDeleteFunc == nil {
|
||||
t.Fatal("Expected app.dao.AfterDeleteFunc to be set, got nil.")
|
||||
}
|
||||
|
||||
if app.logsDao == nil {
|
||||
t.Fatal("Expected app.logsDao to be initialized, got nil.")
|
||||
}
|
||||
|
||||
if app.settings == nil {
|
||||
t.Fatal("Expected app.settings to be initialized, got nil.")
|
||||
}
|
||||
|
||||
if app.logger == nil {
|
||||
t.Fatal("Expected app.logger to be initialized, got nil.")
|
||||
}
|
||||
|
||||
if _, ok := app.logger.Handler().(*logger.BatchHandler); !ok {
|
||||
t.Fatal("Expected app.logger handler to be initialized.")
|
||||
}
|
||||
runNilChecks(nilChecksBeforeReset)
|
||||
|
||||
// reset
|
||||
if err := app.ResetBootstrapState(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if app.dao != nil {
|
||||
t.Fatalf("Expected app.dao to be nil, got %v.", app.dao)
|
||||
nilChecksAfterReset := []nilCheck{
|
||||
{"[after] concurrentDB", app.DB(), true},
|
||||
{"[after] nonconcurrentDB", app.NonconcurrentDB(), true},
|
||||
{"[after] auxConcurrentDB", app.AuxDB(), true},
|
||||
{"[after] auxNonconcurrentDB", app.AuxNonconcurrentDB(), true},
|
||||
{"[after] settings", app.Settings(), false},
|
||||
{"[after] logger", app.Logger(), false},
|
||||
{"[after] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
|
||||
}
|
||||
|
||||
if app.logsDao != nil {
|
||||
t.Fatalf("Expected app.logsDao to be nil, got %v.", app.logsDao)
|
||||
}
|
||||
runNilChecks(nilChecksAfterReset)
|
||||
}
|
||||
|
||||
func TestBaseAppGetters(t *testing.T) {
|
||||
func TestNewBaseAppIsTransactional(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := NewBaseApp(BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "pb_test_env",
|
||||
IsDev: true,
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
|
|
@ -156,81 +139,58 @@ func TestBaseAppGetters(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if app.dao != app.Dao() {
|
||||
t.Fatalf("Expected app.Dao %v, got %v", app.Dao(), app.dao)
|
||||
if app.IsTransactional() {
|
||||
t.Fatalf("Didn't expect the app to be transactional")
|
||||
}
|
||||
|
||||
if app.dao.ConcurrentDB() != app.DB() {
|
||||
t.Fatalf("Expected app.DB %v, got %v", app.DB(), app.dao.ConcurrentDB())
|
||||
}
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
if !txApp.IsTransactional() {
|
||||
t.Fatalf("Expected the app to be transactional")
|
||||
}
|
||||
|
||||
if app.logsDao != app.LogsDao() {
|
||||
t.Fatalf("Expected app.LogsDao %v, got %v", app.LogsDao(), app.logsDao)
|
||||
}
|
||||
|
||||
if app.logsDao.ConcurrentDB() != app.LogsDB() {
|
||||
t.Fatalf("Expected app.LogsDB %v, got %v", app.LogsDB(), app.logsDao.ConcurrentDB())
|
||||
}
|
||||
|
||||
if app.dataDir != app.DataDir() {
|
||||
t.Fatalf("Expected app.DataDir %v, got %v", app.DataDir(), app.dataDir)
|
||||
}
|
||||
|
||||
if app.encryptionEnv != app.EncryptionEnv() {
|
||||
t.Fatalf("Expected app.EncryptionEnv %v, got %v", app.EncryptionEnv(), app.encryptionEnv)
|
||||
}
|
||||
|
||||
if app.isDev != app.IsDev() {
|
||||
t.Fatalf("Expected app.IsDev %v, got %v", app.IsDev(), app.isDev)
|
||||
}
|
||||
|
||||
if app.settings != app.Settings() {
|
||||
t.Fatalf("Expected app.Settings %v, got %v", app.Settings(), app.settings)
|
||||
}
|
||||
|
||||
if app.store != app.Store() {
|
||||
t.Fatalf("Expected app.Store %v, got %v", app.Store(), app.store)
|
||||
}
|
||||
|
||||
if app.logger != app.Logger() {
|
||||
t.Fatalf("Expected app.Logger %v, got %v", app.Logger(), app.logger)
|
||||
}
|
||||
|
||||
if app.subscriptionsBroker != app.SubscriptionsBroker() {
|
||||
t.Fatalf("Expected app.SubscriptionsBroker %v, got %v", app.SubscriptionsBroker(), app.subscriptionsBroker)
|
||||
}
|
||||
|
||||
if app.onBeforeServe != app.OnBeforeServe() || app.OnBeforeServe() == nil {
|
||||
t.Fatalf("Getter app.OnBeforeServe does not match or nil (%v vs %v)", app.OnBeforeServe(), app.onBeforeServe)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseAppNewMailClient(t *testing.T) {
|
||||
app, cleanup, err := initTestBaseApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "pb_test_env",
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
client1 := app.NewMailClient()
|
||||
if val, ok := client1.(*mailer.Sendmail); !ok {
|
||||
t.Fatalf("Expected mailer.Sendmail instance, got %v", val)
|
||||
m1, ok := client1.(*mailer.Sendmail)
|
||||
if !ok {
|
||||
t.Fatalf("Expected mailer.Sendmail instance, got %v", m1)
|
||||
}
|
||||
if m1.OnSend() == nil || m1.OnSend().Length() == 0 {
|
||||
t.Fatal("Expected OnSend hook to be registered")
|
||||
}
|
||||
|
||||
app.Settings().Smtp.Enabled = true
|
||||
app.Settings().SMTP.Enabled = true
|
||||
|
||||
client2 := app.NewMailClient()
|
||||
if val, ok := client2.(*mailer.SmtpClient); !ok {
|
||||
t.Fatalf("Expected mailer.SmtpClient instance, got %v", val)
|
||||
m2, ok := client2.(*mailer.SMTPClient)
|
||||
if !ok {
|
||||
t.Fatalf("Expected mailer.SMTPClient instance, got %v", m2)
|
||||
}
|
||||
if m2.OnSend() == nil || m2.OnSend().Length() == 0 {
|
||||
t.Fatal("Expected OnSend hook to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppNewFilesystem(t *testing.T) {
|
||||
app, cleanup, err := initTestBaseApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
// local
|
||||
local, localErr := app.NewFilesystem()
|
||||
|
|
@ -253,11 +213,13 @@ func TestBaseAppNewFilesystem(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBaseAppNewBackupsFilesystem(t *testing.T) {
|
||||
app, cleanup, err := initTestBaseApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
// local
|
||||
local, localErr := app.NewBackupsFilesystem()
|
||||
|
|
@ -280,18 +242,22 @@ func TestBaseAppNewBackupsFilesystem(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBaseAppLoggerWrites(t *testing.T) {
|
||||
app, cleanup, err := initTestBaseApp()
|
||||
if err != nil {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// reset
|
||||
if err := app.DeleteOldLogs(time.Now()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
const logsThreshold = 200
|
||||
|
||||
totalLogs := func(app App, t *testing.T) int {
|
||||
totalLogs := func(app core.App, t *testing.T) int {
|
||||
var total int
|
||||
|
||||
err := app.LogsDao().LogQuery().Select("count(*)").Row(&total)
|
||||
err := app.LogQuery().Select("count(*)").Row(&total)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch total logs: %v", err)
|
||||
}
|
||||
|
|
@ -338,106 +304,9 @@ func TestBaseAppLoggerWrites(t *testing.T) {
|
|||
t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test batch logs delete", func(t *testing.T) {
|
||||
app.Settings().Logs.MaxDays = 2
|
||||
|
||||
deleteQueries := 0
|
||||
|
||||
// reset
|
||||
app.Store().Set("lastLogsDeletedAt", time.Now())
|
||||
if err := app.LogsDao().DeleteOldLogs(time.Now()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db := app.LogsDao().NonconcurrentDB().(*dbx.DB)
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
if strings.Contains(sql, "DELETE") {
|
||||
deleteQueries++
|
||||
}
|
||||
}
|
||||
|
||||
// trigger batch write (A)
|
||||
expectedLogs := logsThreshold
|
||||
for i := 0; i < expectedLogs; i++ {
|
||||
app.Logger().Error("testA")
|
||||
}
|
||||
|
||||
if total := totalLogs(app, t); total != expectedLogs {
|
||||
t.Fatalf("[batch write A] Expected %d logs, got %d", expectedLogs, total)
|
||||
}
|
||||
|
||||
// mark the A inserted logs as 2-day expired
|
||||
aExpiredDate, err := types.ParseDateTime(time.Now().AddDate(0, 0, -2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = app.LogsDao().NonconcurrentDB().NewQuery("UPDATE _logs SET created={:date}, updated={:date}").Bind(dbx.Params{
|
||||
"date": aExpiredDate.String(),
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to mock logs timestamp fields: %v", err)
|
||||
}
|
||||
|
||||
// simulate recently deleted logs
|
||||
app.Store().Set("lastLogsDeletedAt", time.Now().Add(-5*time.Hour))
|
||||
|
||||
// trigger batch write (B)
|
||||
for i := 0; i < logsThreshold; i++ {
|
||||
app.Logger().Error("testB")
|
||||
}
|
||||
|
||||
expectedLogs = 2 * logsThreshold
|
||||
|
||||
// note: even though there are expired logs it shouldn't perform the delete operation because of the lastLogsDeledAt time
|
||||
if total := totalLogs(app, t); total != expectedLogs {
|
||||
t.Fatalf("[batch write B] Expected %d logs, got %d", expectedLogs, total)
|
||||
}
|
||||
|
||||
// mark the B inserted logs as 1-day expired to ensure that they will not be deleted
|
||||
bExpiredDate, err := types.ParseDateTime(time.Now().AddDate(0, 0, -1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = app.LogsDao().NonconcurrentDB().NewQuery("UPDATE _logs SET created={:date}, updated={:date} where message='testB'").Bind(dbx.Params{
|
||||
"date": bExpiredDate.String(),
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to mock logs timestamp fields: %v", err)
|
||||
}
|
||||
|
||||
// should trigger delete on the next batch write
|
||||
app.Store().Set("lastLogsDeletedAt", time.Now().Add(-6*time.Hour))
|
||||
|
||||
// trigger batch write (C)
|
||||
for i := 0; i < logsThreshold; i++ {
|
||||
app.Logger().Error("testC")
|
||||
}
|
||||
|
||||
expectedLogs = 2 * logsThreshold // only B and C logs should remain
|
||||
|
||||
if total := totalLogs(app, t); total != expectedLogs {
|
||||
t.Fatalf("[batch write C] Expected %d logs, got %d", expectedLogs, total)
|
||||
}
|
||||
|
||||
if deleteQueries != 1 {
|
||||
t.Fatalf("Expected DeleteOldLogs to be called %d, got %d", 1, deleteQueries)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
|
||||
app, cleanup, err := initTestBaseApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
isDev bool
|
||||
|
|
@ -469,173 +338,35 @@ func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
|
|||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app.isDev = s.isDev
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
IsDev: s.isDev,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
|
||||
}
|
||||
|
||||
app.Settings().Logs.MinLevel = s.level
|
||||
|
||||
if err := app.Dao().SaveSettings(app.Settings()); err != nil {
|
||||
if err := app.Save(app.Settings()); err != nil {
|
||||
t.Fatalf("Failed to save settings: %v", err)
|
||||
}
|
||||
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
t.Fatalf("Failed to refresh app settings: %v", err)
|
||||
}
|
||||
|
||||
for level, enabled := range s.expectations {
|
||||
if v := handler.Enabled(nil, slog.Level(level)); v != enabled {
|
||||
if v := handler.Enabled(context.Background(), slog.Level(level)); v != enabled {
|
||||
t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppLoggerLevelDevPrint(t *testing.T) {
|
||||
app, cleanup, err := initTestBaseApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
testLogLevel := 4
|
||||
|
||||
app.Settings().Logs.MinLevel = testLogLevel
|
||||
if err := app.Dao().SaveSettings(app.Settings()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
isDev bool
|
||||
levels []int
|
||||
printedLevels []int
|
||||
persistedLevels []int
|
||||
}{
|
||||
{
|
||||
"dev mode",
|
||||
true,
|
||||
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
|
||||
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
|
||||
[]int{testLogLevel, testLogLevel + 1},
|
||||
},
|
||||
{
|
||||
"nondev mode",
|
||||
false,
|
||||
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
|
||||
[]int{},
|
||||
[]int{testLogLevel, testLogLevel + 1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
var printedLevels []int
|
||||
var persistedLevels []int
|
||||
|
||||
app.isDev = s.isDev
|
||||
|
||||
// trigger slog handler min level refresh
|
||||
if err := app.RefreshSettings(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// track printed logs
|
||||
originalPrintLog := printLog
|
||||
defer func() {
|
||||
printLog = originalPrintLog
|
||||
}()
|
||||
printLog = func(log *logger.Log) {
|
||||
printedLevels = append(printedLevels, int(log.Level))
|
||||
}
|
||||
|
||||
// track persisted logs
|
||||
app.LogsDao().AfterCreateFunc = func(eventDao *daos.Dao, m models.Model) error {
|
||||
l, ok := m.(*models.Log)
|
||||
if ok {
|
||||
persistedLevels = append(persistedLevels, l.Level)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// write and persist logs
|
||||
for _, l := range s.levels {
|
||||
app.Logger().Log(nil, slog.Level(l), "test")
|
||||
}
|
||||
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
|
||||
}
|
||||
if err := handler.WriteAll(nil); err != nil {
|
||||
t.Fatalf("Failed to write all logs: %v", err)
|
||||
}
|
||||
|
||||
// check persisted log levels
|
||||
if len(s.persistedLevels) != len(persistedLevels) {
|
||||
t.Fatalf("Expected persisted levels \n%v\ngot\n%v", s.persistedLevels, persistedLevels)
|
||||
}
|
||||
for _, l := range persistedLevels {
|
||||
if !list.ExistInSlice(l, s.persistedLevels) {
|
||||
t.Fatalf("Missing expected persisted level %v in %v", l, persistedLevels)
|
||||
}
|
||||
}
|
||||
|
||||
// check printed log levels
|
||||
if len(s.printedLevels) != len(printedLevels) {
|
||||
t.Fatalf("Expected printed levels \n%v\ngot\n%v", s.printedLevels, printedLevels)
|
||||
}
|
||||
for _, l := range printedLevels {
|
||||
if !list.ExistInSlice(l, s.printedLevels) {
|
||||
t.Fatalf("Missing expected printed level %v in %v", l, printedLevels)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// note: make sure to call `defer cleanup()` when the app is no longer needed.
|
||||
func initTestBaseApp() (app *BaseApp, cleanup func(), err error) {
|
||||
testDataDir, err := os.MkdirTemp("", "test_base_app")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
os.RemoveAll(testDataDir)
|
||||
}
|
||||
|
||||
app = NewBaseApp(BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
|
||||
initErr := func() error {
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
return fmt.Errorf("bootstrap error: %w", err)
|
||||
}
|
||||
|
||||
logsRunner, err := migrate.NewRunner(app.LogsDB(), logs.LogsMigrations)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logsRunner error: %w", err)
|
||||
}
|
||||
if _, err := logsRunner.Up(); err != nil {
|
||||
return fmt.Errorf("logsRunner migrations execution error: %w", err)
|
||||
}
|
||||
|
||||
dataRunner, err := migrate.NewRunner(app.DB(), migrations.AppMigrations)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logsRunner error: %w", err)
|
||||
}
|
||||
if _, err := dataRunner.Up(); err != nil {
|
||||
return fmt.Errorf("dataRunner migrations execution error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if initErr != nil {
|
||||
cleanup()
|
||||
return nil, nil, initErr
|
||||
}
|
||||
|
||||
return app, cleanup, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,194 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// ImportCollectionsByMarshaledJSON is the same as [ImportCollections]
|
||||
// but accept marshaled json array as import data (usually used for the autogenerated snapshots).
|
||||
func (app *BaseApp) ImportCollectionsByMarshaledJSON(rawSliceOfMaps []byte, deleteMissing bool) error {
|
||||
data := []map[string]any{}
|
||||
|
||||
err := json.Unmarshal(rawSliceOfMaps, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return app.ImportCollections(data, deleteMissing)
|
||||
}
|
||||
|
||||
// ImportCollections imports the provided collections data in a single transaction.
|
||||
//
|
||||
// For existing matching collections, the imported data is unmarshaled on top of the existing model.
|
||||
//
|
||||
// NB! If deleteMissing is true, ALL NON-SYSTEM COLLECTIONS AND SCHEMA FIELDS,
|
||||
// that are not present in the imported configuration, WILL BE DELETED
|
||||
// (this includes their related records data).
|
||||
func (app *BaseApp) ImportCollections(toImport []map[string]any, deleteMissing bool) error {
|
||||
if len(toImport) == 0 {
|
||||
// prevent accidentally deleting all collections
|
||||
return errors.New("no collections to import")
|
||||
}
|
||||
|
||||
importedCollections := make([]*Collection, len(toImport))
|
||||
mappedImported := make(map[string]*Collection, len(toImport))
|
||||
|
||||
// normalize imported collections data to ensure that all
|
||||
// collection fields are present and properly initialized
|
||||
for i, data := range toImport {
|
||||
var imported *Collection
|
||||
|
||||
identifier := cast.ToString(data["id"])
|
||||
if identifier == "" {
|
||||
identifier = cast.ToString(data["name"])
|
||||
}
|
||||
|
||||
existing, err := app.FindCollectionByNameOrId(identifier)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// refetch for deep copy
|
||||
imported, err = app.FindCollectionByNameOrId(existing.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure that the fields will be cleared
|
||||
if data["fields"] == nil && deleteMissing {
|
||||
data["fields"] = []map[string]any{}
|
||||
}
|
||||
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the imported data
|
||||
err = json.Unmarshal(rawData, imported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// extend with the existing fields if necessary
|
||||
for _, f := range existing.Fields {
|
||||
if !f.GetSystem() && deleteMissing {
|
||||
continue
|
||||
}
|
||||
if imported.Fields.GetById(f.GetId()) == nil {
|
||||
imported.Fields.Add(f)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
imported = &Collection{}
|
||||
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the imported data
|
||||
err = json.Unmarshal(rawData, imported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
imported.IntegrityChecks(false)
|
||||
|
||||
importedCollections[i] = imported
|
||||
mappedImported[imported.Id] = imported
|
||||
}
|
||||
|
||||
// reorder views last since the view query could depend on some of the other collections
|
||||
slices.SortStableFunc(importedCollections, func(a, b *Collection) int {
|
||||
cmpA := -1
|
||||
if a.IsView() {
|
||||
cmpA = 1
|
||||
}
|
||||
|
||||
cmpB := -1
|
||||
if b.IsView() {
|
||||
cmpB = 1
|
||||
}
|
||||
|
||||
res := cmp.Compare(cmpA, cmpB)
|
||||
if res == 0 {
|
||||
res = a.Created.Compare(b.Created)
|
||||
if res == 0 {
|
||||
res = a.Updated.Compare(b.Updated)
|
||||
}
|
||||
}
|
||||
return res
|
||||
})
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
existingCollections := []*Collection{}
|
||||
if err := txApp.CollectionQuery().OrderBy("updated ASC").All(&existingCollections); err != nil {
|
||||
return err
|
||||
}
|
||||
mappedExisting := make(map[string]*Collection, len(existingCollections))
|
||||
for _, existing := range existingCollections {
|
||||
existing.IntegrityChecks(false)
|
||||
mappedExisting[existing.Id] = existing
|
||||
}
|
||||
|
||||
// delete old collections not available in the new configuration
|
||||
// (before saving the imports in case a deleted collection name is being reused)
|
||||
if deleteMissing {
|
||||
for _, existing := range existingCollections {
|
||||
if mappedImported[existing.Id] != nil || existing.System {
|
||||
continue // exist or system
|
||||
}
|
||||
|
||||
// delete collection
|
||||
if err := txApp.Delete(existing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// upsert imported collections
|
||||
for _, imported := range importedCollections {
|
||||
if err := txApp.SaveNoValidate(imported); err != nil {
|
||||
return fmt.Errorf("failed to save collection %q: %w", imported.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// run validations
|
||||
for _, imported := range importedCollections {
|
||||
original := mappedExisting[imported.Id]
|
||||
if original == nil {
|
||||
original = imported
|
||||
}
|
||||
|
||||
validator := newCollectionValidator(
|
||||
context.Background(),
|
||||
txApp,
|
||||
imported,
|
||||
original,
|
||||
)
|
||||
if err := validator.run(); err != nil {
|
||||
// serialize the validation error(s)
|
||||
serializedErr, _ := json.MarshalIndent(err, "", " ")
|
||||
|
||||
return validation.Errors{"collections": validation.NewError(
|
||||
"validation_collections_import_failure",
|
||||
fmt.Sprintf("Data validations failed for collection %q (%s):\n%s", imported.Name, imported.Id, serializedErr),
|
||||
)}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,476 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestImportCollections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
var regularCollections []*core.Collection
|
||||
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(®ularCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var systemCollections []*core.Collection
|
||||
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalRegularCollections := len(regularCollections)
|
||||
totalSystemCollections := len(systemCollections)
|
||||
totalCollections := totalRegularCollections + totalSystemCollections
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data []map[string]any
|
||||
deleteMissing bool
|
||||
expectError bool
|
||||
expectCollectionsCount int
|
||||
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
|
||||
}{
|
||||
{
|
||||
name: "empty collections",
|
||||
data: []map[string]any{},
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (with missing system fields)",
|
||||
data: []map[string]any{
|
||||
{"name": "import_test1", "type": "auth"},
|
||||
{
|
||||
"name": "import_test2", "fields": []map[string]any{
|
||||
{"name": "test", "type": "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalCollections + 2,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (trigger collection model validations)",
|
||||
data: []map[string]any{
|
||||
{"name": ""},
|
||||
{
|
||||
"name": "import_test2", "fields": []map[string]any{
|
||||
{"name": "test", "type": "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (trigger field settings validation)",
|
||||
data: []map[string]any{
|
||||
{"name": "import_test", "fields": []map[string]any{{"name": "test", "type": "text", "min": -1}}},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "new + update + delete (system collections delete should be ignored)",
|
||||
data: []map[string]any{
|
||||
{
|
||||
"id": "wsmn24bux7wo113",
|
||||
"name": "demo",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
},
|
||||
"indexes": []string{},
|
||||
},
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"name": "active",
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: true,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalSystemCollections + 2,
|
||||
},
|
||||
{
|
||||
name: "test with deleteMissing: false",
|
||||
data: []map[string]any{
|
||||
{
|
||||
// "id": "wsmn24bux7wo113", // test update with only name as identifier
|
||||
"name": "demo1",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "field_with_duplicate_id",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"unique": false,
|
||||
"min": 4,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
{
|
||||
"id": "abcd_import",
|
||||
"name": "new_field",
|
||||
"type": "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "new_import",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "abcd_import",
|
||||
"name": "active",
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalCollections + 1,
|
||||
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*core.Collection) {
|
||||
expectedCollectionFields := map[string]int{
|
||||
core.CollectionNameAuthOrigins: 6,
|
||||
"nologin": 10,
|
||||
"demo1": 18,
|
||||
"demo2": 5,
|
||||
"demo3": 5,
|
||||
"demo4": 16,
|
||||
"demo5": 9,
|
||||
"new_import": 2,
|
||||
}
|
||||
for name, expectedCount := range expectedCollectionFields {
|
||||
collection, err := testApp.FindCollectionByNameOrId(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if totalFields := len(collection.Fields); totalFields != expectedCount {
|
||||
t.Errorf("Expected %d %q fields, got %d", expectedCount, collection.Name, totalFields)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollections(s.data, s.deleteMissing)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// check collections count
|
||||
collections := []*core.Collection{}
|
||||
if err := testApp.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(collections) != s.expectCollectionsCount {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
|
||||
}
|
||||
|
||||
if s.afterTestFunc != nil {
|
||||
s.afterTestFunc(testApp, collections)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsByMarshaledJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
var regularCollections []*core.Collection
|
||||
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(®ularCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var systemCollections []*core.Collection
|
||||
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalRegularCollections := len(regularCollections)
|
||||
totalSystemCollections := len(systemCollections)
|
||||
totalCollections := totalRegularCollections + totalSystemCollections
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data string
|
||||
deleteMissing bool
|
||||
expectError bool
|
||||
expectCollectionsCount int
|
||||
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
|
||||
}{
|
||||
{
|
||||
name: "invalid json array",
|
||||
data: `{"test":123}`,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "new + update + delete (system collections delete should be ignored)",
|
||||
data: `[
|
||||
{
|
||||
"id": "wsmn24bux7wo113",
|
||||
"name": "demo",
|
||||
"fields": [
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": null,
|
||||
"pattern": ""
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
},
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"name": "active",
|
||||
"type": "bool"
|
||||
}
|
||||
]
|
||||
}
|
||||
]`,
|
||||
deleteMissing: true,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalSystemCollections + 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollectionsByMarshaledJSON([]byte(s.data), s.deleteMissing)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// check collections count
|
||||
collections := []*core.Collection{}
|
||||
if err := testApp.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(collections) != s.expectCollectionsCount {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
|
||||
}
|
||||
|
||||
if s.afterTestFunc != nil {
|
||||
s.afterTestFunc(testApp, collections)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsUpdateRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
deleteMissing bool
|
||||
}{
|
||||
{
|
||||
"extend existing by name (without deleteMissing)",
|
||||
map[string]any{"name": "clients", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"extend existing by id (without deleteMissing)",
|
||||
map[string]any{"id": "v851q4r790rhknl", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"extend with delete missing",
|
||||
map[string]any{
|
||||
"id": "v851q4r790rhknl",
|
||||
"authToken": map[string]any{"duration": 100},
|
||||
"fields": []map[string]any{{"name": "test", "type": "text"}},
|
||||
"passwordAuth": map[string]any{"identityFields": []string{"email"}},
|
||||
"indexes": []string{
|
||||
// min required system fields indexes
|
||||
"CREATE UNIQUE INDEX `_v851q4r790rhknl_email_idx` ON `clients` (email) WHERE email != ''",
|
||||
"CREATE UNIQUE INDEX `_v851q4r790rhknl_tokenKey_idx` ON `clients` (tokenKey)",
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
beforeCollection, err := testApp.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = testApp.ImportCollections([]map[string]any{s.data}, s.deleteMissing)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
afterCollection, err := testApp.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if afterCollection.AuthToken.Duration != 100 {
|
||||
t.Fatalf("Expected AuthToken duration to be %d, got %d", 100, afterCollection.AuthToken.Duration)
|
||||
}
|
||||
if beforeCollection.AuthToken.Secret != afterCollection.AuthToken.Secret {
|
||||
t.Fatalf("Expected AuthToken secrets to remain the same, got\n%q\nVS\n%q", beforeCollection.AuthToken.Secret, afterCollection.AuthToken.Secret)
|
||||
}
|
||||
if beforeCollection.Name != afterCollection.Name {
|
||||
t.Fatalf("Expected Name to remain the same, got\n%q\nVS\n%q", beforeCollection.Name, afterCollection.Name)
|
||||
}
|
||||
if beforeCollection.Id != afterCollection.Id {
|
||||
t.Fatalf("Expected Id to remain the same, got\n%q\nVS\n%q", beforeCollection.Id, afterCollection.Id)
|
||||
}
|
||||
|
||||
if !s.deleteMissing {
|
||||
totalExpectedFields := len(beforeCollection.Fields) + 1
|
||||
if v := len(afterCollection.Fields); v != totalExpectedFields {
|
||||
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
|
||||
}
|
||||
|
||||
if afterCollection.Fields.GetByName("test") == nil {
|
||||
t.Fatalf("Missing new field %q", "test")
|
||||
}
|
||||
|
||||
// ensure that the old fields still exist
|
||||
oldFields := beforeCollection.Fields.FieldNames()
|
||||
for _, name := range oldFields {
|
||||
if afterCollection.Fields.GetByName(name) == nil {
|
||||
t.Fatalf("Missing expected old field %q", name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
totalExpectedFields := 1
|
||||
for _, f := range beforeCollection.Fields {
|
||||
if f.GetSystem() {
|
||||
totalExpectedFields++
|
||||
}
|
||||
}
|
||||
|
||||
if v := len(afterCollection.Fields); v != totalExpectedFields {
|
||||
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
|
||||
}
|
||||
|
||||
if afterCollection.Fields.GetByName("test") == nil {
|
||||
t.Fatalf("Missing new field %q", "test")
|
||||
}
|
||||
|
||||
// ensure that the old system fields still exist
|
||||
for _, f := range beforeCollection.Fields {
|
||||
if f.GetSystem() && afterCollection.Fields.GetByName(f.GetName()) == nil {
|
||||
t.Fatalf("Missing expected old field %q", f.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsCreateRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollections([]map[string]any{
|
||||
{"name": "new_test", "type": "auth", "authToken": map[string]any{"duration": 123}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
collection, err := testApp.FindCollectionByNameOrId("new_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expectedParts := []string{
|
||||
`"name":"new_test"`,
|
||||
`"fields":[`,
|
||||
`"name":"id"`,
|
||||
`"name":"email"`,
|
||||
`"name":"tokenKey"`,
|
||||
`"name":"password"`,
|
||||
`"name":"test"`,
|
||||
`"indexes":[`,
|
||||
`CREATE UNIQUE INDEX`,
|
||||
`"duration":123`,
|
||||
}
|
||||
|
||||
for _, part := range expectedParts {
|
||||
if !strings.Contains(rawStr, part) {
|
||||
t.Errorf("Missing %q in\n%s", part, rawStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,949 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Model = (*Collection)(nil)
|
||||
_ DBExporter = (*Collection)(nil)
|
||||
_ FilesManager = (*Collection)(nil)
|
||||
)
|
||||
|
||||
const (
|
||||
CollectionTypeBase = "base"
|
||||
CollectionTypeAuth = "auth"
|
||||
CollectionTypeView = "view"
|
||||
)
|
||||
|
||||
const systemHookIdCollection = "__pbCollectionSystemHook__"
|
||||
|
||||
func (app *BaseApp) registerCollectionHooks() {
|
||||
app.OnModelValidate().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionValidate().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelCreate().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionCreate().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelCreateExecute().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionCreateExecute().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionAfterCreateSuccess().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterCreateError().Bind(&hook.Handler[*ModelErrorEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelErrorEvent) error {
|
||||
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
|
||||
return me.App.OnCollectionAfterCreateError().Trigger(ce, func(ce *CollectionErrorEvent) error {
|
||||
syncModelErrorEventWithCollectionErrorEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelUpdate().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionUpdate().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelUpdateExecute().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionUpdateExecute().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionAfterUpdateSuccess().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterUpdateError().Bind(&hook.Handler[*ModelErrorEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelErrorEvent) error {
|
||||
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
|
||||
return me.App.OnCollectionAfterUpdateError().Trigger(ce, func(ce *CollectionErrorEvent) error {
|
||||
syncModelErrorEventWithCollectionErrorEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelDelete().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionDelete().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelDeleteExecute().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionDeleteExecute().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*ModelEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelEvent) error {
|
||||
if ce, ok := newCollectionEventFromModelEvent(me); ok {
|
||||
return me.App.OnCollectionAfterDeleteSuccess().Trigger(ce, func(ce *CollectionEvent) error {
|
||||
syncModelEventWithCollectionEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterDeleteError().Bind(&hook.Handler[*ModelErrorEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(me *ModelErrorEvent) error {
|
||||
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
|
||||
return me.App.OnCollectionAfterDeleteError().Trigger(ce, func(ce *CollectionErrorEvent) error {
|
||||
syncModelErrorEventWithCollectionErrorEvent(me, ce)
|
||||
return me.Next()
|
||||
})
|
||||
}
|
||||
|
||||
return me.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// --------------------------------------------------------------
|
||||
|
||||
app.OnCollectionValidate().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onCollectionValidate,
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
app.OnCollectionCreate().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onCollectionSave,
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnCollectionUpdate().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onCollectionSave,
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnCollectionCreateExecute().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onCollectionSaveExecute,
|
||||
// execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
app.OnCollectionUpdateExecute().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onCollectionSaveExecute,
|
||||
Priority: 99, // execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
|
||||
})
|
||||
|
||||
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onCollectionDeleteExecute,
|
||||
Priority: 99, // execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
|
||||
})
|
||||
|
||||
// reload cache on failure
|
||||
// ---
|
||||
onErrorReloadCachedCollections := func(ce *CollectionErrorEvent) error {
|
||||
if err := ce.App.ReloadCachedCollections(); err != nil {
|
||||
ce.App.Logger().Warn("Failed to reload collections cache", "error", err)
|
||||
}
|
||||
|
||||
return ce.Next()
|
||||
}
|
||||
app.OnCollectionAfterCreateError().Bind(&hook.Handler[*CollectionErrorEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onErrorReloadCachedCollections,
|
||||
Priority: -99,
|
||||
})
|
||||
app.OnCollectionAfterUpdateError().Bind(&hook.Handler[*CollectionErrorEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onErrorReloadCachedCollections,
|
||||
Priority: -99,
|
||||
})
|
||||
app.OnCollectionAfterDeleteError().Bind(&hook.Handler[*CollectionErrorEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: onErrorReloadCachedCollections,
|
||||
Priority: -99,
|
||||
})
|
||||
// ---
|
||||
|
||||
app.OnBootstrap().Bind(&hook.Handler[*BootstrapEvent]{
|
||||
Id: systemHookIdCollection,
|
||||
Func: func(e *BootstrapEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.App.ReloadCachedCollections(); err != nil {
|
||||
return fmt.Errorf("failed to load collections cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Priority: 99, // execute as latest as possible
|
||||
})
|
||||
}
|
||||
|
||||
// @todo experiment eventually replacing the rules *string with a struct?
|
||||
type baseCollection struct {
|
||||
BaseModel
|
||||
|
||||
disableIntegrityChecks bool
|
||||
|
||||
ListRule *string `db:"listRule" json:"listRule" form:"listRule"`
|
||||
ViewRule *string `db:"viewRule" json:"viewRule" form:"viewRule"`
|
||||
CreateRule *string `db:"createRule" json:"createRule" form:"createRule"`
|
||||
UpdateRule *string `db:"updateRule" json:"updateRule" form:"updateRule"`
|
||||
DeleteRule *string `db:"deleteRule" json:"deleteRule" form:"deleteRule"`
|
||||
|
||||
// RawOptions represents the raw serialized collection option loaded from the DB.
|
||||
// NB! This field shouldn't be modified manually. It is automatically updated
|
||||
// with the collection type specific option before save.
|
||||
RawOptions types.JSONRaw `db:"options" json:"-" xml:"-" form:"-"`
|
||||
|
||||
Name string `db:"name" json:"name" form:"name"`
|
||||
Type string `db:"type" json:"type" form:"type"`
|
||||
Fields FieldsList `db:"fields" json:"fields" form:"fields"`
|
||||
Indexes types.JSONArray[string] `db:"indexes" json:"indexes" form:"indexes"`
|
||||
System bool `db:"system" json:"system" form:"system"`
|
||||
Created types.DateTime `db:"created" json:"created"`
|
||||
Updated types.DateTime `db:"updated" json:"updated"`
|
||||
}
|
||||
|
||||
// Collection defines the table, fields and various options related to a set of records.
|
||||
type Collection struct {
|
||||
baseCollection
|
||||
collectionAuthOptions
|
||||
collectionViewOptions
|
||||
}
|
||||
|
||||
// NewCollection initializes and returns a new Collection model with the specified type and name.
|
||||
func NewCollection(typ, name string) *Collection {
|
||||
switch typ {
|
||||
case CollectionTypeAuth:
|
||||
return NewAuthCollection(name)
|
||||
case CollectionTypeView:
|
||||
return NewViewCollection(name)
|
||||
default:
|
||||
return NewBaseCollection(name)
|
||||
}
|
||||
}
|
||||
|
||||
// NewBaseCollection initializes and returns a new "base" Collection model.
|
||||
func NewBaseCollection(name string) *Collection {
|
||||
m := &Collection{}
|
||||
m.Name = name
|
||||
m.Type = CollectionTypeBase
|
||||
m.initDefaultId()
|
||||
m.initDefaultFields()
|
||||
return m
|
||||
}
|
||||
|
||||
// NewViewCollection initializes and returns a new "view" Collection model.
|
||||
func NewViewCollection(name string) *Collection {
|
||||
m := &Collection{}
|
||||
m.Name = name
|
||||
m.Type = CollectionTypeView
|
||||
m.initDefaultId()
|
||||
m.initDefaultFields()
|
||||
return m
|
||||
}
|
||||
|
||||
// NewAuthCollection initializes and returns a new "auth" Collection model.
|
||||
func NewAuthCollection(name string) *Collection {
|
||||
m := &Collection{}
|
||||
m.Name = name
|
||||
m.Type = CollectionTypeAuth
|
||||
m.initDefaultId()
|
||||
m.initDefaultFields()
|
||||
m.setDefaultAuthOptions()
|
||||
return m
|
||||
}
|
||||
|
||||
// TableName returns the Collection model SQL table name.
|
||||
func (m *Collection) TableName() string {
|
||||
return "_collections"
|
||||
}
|
||||
|
||||
// BaseFilesPath returns the storage dir path used by the collection.
|
||||
func (m *Collection) BaseFilesPath() string {
|
||||
return m.Id
|
||||
}
|
||||
|
||||
// IsBase checks if the current collection has "base" type.
|
||||
func (m *Collection) IsBase() bool {
|
||||
return m.Type == CollectionTypeBase
|
||||
}
|
||||
|
||||
// IsAuth checks if the current collection has "auth" type.
|
||||
func (m *Collection) IsAuth() bool {
|
||||
return m.Type == CollectionTypeAuth
|
||||
}
|
||||
|
||||
// IsView checks if the current collection has "view" type.
|
||||
func (m *Collection) IsView() bool {
|
||||
return m.Type == CollectionTypeView
|
||||
}
|
||||
|
||||
// IntegrityChecks toggles the current collection integrity checks (ex. checking references on delete).
|
||||
func (m *Collection) IntegrityChecks(enable bool) {
|
||||
m.disableIntegrityChecks = !enable
|
||||
}
|
||||
|
||||
// PostScan implements the [dbx.PostScanner] interface to auto unmarshal
|
||||
// the raw serialized options into the concrete type specific fields.
|
||||
func (m *Collection) PostScan() error {
|
||||
if err := m.BaseModel.PostScan(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.unmarshalRawOptions()
|
||||
}
|
||||
|
||||
func (m *Collection) unmarshalRawOptions() error {
|
||||
raw, err := m.RawOptions.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case CollectionTypeView:
|
||||
return json.Unmarshal(raw, &m.collectionViewOptions)
|
||||
case CollectionTypeAuth:
|
||||
return json.Unmarshal(raw, &m.collectionAuthOptions)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the [json.Unmarshaler] interface.
|
||||
//
|
||||
// For new/"blank" Collection models it replaces the model with a factory
|
||||
// instance and then unmarshal the provided data one on top of it.
|
||||
func (m *Collection) UnmarshalJSON(b []byte) error {
|
||||
type alias *Collection
|
||||
|
||||
// initialize the default fields
|
||||
// (e.g. in case the collection was NOT created using the designated factories)
|
||||
if m.IsNew() && m.Type == "" {
|
||||
minimal := &struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, minimal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
blank := NewCollection(minimal.Type, minimal.Name)
|
||||
*m = *blank
|
||||
}
|
||||
|
||||
return json.Unmarshal(b, alias(m))
|
||||
}
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
//
|
||||
// Note that non-type related fields are ignored from the serialization
|
||||
// (ex. for "view" colections the "auth" fields are skipped).
|
||||
func (m Collection) MarshalJSON() ([]byte, error) {
|
||||
switch m.Type {
|
||||
case CollectionTypeView:
|
||||
return json.Marshal(struct {
|
||||
baseCollection
|
||||
collectionViewOptions
|
||||
}{m.baseCollection, m.collectionViewOptions})
|
||||
case CollectionTypeAuth:
|
||||
alias := struct {
|
||||
baseCollection
|
||||
collectionAuthOptions
|
||||
}{m.baseCollection, m.collectionAuthOptions}
|
||||
|
||||
// ensure that it is always returned as array
|
||||
if alias.OAuth2.Providers == nil {
|
||||
alias.OAuth2.Providers = []OAuth2ProviderConfig{}
|
||||
}
|
||||
|
||||
// hide secret keys from the serialization
|
||||
alias.AuthToken.Secret = ""
|
||||
alias.FileToken.Secret = ""
|
||||
alias.PasswordResetToken.Secret = ""
|
||||
alias.EmailChangeToken.Secret = ""
|
||||
alias.VerificationToken.Secret = ""
|
||||
for i := range alias.OAuth2.Providers {
|
||||
alias.OAuth2.Providers[i].ClientSecret = ""
|
||||
}
|
||||
|
||||
return json.Marshal(alias)
|
||||
default:
|
||||
return json.Marshal(m.baseCollection)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the current collection.
|
||||
func (m Collection) String() string {
|
||||
raw, _ := json.Marshal(m)
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
// DBExport prepares and exports the current collection data for db persistence.
|
||||
func (m *Collection) DBExport(app App) (map[string]any, error) {
|
||||
result := map[string]any{
|
||||
"id": m.Id,
|
||||
"type": m.Type,
|
||||
"listRule": m.ListRule,
|
||||
"viewRule": m.ViewRule,
|
||||
"createRule": m.CreateRule,
|
||||
"updateRule": m.UpdateRule,
|
||||
"deleteRule": m.DeleteRule,
|
||||
"name": m.Name,
|
||||
"fields": m.Fields,
|
||||
"indexes": m.Indexes,
|
||||
"system": m.System,
|
||||
"created": m.Created,
|
||||
"updated": m.Updated,
|
||||
"options": `{}`,
|
||||
}
|
||||
|
||||
switch m.Type {
|
||||
case CollectionTypeView:
|
||||
if raw, err := types.ParseJSONRaw(m.collectionViewOptions); err == nil {
|
||||
result["options"] = raw
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
case CollectionTypeAuth:
|
||||
if raw, err := types.ParseJSONRaw(m.collectionAuthOptions); err == nil {
|
||||
result["options"] = raw
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetIndex returns s single Collection index expression by its name.
|
||||
func (m *Collection) GetIndex(name string) string {
|
||||
for _, idx := range m.Indexes {
|
||||
if strings.EqualFold(dbutils.ParseIndex(idx).IndexName, name) {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// AddIndex adds a new index into the current collection.
|
||||
//
|
||||
// If the collection has an existing index matching the new name it will be replaced with the new one.
|
||||
func (m *Collection) AddIndex(name string, unique bool, columnsExpr string, optWhereExpr string) {
|
||||
m.RemoveIndex(name)
|
||||
|
||||
var idx strings.Builder
|
||||
|
||||
idx.WriteString("CREATE ")
|
||||
if unique {
|
||||
idx.WriteString("UNIQUE ")
|
||||
}
|
||||
idx.WriteString("INDEX `")
|
||||
idx.WriteString(name)
|
||||
idx.WriteString("` ")
|
||||
idx.WriteString("ON `")
|
||||
idx.WriteString(m.Name)
|
||||
idx.WriteString("` (")
|
||||
idx.WriteString(columnsExpr)
|
||||
idx.WriteString(")")
|
||||
if optWhereExpr != "" {
|
||||
idx.WriteString(" WHERE ")
|
||||
idx.WriteString(optWhereExpr)
|
||||
}
|
||||
|
||||
m.Indexes = append(m.Indexes, idx.String())
|
||||
}
|
||||
|
||||
// RemoveIndex removes a single index with the specified name from the current collection.
|
||||
func (m *Collection) RemoveIndex(name string) {
|
||||
for i, idx := range m.Indexes {
|
||||
if strings.EqualFold(dbutils.ParseIndex(idx).IndexName, name) {
|
||||
m.Indexes = append(m.Indexes[:i], m.Indexes[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// delete hook
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func onCollectionDeleteExecute(e *CollectionEvent) error {
|
||||
if e.Collection.System {
|
||||
return fmt.Errorf("[%s] system collections cannot be deleted", e.Collection.Name)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := e.App.ReloadCachedCollections(); err != nil {
|
||||
e.App.Logger().Warn("Failed to reload collections cache", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !e.Collection.disableIntegrityChecks {
|
||||
// ensure that there aren't any existing references.
|
||||
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
|
||||
references, err := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[%s] failed to check collection references: %w", e.Collection.Name, err)
|
||||
}
|
||||
if total := len(references); total > 0 {
|
||||
names := make([]string, 0, len(references))
|
||||
for ref := range references {
|
||||
names = append(names, ref.Name)
|
||||
}
|
||||
return fmt.Errorf("[%s] failed to delete due to existing relation references: %s", e.Collection.Name, strings.Join(names, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
// delete the related view or records table
|
||||
if e.Collection.IsView() {
|
||||
if err := txApp.DeleteView(e.Collection.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := txApp.DeleteTable(e.Collection.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !e.Collection.disableIntegrityChecks {
|
||||
// trigger views resave to check for dependencies
|
||||
if err := resaveViewsWithChangedFields(txApp, e.Collection.Id); err != nil {
|
||||
return fmt.Errorf("[%s] failed to delete due to existing view dependency: %w", e.Collection.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// delete
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
}
|
||||
|
||||
// save hook
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func (c *Collection) initDefaultId() {
|
||||
if c.Id == "" && c.Name != "" {
|
||||
c.Id = "_pbc_" + crc32Checksum(c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Collection) savePrepare() error {
|
||||
if c.Type == "" {
|
||||
c.Type = CollectionTypeBase
|
||||
}
|
||||
|
||||
if c.IsNew() {
|
||||
c.initDefaultId()
|
||||
c.Created = types.NowDateTime()
|
||||
}
|
||||
|
||||
c.Updated = types.NowDateTime()
|
||||
|
||||
// recreate the fields list to ensure that all normalizations
|
||||
// like default field id are applied
|
||||
c.Fields = NewFieldsList(c.Fields...)
|
||||
|
||||
c.initDefaultFields()
|
||||
|
||||
if c.IsAuth() {
|
||||
c.unsetMissingOAuth2MappedFields()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func onCollectionSave(e *CollectionEvent) error {
|
||||
if err := e.Collection.savePrepare(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
func onCollectionSaveExecute(e *CollectionEvent) error {
|
||||
defer func() {
|
||||
if err := e.App.ReloadCachedCollections(); err != nil {
|
||||
e.App.Logger().Warn("Failed to reload collections cache", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var oldCollection *Collection
|
||||
if !e.Collection.IsNew() {
|
||||
var err error
|
||||
oldCollection, err = e.App.FindCachedCollectionByNameOrId(e.Collection.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// invalidate previously issued auth tokens on auth rule change
|
||||
if oldCollection.AuthRule != e.Collection.AuthRule &&
|
||||
cast.ToString(oldCollection.AuthRule) != cast.ToString(e.Collection.AuthRule) {
|
||||
e.Collection.AuthToken.Secret = security.RandomString(50)
|
||||
}
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
isView := e.Collection.IsView()
|
||||
|
||||
// ensures that the view collection shema is properly loaded
|
||||
if isView {
|
||||
query := e.Collection.ViewQuery
|
||||
|
||||
// generate collection fields list from the query
|
||||
viewFields, err := e.App.CreateViewFields(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete old renamed view
|
||||
if oldCollection != nil {
|
||||
if err := e.App.DeleteView(oldCollection.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// wrap view query if necessary
|
||||
query, err = normalizeViewQueryId(e.App, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize view query id: %w", err)
|
||||
}
|
||||
|
||||
// (re)create the view
|
||||
if err := e.App.SaveView(e.Collection.Name, query); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// updates newCollection.Fields based on the generated view table info and query
|
||||
e.Collection.Fields = viewFields
|
||||
}
|
||||
|
||||
// save the Collection model
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sync the changes with the related records table
|
||||
if !isView {
|
||||
if err := e.App.SyncRecordTableSchema(e.Collection, oldCollection); err != nil {
|
||||
// note: don't wrap to allow propagating indexes validation.Errors
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
|
||||
// trigger an update for all views with changed fields as a result of the current collection save
|
||||
// (ignoring view errors to allow users to update the query from the UI)
|
||||
resaveViewsWithChangedFields(e.App, e.Collection.Id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Collection) initDefaultFields() {
|
||||
switch m.Type {
|
||||
case CollectionTypeBase:
|
||||
m.initIdField()
|
||||
case CollectionTypeAuth:
|
||||
m.initIdField()
|
||||
m.initPasswordField()
|
||||
m.initTokenKeyField()
|
||||
m.initEmailField()
|
||||
m.initEmailVisibilityField()
|
||||
m.initVerifiedField()
|
||||
case CollectionTypeView:
|
||||
// view fields are autogenerated
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) initIdField() {
|
||||
field, _ := m.Fields.GetByName(FieldNameId).(*TextField)
|
||||
if field == nil {
|
||||
// create default field
|
||||
field = &TextField{
|
||||
Name: FieldNameId,
|
||||
System: true,
|
||||
PrimaryKey: true,
|
||||
Required: true,
|
||||
Min: 15,
|
||||
Max: 15,
|
||||
Pattern: `^[a-z0-9]+$`,
|
||||
AutogeneratePattern: `[a-z0-9]{15}`,
|
||||
}
|
||||
|
||||
// prepend it
|
||||
m.Fields = NewFieldsList(append([]Field{field}, m.Fields...)...)
|
||||
} else {
|
||||
// enforce system defaults
|
||||
field.System = true
|
||||
field.Required = true
|
||||
field.PrimaryKey = true
|
||||
field.Hidden = false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) initPasswordField() {
|
||||
field, _ := m.Fields.GetByName(FieldNamePassword).(*PasswordField)
|
||||
if field == nil {
|
||||
// load default field
|
||||
m.Fields.Add(&PasswordField{
|
||||
Name: FieldNamePassword,
|
||||
System: true,
|
||||
Hidden: true,
|
||||
Required: true,
|
||||
Min: 8,
|
||||
})
|
||||
} else {
|
||||
// enforce system defaults
|
||||
field.System = true
|
||||
field.Hidden = true
|
||||
field.Required = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) initTokenKeyField() {
|
||||
field, _ := m.Fields.GetByName(FieldNameTokenKey).(*TextField)
|
||||
if field == nil {
|
||||
// load default field
|
||||
m.Fields.Add(&TextField{
|
||||
Name: FieldNameTokenKey,
|
||||
System: true,
|
||||
Hidden: true,
|
||||
Min: 30,
|
||||
Max: 60,
|
||||
Required: true,
|
||||
AutogeneratePattern: `[a-zA-Z0-9]{50}`,
|
||||
})
|
||||
} else {
|
||||
// enforce system defaults
|
||||
field.System = true
|
||||
field.Hidden = true
|
||||
field.Required = true
|
||||
}
|
||||
|
||||
// ensure that there is a unique index for the field
|
||||
if !dbutils.HasSingleColumnUniqueIndex(FieldNameTokenKey, m.Indexes) {
|
||||
m.Indexes = append(m.Indexes, fmt.Sprintf(
|
||||
"CREATE UNIQUE INDEX `%s` ON `%s` (`%s`)",
|
||||
m.fieldIndexName(FieldNameTokenKey),
|
||||
m.Name,
|
||||
FieldNameTokenKey,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) initEmailField() {
|
||||
field, _ := m.Fields.GetByName(FieldNameEmail).(*EmailField)
|
||||
if field == nil {
|
||||
// load default field
|
||||
m.Fields.Add(&EmailField{
|
||||
Name: FieldNameEmail,
|
||||
System: true,
|
||||
Required: true,
|
||||
})
|
||||
} else {
|
||||
// enforce system defaults
|
||||
field.System = true
|
||||
field.Hidden = false // managed by the emailVisibility flag
|
||||
}
|
||||
|
||||
// ensure that there is a unique index for the email field
|
||||
if !dbutils.HasSingleColumnUniqueIndex(FieldNameEmail, m.Indexes) {
|
||||
m.Indexes = append(m.Indexes, fmt.Sprintf(
|
||||
"CREATE UNIQUE INDEX `%s` ON `%s` (`%s`) WHERE `%s` != ''",
|
||||
m.fieldIndexName(FieldNameEmail),
|
||||
m.Name,
|
||||
FieldNameEmail,
|
||||
FieldNameEmail,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) initEmailVisibilityField() {
|
||||
field, _ := m.Fields.GetByName(FieldNameEmailVisibility).(*BoolField)
|
||||
if field == nil {
|
||||
// load default field
|
||||
m.Fields.Add(&BoolField{
|
||||
Name: FieldNameEmailVisibility,
|
||||
System: true,
|
||||
})
|
||||
} else {
|
||||
// enforce system defaults
|
||||
field.System = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) initVerifiedField() {
|
||||
field, _ := m.Fields.GetByName(FieldNameVerified).(*BoolField)
|
||||
if field == nil {
|
||||
// load default field
|
||||
m.Fields.Add(&BoolField{
|
||||
Name: FieldNameVerified,
|
||||
System: true,
|
||||
})
|
||||
} else {
|
||||
// enforce system defaults
|
||||
field.System = true
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) fieldIndexName(field string) string {
|
||||
name := "idx_" + field + "_"
|
||||
|
||||
if m.Id != "" {
|
||||
name += m.Id
|
||||
} else if m.Name != "" {
|
||||
name += m.Name
|
||||
} else {
|
||||
name += security.PseudorandomString(10)
|
||||
}
|
||||
|
||||
if len(name) > 64 {
|
||||
return name[:64]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
|
@ -0,0 +1,535 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func (m *Collection) unsetMissingOAuth2MappedFields() {
|
||||
if !m.IsAuth() {
|
||||
return
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Id != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Id) == nil {
|
||||
m.OAuth2.MappedFields.Id = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Name != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Name) == nil {
|
||||
m.OAuth2.MappedFields.Name = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Username != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Username) == nil {
|
||||
m.OAuth2.MappedFields.Username = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.AvatarURL != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.AvatarURL) == nil {
|
||||
m.OAuth2.MappedFields.AvatarURL = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) setDefaultAuthOptions() {
|
||||
m.collectionAuthOptions = collectionAuthOptions{
|
||||
VerificationTemplate: defaultVerificationTemplate,
|
||||
ResetPasswordTemplate: defaultResetPasswordTemplate,
|
||||
ConfirmEmailChangeTemplate: defaultConfirmEmailChangeTemplate,
|
||||
AuthRule: types.Pointer(""),
|
||||
AuthAlert: AuthAlertConfig{
|
||||
Enabled: true,
|
||||
EmailTemplate: defaultAuthAlertTemplate,
|
||||
},
|
||||
PasswordAuth: PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
IdentityFields: []string{FieldNameEmail},
|
||||
},
|
||||
MFA: MFAConfig{
|
||||
Enabled: false,
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
OTP: OTPConfig{
|
||||
Enabled: false,
|
||||
Duration: 180, // 3min
|
||||
Length: 8,
|
||||
EmailTemplate: defaultOTPTemplate,
|
||||
},
|
||||
AuthToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 604800, // 7 days
|
||||
},
|
||||
PasswordResetToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
EmailChangeToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
VerificationToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 259200, // 3days
|
||||
},
|
||||
FileToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 180, // 3min
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ optionsValidator = (*collectionAuthOptions)(nil)
|
||||
|
||||
// collectionAuthOptions defines the options for the "auth" type collection.
|
||||
type collectionAuthOptions struct {
|
||||
// AuthRule could be used to specify additional record constraints
|
||||
// applied after record authentication and right before returning the
|
||||
// auth token response to the client.
|
||||
//
|
||||
// For example, to allow only verified users you could set it to
|
||||
// "verified = true".
|
||||
//
|
||||
// Set it to empty string to allow any Auth collection record to authenticate.
|
||||
//
|
||||
// Set it to nil to disallow authentication altogether for the collection
|
||||
// (that includes password, OAuth2, etc.).
|
||||
AuthRule *string `form:"authRule" json:"authRule"`
|
||||
|
||||
// ManageRule gives admin-like permissions to allow fully managing
|
||||
// the auth record(s), eg. changing the password without requiring
|
||||
// to enter the old one, directly updating the verified state and email, etc.
|
||||
//
|
||||
// This rule is executed in addition to the Create and Update API rules.
|
||||
ManageRule *string `form:"manageRule" json:"manageRule"`
|
||||
|
||||
// AuthAlert defines options related to the auth alerts on new device login.
|
||||
AuthAlert AuthAlertConfig `form:"authAlert" json:"authAlert"`
|
||||
|
||||
// OAuth2 specifies whether OAuth2 auth is enabled for the collection
|
||||
// and which OAuth2 providers are allowed.
|
||||
OAuth2 OAuth2Config `form:"oauth2" json:"oauth2"`
|
||||
|
||||
PasswordAuth PasswordAuthConfig `form:"passwordAuth" json:"passwordAuth"`
|
||||
|
||||
MFA MFAConfig `form:"mfa" json:"mfa"`
|
||||
|
||||
OTP OTPConfig `form:"otp" json:"otp"`
|
||||
|
||||
// Various token configurations
|
||||
// ---
|
||||
AuthToken TokenConfig `form:"authToken" json:"authToken"`
|
||||
PasswordResetToken TokenConfig `form:"passwordResetToken" json:"passwordResetToken"`
|
||||
EmailChangeToken TokenConfig `form:"emailChangeToken" json:"emailChangeToken"`
|
||||
VerificationToken TokenConfig `form:"verificationToken" json:"verificationToken"`
|
||||
FileToken TokenConfig `form:"fileToken" json:"fileToken"`
|
||||
|
||||
// default email templates
|
||||
// ---
|
||||
VerificationTemplate EmailTemplate `form:"verificationTemplate" json:"verificationTemplate"`
|
||||
ResetPasswordTemplate EmailTemplate `form:"resetPasswordTemplate" json:"resetPasswordTemplate"`
|
||||
ConfirmEmailChangeTemplate EmailTemplate `form:"confirmEmailChangeTemplate" json:"confirmEmailChangeTemplate"`
|
||||
}
|
||||
|
||||
func (o *collectionAuthOptions) validate(cv *collectionValidator) error {
|
||||
err := validation.ValidateStruct(o,
|
||||
validation.Field(
|
||||
&o.AuthRule,
|
||||
validation.By(cv.checkRule),
|
||||
validation.By(cv.ensureNoSystemRuleChange(cv.original.AuthRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&o.ManageRule,
|
||||
validation.NilOrNotEmpty,
|
||||
validation.By(cv.checkRule),
|
||||
validation.By(cv.ensureNoSystemRuleChange(cv.original.ManageRule)),
|
||||
),
|
||||
validation.Field(&o.AuthAlert),
|
||||
validation.Field(&o.PasswordAuth),
|
||||
validation.Field(&o.OAuth2),
|
||||
validation.Field(&o.OTP),
|
||||
validation.Field(&o.MFA),
|
||||
validation.Field(&o.AuthToken),
|
||||
validation.Field(&o.PasswordResetToken),
|
||||
validation.Field(&o.EmailChangeToken),
|
||||
validation.Field(&o.VerificationToken),
|
||||
validation.Field(&o.FileToken),
|
||||
validation.Field(&o.VerificationTemplate, validation.Required),
|
||||
validation.Field(&o.ResetPasswordTemplate, validation.Required),
|
||||
validation.Field(&o.ConfirmEmailChangeTemplate, validation.Required),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.MFA.Enabled {
|
||||
// if MFA is enabled require at least 2 auth methods
|
||||
//
|
||||
// @todo maybe consider disabling the check because if custom auth methods
|
||||
// are registered it may fail since we don't have mechanism to detect them at the moment
|
||||
authsEnabled := 0
|
||||
if o.PasswordAuth.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if o.OAuth2.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if o.OTP.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if authsEnabled < 2 {
|
||||
return validation.Errors{
|
||||
"mfa": validation.Errors{
|
||||
"enabled": validation.NewError("validation_mfa_not_enough_auths", "MFA requires at least 2 auth methods to be enabled."),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if o.MFA.Rule != "" {
|
||||
mfaRuleValidators := []validation.RuleFunc{
|
||||
cv.checkRule,
|
||||
cv.ensureNoSystemRuleChange(&cv.original.MFA.Rule),
|
||||
}
|
||||
|
||||
for _, validator := range mfaRuleValidators {
|
||||
err := validator(&o.MFA.Rule)
|
||||
if err != nil {
|
||||
return validation.Errors{
|
||||
"mfa": validation.Errors{
|
||||
"rule": err,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extra check to ensure that only unique identity fields are used
|
||||
if o.PasswordAuth.Enabled {
|
||||
err = validation.Validate(o.PasswordAuth.IdentityFields, validation.By(cv.checkFieldsForUniqueIndex))
|
||||
if err != nil {
|
||||
return validation.Errors{
|
||||
"passwordAuth": validation.Errors{
|
||||
"identityFields": err,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type EmailTemplate struct {
|
||||
Subject string `form:"subject" json:"subject"`
|
||||
Body string `form:"body" json:"body"`
|
||||
}
|
||||
|
||||
// Validate makes EmailTemplate validatable by implementing [validation.Validatable] interface.
|
||||
func (t EmailTemplate) Validate() error {
|
||||
return validation.ValidateStruct(&t,
|
||||
validation.Field(&t.Subject, validation.Required),
|
||||
validation.Field(&t.Body, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
// Resolve replaces the placeholder parameters in the current email
|
||||
// template and returns its components as ready-to-use strings.
|
||||
func (t EmailTemplate) Resolve(placeholders map[string]any) (subject, body string) {
|
||||
body = t.Body
|
||||
subject = t.Subject
|
||||
|
||||
for k, v := range placeholders {
|
||||
vStr := cast.ToString(v)
|
||||
|
||||
// replace subject placeholder params (if any)
|
||||
subject = strings.ReplaceAll(subject, k, vStr)
|
||||
|
||||
// replace body placeholder params (if any)
|
||||
body = strings.ReplaceAll(body, k, vStr)
|
||||
}
|
||||
|
||||
return subject, body
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type AuthAlertConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
|
||||
}
|
||||
|
||||
// Validate makes AuthAlertConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c AuthAlertConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
// note: for now always run the email template validations even
|
||||
// if not enabled since it could be used separately
|
||||
validation.Field(&c.EmailTemplate),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type TokenConfig struct {
|
||||
Secret string `form:"secret" json:"secret,omitempty"`
|
||||
|
||||
// Duration specifies how long an issued token to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
}
|
||||
|
||||
// Validate makes TokenConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c TokenConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Secret, validation.Required, validation.Length(30, 255)),
|
||||
validation.Field(&c.Duration, validation.Required, validation.Min(10), validation.Max(94670856)), // ~3y max
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c TokenConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type OTPConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// Duration specifies how long the OTP to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
|
||||
// Length specifies the auto generated password length.
|
||||
Length int `form:"length" json:"length"`
|
||||
|
||||
// EmailTemplate is the default OTP email template that will be send to the auth record.
|
||||
//
|
||||
// In addition to the system placeholders you can also make use of
|
||||
// [core.EmailPlaceholderOTPId] and [core.EmailPlaceholderOTP].
|
||||
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
|
||||
}
|
||||
|
||||
// Validate makes OTPConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c OTPConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
|
||||
validation.Field(&c.Length, validation.When(c.Enabled, validation.Required, validation.Min(4))),
|
||||
// note: for now always run the email template validations even
|
||||
// if not enabled since it could be used separately
|
||||
validation.Field(&c.EmailTemplate),
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c OTPConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type MFAConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// Duration specifies how long an issued MFA to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
|
||||
// Rule is an optional field to restrict MFA only for the records that satisfy the rule.
|
||||
//
|
||||
// Leave it empty to enable MFA for everyone.
|
||||
Rule string `form:"rule" json:"rule"`
|
||||
}
|
||||
|
||||
// Validate makes MFAConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c MFAConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c MFAConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// IdentityFields is a list of field names that could be used as
|
||||
// identity during password authentication.
|
||||
//
|
||||
// Usually only fields that has single column UNIQUE index are accepted as values.
|
||||
IdentityFields []string `form:"identityFields" json:"identityFields"`
|
||||
}
|
||||
|
||||
// Validate makes PasswordAuthConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c PasswordAuthConfig) Validate() error {
|
||||
// strip duplicated values
|
||||
c.IdentityFields = list.ToUniqueStringSlice(c.IdentityFields)
|
||||
|
||||
if !c.Enabled {
|
||||
return nil // no need to validate
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.IdentityFields, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type OAuth2KnownFields struct {
|
||||
Id string `form:"id" json:"id"`
|
||||
Name string `form:"name" json:"name"`
|
||||
Username string `form:"username" json:"username"`
|
||||
AvatarURL string `form:"avatarURL" json:"avatarURL"`
|
||||
}
|
||||
|
||||
type OAuth2Config struct {
|
||||
Providers []OAuth2ProviderConfig `form:"providers" json:"providers"`
|
||||
|
||||
MappedFields OAuth2KnownFields `form:"mappedFields" json:"mappedFields"`
|
||||
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// GetProviderConfig returns the first OAuth2ProviderConfig that matches the specified name.
|
||||
//
|
||||
// Returns false and zero config if no such provider is available in c.Providers.
|
||||
func (c OAuth2Config) GetProviderConfig(name string) (config OAuth2ProviderConfig, exists bool) {
|
||||
for _, p := range c.Providers {
|
||||
if p.Name == name {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate makes OAuth2Config validatable by implementing [validation.Validatable] interface.
|
||||
func (c OAuth2Config) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil // no need to validate
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&c,
|
||||
// note: don't require providers for now as they could be externally registered/removed
|
||||
validation.Field(&c.Providers, validation.By(checkForDuplicatedProviders)),
|
||||
)
|
||||
}
|
||||
|
||||
func checkForDuplicatedProviders(value any) error {
|
||||
configs, _ := value.([]OAuth2ProviderConfig)
|
||||
|
||||
existing := map[string]struct{}{}
|
||||
|
||||
for i, c := range configs {
|
||||
if c.Name == "" {
|
||||
continue // the name nonempty state is validated separately
|
||||
}
|
||||
if _, ok := existing[c.Name]; ok {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.Errors{
|
||||
"name": validation.NewError("validation_duplicated_provider", "The provider "+c.Name+" is already registered.").
|
||||
SetParams(map[string]any{"name": c.Name}),
|
||||
},
|
||||
}
|
||||
}
|
||||
existing[c.Name] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type OAuth2ProviderConfig struct {
|
||||
// PKCE overwrites the default provider PKCE config option.
|
||||
//
|
||||
// This usually shouldn't be needed but some OAuth2 vendors, like the LinkedIn OIDC,
|
||||
// may require manual adjustment due to returning error if extra parameters are added to the request
|
||||
// (https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312)
|
||||
PKCE *bool `form:"pkce" json:"pkce"`
|
||||
|
||||
Name string `form:"name" json:"name"`
|
||||
ClientId string `form:"clientId" json:"clientId"`
|
||||
ClientSecret string `form:"clientSecret" json:"clientSecret,omitempty"`
|
||||
AuthURL string `form:"authURL" json:"authURL"`
|
||||
TokenURL string `form:"tokenURL" json:"tokenURL"`
|
||||
UserInfoURL string `form:"userInfoURL" json:"userInfoURL"`
|
||||
DisplayName string `form:"displayName" json:"displayName"`
|
||||
}
|
||||
|
||||
// Validate makes OAuth2ProviderConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c OAuth2ProviderConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Name, validation.Required, validation.By(checkProviderName)),
|
||||
validation.Field(&c.ClientId, validation.Required),
|
||||
validation.Field(&c.ClientSecret, validation.Required),
|
||||
validation.Field(&c.AuthURL, is.URL),
|
||||
validation.Field(&c.TokenURL, is.URL),
|
||||
validation.Field(&c.UserInfoURL, is.URL),
|
||||
)
|
||||
}
|
||||
|
||||
func checkProviderName(value any) error {
|
||||
name, _ := value.(string)
|
||||
if name == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if _, err := auth.NewProviderByName(name); err != nil {
|
||||
return validation.NewError("validation_missing_provider", "Invalid or missing provider with name "+name+".").
|
||||
SetParams(map[string]any{"name": name})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitProvider returns a new auth.Provider instance loaded with the current OAuth2ProviderConfig options.
|
||||
func (c OAuth2ProviderConfig) InitProvider() (auth.Provider, error) {
|
||||
provider, err := auth.NewProviderByName(c.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.ClientId != "" {
|
||||
provider.SetClientId(c.ClientId)
|
||||
}
|
||||
|
||||
if c.ClientSecret != "" {
|
||||
provider.SetClientSecret(c.ClientSecret)
|
||||
}
|
||||
|
||||
if c.AuthURL != "" {
|
||||
provider.SetAuthURL(c.AuthURL)
|
||||
}
|
||||
|
||||
if c.UserInfoURL != "" {
|
||||
provider.SetUserInfoURL(c.UserInfoURL)
|
||||
}
|
||||
|
||||
if c.TokenURL != "" {
|
||||
provider.SetTokenURL(c.TokenURL)
|
||||
}
|
||||
|
||||
if c.DisplayName != "" {
|
||||
provider.SetDisplayName(c.DisplayName)
|
||||
}
|
||||
|
||||
if c.PKCE != nil {
|
||||
provider.SetPKCE(*c.PKCE)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue