delay default response body write for *Request hooks wrapped in a transaction

This commit is contained in:
Gani Georgiev 2025-04-27 16:25:51 +03:00
parent 1a3efe96ac
commit dc350f0a3e
38 changed files with 759 additions and 149 deletions

View File

@ -1,3 +1,8 @@
## v0.28.0 (WIP)
- Write the default response body of `*Request` hooks that are wrapped in a transaction after the related transaction completes to allow propagating errors ([#6462](https://github.com/pocketbase/pocketbase/discussions/6462#discussioncomment-12207818)).
## v0.27.1
- Updated example `geoPoint` API preview body data.

View File

@ -49,7 +49,7 @@ var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
params["id"] = id // required for the path value
ir.Method = "PATCH"
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
return recordUpdate(next)
return recordUpdate(false, next)
}
}
@ -57,16 +57,16 @@ var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
// ---
ir.Method = "POST"
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
return recordCreate(next)
return recordCreate(false, next)
},
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordCreate(next)
return recordCreate(false, next)
},
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordUpdate(next)
return recordUpdate(false, next)
},
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordDelete(next)
return recordDelete(false, next)
},
}

View File

@ -45,7 +45,9 @@ func collectionsList(e *core.RequestEvent) error {
event.Result = result
return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
return e.JSON(http.StatusOK, e.Result)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Result)
})
})
}
@ -60,7 +62,9 @@ func collectionView(e *core.RequestEvent) error {
event.Collection = collection
return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
return e.JSON(http.StatusOK, e.Collection)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Collection)
})
})
}
@ -98,7 +102,9 @@ func collectionCreate(e *core.RequestEvent) error {
return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
}
return e.JSON(http.StatusOK, e.Collection)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Collection)
})
})
}
@ -128,7 +134,9 @@ func collectionUpdate(e *core.RequestEvent) error {
return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
}
return e.JSON(http.StatusOK, e.Collection)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Collection)
})
})
}
@ -159,7 +167,9 @@ func collectionDelete(e *core.RequestEvent) error {
return e.BadRequestError(msg, err)
}
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -29,7 +29,9 @@ func collectionsImport(e *core.RequestEvent) error {
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
if importErr == nil {
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
}
// validation failure

View File

@ -316,6 +316,51 @@ func TestCollectionsImport(t *testing.T) {
}
},
},
{
Name: "OnCollectionsImportRequest tx body write check",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"deleteMissing": true,
"collections":[
{"name": "test123"},
{
"id":"wsmn24bux7wo113",
"name":"demo1",
"fields":[
{
"id":"_2hlxbmp",
"name":"title",
"type":"text",
"required":true
}
],
"indexes": []
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionsImportRequest().BindFunc(func(e *core.CollectionsImportRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnCollectionsImportRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {

View File

@ -1,7 +1,6 @@
package apis_test
import (
"errors"
"net/http"
"os"
"path/filepath"
@ -130,6 +129,32 @@ func TestCollectionsList(t *testing.T) {
"OnCollectionsListRequest": 1,
},
},
{
Name: "OnCollectionsListRequest tx body write check",
Method: http.MethodGet,
URL: "/api/collections",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionsListRequest().BindFunc(func(e *core.CollectionsListRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnCollectionsListRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {
@ -205,6 +230,32 @@ func TestCollectionView(t *testing.T) {
"OnCollectionViewRequest": 1,
},
},
{
Name: "OnCollectionViewRequest tx body write check",
Method: http.MethodGet,
URL: "/api/collections/wsmn24bux7wo113",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionViewRequest().BindFunc(func(e *core.CollectionRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnCollectionViewRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {
@ -361,7 +412,7 @@ func TestCollectionDelete(t *testing.T) {
},
},
{
Name: "OnCollectionAfterDeleteSuccessRequest error response",
Name: "OnCollectionDeleteRequest tx body write check",
Method: http.MethodDelete,
URL: "/api/collections/view2",
Headers: map[string]string{
@ -369,15 +420,22 @@ func TestCollectionDelete(t *testing.T) {
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionDeleteRequest().BindFunc(func(e *core.CollectionRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionDeleteRequest": 1,
},
ExpectedEvents: map[string]int{"OnCollectionDeleteRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
@ -656,7 +714,7 @@ func TestCollectionCreate(t *testing.T) {
},
},
{
Name: "OnCollectionCreateRequest error response",
Name: "OnCollectionCreateRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections",
Body: strings.NewReader(`{"name":"new","type":"base"}`),
@ -665,15 +723,22 @@ func TestCollectionCreate(t *testing.T) {
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionCreateRequest().BindFunc(func(e *core.CollectionRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionCreateRequest": 1,
},
ExpectedEvents: map[string]int{"OnCollectionCreateRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// view
@ -978,7 +1043,7 @@ func TestCollectionUpdate(t *testing.T) {
},
},
{
Name: "OnCollectionAfterUpdateSuccessRequest error response",
Name: "OnCollectionUpdateRequest tx body write check",
Method: http.MethodPatch,
URL: "/api/collections/demo1",
Body: strings.NewReader(`{}`),
@ -987,15 +1052,22 @@ func TestCollectionUpdate(t *testing.T) {
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionUpdateRequest().BindFunc(func(e *core.CollectionRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionUpdateRequest": 1,
},
ExpectedEvents: map[string]int{"OnCollectionUpdateRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
{
Name: "authorized as superuser + invalid data (eg. existing name)",

View File

@ -75,8 +75,8 @@ func (api *fileApi) fileToken(e *core.RequestEvent) error {
event.Record = e.Auth
return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
return e.JSON(http.StatusOK, map[string]string{
"token": e.Token,
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, map[string]string{"token": e.Token})
})
})
}
@ -192,7 +192,10 @@ func (api *fileApi) download(e *core.RequestEvent) error {
e.Response.Header().Del("X-Frame-Options")
return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
if err := fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName); err != nil {
err = execAfterSuccessTx(true, e.App, func() error {
return fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName)
})
if err != nil {
return e.NotFoundError("", err)
}

View File

@ -365,11 +365,25 @@ func logRequest(event *core.RequestEvent, err error) {
// parse the request error
if err != nil {
if apiErr, ok := err.(*router.ApiError); ok {
status = apiErr.Status
apiErr, isPlainApiError := err.(*router.ApiError)
if isPlainApiError || errors.As(err, &apiErr) {
// the status header wasn't written yet
if status == 0 {
status = apiErr.Status
}
var errMsg string
if isPlainApiError {
errMsg = apiErr.Message
} else {
// wrapped ApiError -> add the full serialized version
// of the original error since it could contain more information
errMsg = err.Error()
}
attrs = append(
attrs,
slog.String("error", apiErr.Message),
slog.String("error", errMsg),
slog.Any("details", apiErr.RawData()),
)
} else {

View File

@ -213,7 +213,9 @@ func realtimeSetSubscriptions(e *core.RequestEvent) error {
slog.Any("subscriptions", e.Subscriptions),
)
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -45,7 +45,9 @@ func recordConfirmEmailChange(e *core.RequestEvent) error {
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
}
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -1,7 +1,6 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
@ -136,7 +135,7 @@ func TestRecordConfirmEmailChange(t *testing.T) {
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAfterConfirmEmailChangeRequest error response",
Name: "OnRecordConfirmEmailChangeRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
@ -145,15 +144,22 @@ func TestRecordConfirmEmailChange(t *testing.T) {
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmEmailChangeRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordConfirmEmailChangeRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks

View File

@ -43,7 +43,9 @@ func recordRequestEmailChange(e *core.RequestEvent) error {
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
}
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -118,6 +118,33 @@ func TestRecordRequestEmailChange(t *testing.T) {
}
},
},
{
Name: "OnRecordRequestEmailChangeRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestEmailChangeRequest().BindFunc(func(e *core.RecordRequestEmailChangeRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestEmailChangeRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------

View File

@ -26,10 +26,10 @@ func recordAuthImpersonate(e *core.RequestEvent) error {
form := &impersonateForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)

View File

@ -108,8 +108,8 @@ func recordRequestOTP(e *core.RequestEvent) error {
})
}
return e.JSON(http.StatusOK, map[string]string{
"otpId": otp.Id,
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, map[string]string{"otpId": otp.Id})
})
})
}

View File

@ -247,6 +247,31 @@ func TestRecordRequestOTP(t *testing.T) {
}
},
},
{
Name: "OnRecordRequestOTPRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestOTPRequest().BindFunc(func(e *core.RecordCreateOTPRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestOTPRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------

View File

@ -54,7 +54,9 @@ func recordConfirmPasswordReset(e *core.RequestEvent) error {
e.App.Store().Remove(getPasswordResetResendKey(authRecord))
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -1,7 +1,6 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
@ -282,7 +281,7 @@ func TestRecordConfirmPasswordReset(t *testing.T) {
},
},
{
Name: "OnRecordAfterConfirmPasswordResetRequest error response",
Name: "OnRecordConfirmPasswordResetRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
@ -292,15 +291,22 @@ func TestRecordConfirmPasswordReset(t *testing.T) {
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordConfirmPasswordResetRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks

View File

@ -65,7 +65,9 @@ func recordRequestPasswordReset(e *core.RequestEvent) error {
})
})
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -101,6 +101,30 @@ func TestRecordRequestPasswordReset(t *testing.T) {
app.Store().Set(resendKey, struct{}{})
},
},
{
Name: "OnRecordRequestPasswordResetRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestPasswordResetRequest().BindFunc(func(e *core.RecordRequestPasswordResetRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestPasswordResetRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------

View File

@ -1,7 +1,6 @@
package apis_test
import (
"errors"
"net/http"
"testing"
@ -130,23 +129,30 @@ func TestRecordAuthRefresh(t *testing.T) {
},
},
{
Name: "OnRecordAfterAuthRefreshRequest error response",
Name: "OnRecordAuthRefreshRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordAuthRefreshRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks

View File

@ -42,19 +42,19 @@ func recordConfirmVerification(e *core.RequestEvent) error {
event.Record = record
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
if wasVerified {
return e.NoContent(http.StatusNoContent)
}
if !wasVerified {
e.Record.SetVerified(true)
e.Record.SetVerified(true)
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
}
}
e.App.Store().Remove(getVerificationResendKey(e.Record))
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -1,7 +1,6 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
@ -144,7 +143,7 @@ func TestRecordConfirmVerification(t *testing.T) {
},
},
{
Name: "OnRecordAfterConfirmVerificationRequest error response",
Name: "OnRecordConfirmVerificationRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
@ -152,15 +151,22 @@ func TestRecordConfirmVerification(t *testing.T) {
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordConfirmVerificationRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks

View File

@ -68,7 +68,9 @@ func recordRequestVerification(e *core.RequestEvent) error {
})
})
return e.NoContent(http.StatusNoContent)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}

View File

@ -118,6 +118,30 @@ func TestRecordRequestVerification(t *testing.T) {
}
},
},
{
Name: "OnRecordRequestVerificationRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestVerificationRequest().BindFunc(func(e *core.RecordRequestVerificationRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestVerificationRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------

View File

@ -1577,6 +1577,69 @@ func TestRecordAuthWithOAuth2(t *testing.T) {
"OnRecordValidate": 4,
},
},
{
Name: "OnRecordAuthWithOAuth2Request tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-oauth2",
Body: strings.NewReader(`{
"provider": "test",
"code":"123",
"redirectURL": "https://example.com"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// register the test provider
auth.Providers["test"] = func() auth.Provider {
return &oauth2MockProvider{
AuthUser: &auth.AuthUser{Id: "test_id"},
Token: &oauth2.Token{AccessToken: "abc"},
}
}
// add the test provider in the collection
user.Collection().MFA.Enabled = false
user.Collection().OAuth2.Enabled = true
user.Collection().OAuth2.Providers = []core.OAuth2ProviderConfig{{
Name: "test",
ClientId: "123",
ClientSecret: "456",
}}
if err := app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
// stub linked provider
ea := core.NewExternalAuth(app)
ea.SetCollectionRef(user.Collection().Id)
ea.SetRecordRef(user.Id)
ea.SetProvider("test")
ea.SetProviderId("test_id")
if err := app.Save(ea); err != nil {
t.Fatal(err)
}
app.OnRecordAuthWithOAuth2Request().BindFunc(func(e *core.RecordAuthWithOAuth2RequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordAuthWithOAuth2Request": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------

View File

@ -419,6 +419,53 @@ func TestRecordAuthWithOTP(t *testing.T) {
}
},
},
{
Name: "OnRecordAuthWithOTPRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// disable MFA
user.Collection().MFA.Enabled = false
if err = app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
app.OnRecordAuthWithOTPRequest().BindFunc(func(e *core.RecordAuthWithOTPRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordAuthWithOTPRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------

View File

@ -1,7 +1,6 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
@ -82,7 +81,7 @@ func TestRecordAuthWithPassword(t *testing.T) {
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAuthWithPasswordRequest error response",
Name: "OnRecordAuthWithPasswordRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
@ -91,15 +90,22 @@ func TestRecordAuthWithPassword(t *testing.T) {
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordAuthWithPasswordRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
{
Name: "valid identity field and invalid password",

View File

@ -28,9 +28,9 @@ func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent])
subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
subGroup.GET("", recordsList)
subGroup.GET("/{id}", recordView)
subGroup.POST("", recordCreate(nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.PATCH("/{id}", recordUpdate(nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.DELETE("/{id}", recordDelete(nil))
subGroup.POST("", recordCreate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.PATCH("/{id}", recordUpdate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.DELETE("/{id}", recordDelete(true, nil))
}
func recordsList(e *core.RequestEvent) error {
@ -121,7 +121,9 @@ func recordsList(e *core.RequestEvent) error {
randomizedThrottle(150)
}
return e.JSON(http.StatusOK, e.Result)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Result)
})
})
}
@ -192,11 +194,13 @@ func recordView(e *core.RequestEvent) error {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
return e.JSON(http.StatusOK, e.Record)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Record)
})
})
}
func recordCreate(optFinalizer func(data any) error) func(e *core.RequestEvent) error {
func recordCreate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
@ -344,7 +348,9 @@ func recordCreate(optFinalizer func(data any) error) func(e *core.RequestEvent)
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = e.JSON(http.StatusOK, e.Record)
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
return e.JSON(http.StatusOK, e.Record)
})
if err != nil {
return err
}
@ -374,7 +380,7 @@ func recordCreate(optFinalizer func(data any) error) func(e *core.RequestEvent)
}
}
func recordUpdate(optFinalizer func(data any) error) func(e *core.RequestEvent) error {
func recordUpdate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
@ -475,7 +481,9 @@ func recordUpdate(optFinalizer func(data any) error) func(e *core.RequestEvent)
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = e.JSON(http.StatusOK, e.Record)
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
return e.JSON(http.StatusOK, e.Record)
})
if err != nil {
return err
}
@ -505,7 +513,7 @@ func recordUpdate(optFinalizer func(data any) error) func(e *core.RequestEvent)
}
}
func recordDelete(optFinalizer func(data any) error) func(e *core.RequestEvent) error {
func recordDelete(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
@ -565,7 +573,9 @@ func recordDelete(optFinalizer func(data any) error) func(e *core.RequestEvent)
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
}
err = e.NoContent(http.StatusNoContent)
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
if err != nil {
return err
}

View File

@ -2,7 +2,6 @@ package apis_test
import (
"bytes"
"errors"
"net/http"
"net/url"
"os"
@ -418,6 +417,32 @@ func TestRecordCrudList(t *testing.T) {
"OnRecordsListRequest": 1,
},
},
{
Name: "OnRecordsListRequest tx body write check",
Method: http.MethodGet,
URL: "/api/collections/demo4/records",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordsListRequest().BindFunc(func(e *core.RecordsListRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordsListRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// auth collection
// -----------------------------------------------------------
@ -862,6 +887,32 @@ func TestRecordCrudView(t *testing.T) {
"OnRecordEnrich": 7,
},
},
{
Name: "OnRecordViewRequest tx body write check",
Method: http.MethodGet,
URL: "/api/collections/demo1/records/al1h9ijdeojtsjy",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordViewRequest().BindFunc(func(e *core.RecordRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordViewRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// auth collection
// -----------------------------------------------------------
@ -1209,7 +1260,7 @@ func TestRecordCrudDelete(t *testing.T) {
},
},
{
Name: "OnRecordAfterDeleteSuccessRequest error response",
Name: "OnRecordDeleteRequest tx body write check",
Method: http.MethodDelete,
URL: "/api/collections/clients/records/o1y0dd0spd786md",
Headers: map[string]string{
@ -1217,15 +1268,22 @@ func TestRecordCrudDelete(t *testing.T) {
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordDeleteRequest().BindFunc(func(e *core.RecordRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordDeleteRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
{
Name: "authenticated record that match the collection delete rule",
@ -1792,21 +1850,31 @@ func TestRecordCrudCreate(t *testing.T) {
},
},
{
Name: "OnRecordAfterCreateSuccessRequest error response",
Name: "OnRecordCreateRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/demo2/records",
Body: strings.NewReader(`{"title":"new"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordCreateRequest().BindFunc(func(e *core.RecordRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordCreateRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// ID checks
@ -2799,21 +2867,31 @@ func TestRecordCrudUpdate(t *testing.T) {
},
},
{
Name: "OnRecordAfterUpdateSuccessRequest error response",
Name: "OnRecordUpdateRequest tx body write check",
Method: http.MethodPatch,
URL: "/api/collections/demo2/records/0yxhwia2amd8gec",
Body: strings.NewReader(`{"title":"new"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordUpdateRequest().BindFunc(func(e *core.RecordRequestEvent) error {
return errors.New("error")
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
},
ExpectedEvents: map[string]int{"OnRecordUpdateRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
{
Name: "try to change the id of an existing record",

View File

@ -129,7 +129,9 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str
result.Meta = e.Meta
}
return e.JSON(http.StatusOK, result)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, result)
})
})
}
@ -535,6 +537,27 @@ func firstApiError(errs ...error) *router.ApiError {
return router.NewInternalServerError("", errors.Join(errs...))
}
// execAfterSuccessTx ensures that fn is executed only after a succesul transaction.
//
// If the current app instance is not a transactional or checkTx is false,
// then fn is directly executed.
//
// It could be usually used to allow propagating an error or writing
// custom response from within the wrapped transaction block.
func execAfterSuccessTx(checkTx bool, app core.App, fn func() error) error {
if txInfo := app.TxInfo(); txInfo != nil && checkTx {
txInfo.OnComplete(func(txErr error) error {
if txErr == nil {
return fn()
}
return nil
})
return nil
}
return fn()
}
// -------------------------------------------------------------------
const maxAuthOrigins = 5

View File

@ -30,7 +30,9 @@ func settingsList(e *core.RequestEvent) error {
event.Settings = clone
return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
return e.JSON(http.StatusOK, e.Settings)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Settings)
})
})
}
@ -65,7 +67,9 @@ func settingsSet(e *core.RequestEvent) error {
return e.InternalServerError("Failed to clone app settings.", err)
}
return e.JSON(http.StatusOK, appSettings)
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, appSettings)
})
})
}

View File

@ -11,6 +11,7 @@ import (
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
@ -58,6 +59,32 @@ func TestSettingsList(t *testing.T) {
"OnSettingsListRequest": 1,
},
},
{
Name: "OnSettingsListRequest tx body write check",
Method: http.MethodGet,
URL: "/api/settings",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnSettingsListRequest().BindFunc(func(e *core.SettingsListRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnSettingsListRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {
@ -176,6 +203,33 @@ func TestSettingsSet(t *testing.T) {
"OnSettingsReload": 1,
},
},
{
Name: "OnSettingsUpdateRequest tx body write check",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(validData),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnSettingsUpdateRequest().BindFunc(func(e *core.SettingsUpdateRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnSettingsUpdateRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {

View File

@ -45,6 +45,12 @@ type App interface {
// IsTransactional checks if the current app instance is part of a transaction.
IsTransactional() bool
// TxInfo returns the transaction associated with the current app instance (if any).
//
// Could be used if you want to execute indirectly a function after
// the related app transaction completes using `app.TxInfo().OnAfterFunc(callback)`.
TxInfo() *TxAppInfo
// Bootstrap initializes the application
// (aka. create data dir, open db connections, load settings, etc.).
//

View File

@ -69,7 +69,7 @@ var _ App = (*BaseApp)(nil)
// BaseApp implements core.App and defines the base PocketBase app structure.
type BaseApp struct {
config *BaseAppConfig
txInfo *txAppInfo
txInfo *TxAppInfo
store *store.Store[string, any]
cron *cron.Cron
settings *Settings
@ -360,9 +360,17 @@ func (app *BaseApp) Logger() *slog.Logger {
return app.logger
}
// TxInfo returns the transaction associated with the current app instance (if any).
//
// Could be used if you want to execute indirectly a function after
// the related app transaction completes using `app.TxInfo().OnAfterFunc(callback)`.
func (app *BaseApp) TxInfo() *TxAppInfo {
return app.txInfo
}
// IsTransactional checks if the current app instance is part of a transaction.
func (app *BaseApp) IsTransactional() bool {
return app.txInfo != nil
return app.TxInfo() != nil
}
// IsBootstrapped checks if the application was initialized

View File

@ -128,7 +128,7 @@ func TestBaseAppBootstrap(t *testing.T) {
runNilChecks(nilChecksAfterReset)
}
func TestNewBaseAppIsTransactional(t *testing.T) {
func TestNewBaseAppTx(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
@ -141,17 +141,34 @@ func TestNewBaseAppIsTransactional(t *testing.T) {
t.Fatal(err)
}
if app.IsTransactional() {
t.Fatalf("Didn't expect the app to be transactional")
mustNotHaveTx := func(app core.App) {
if app.IsTransactional() {
t.Fatalf("Didn't expect the app to be transactional")
}
if app.TxInfo() != nil {
t.Fatalf("Didn't expect the app.txInfo to be loaded")
}
}
app.RunInTransaction(func(txApp core.App) error {
if !txApp.IsTransactional() {
mustHaveTx := func(app core.App) {
if !app.IsTransactional() {
t.Fatalf("Expected the app to be transactional")
}
if app.TxInfo() == nil {
t.Fatalf("Expected the app.txInfo to be loaded")
}
}
mustNotHaveTx(app)
app.RunInTransaction(func(txApp core.App) error {
mustHaveTx(txApp)
return nil
})
mustNotHaveTx(app)
}
func TestBaseAppNewMailClient(t *testing.T) {

View File

@ -151,7 +151,7 @@ func (app *BaseApp) delete(ctx context.Context, model Model, isForAuxDB bool) er
if app.txInfo != nil {
// execute later after the transaction has completed
app.txInfo.onAfterFunc(func(txErr error) error {
app.txInfo.OnComplete(func(txErr error) error {
if app.txInfo != nil && app.txInfo.parent != nil {
event.App = app.txInfo.parent
}
@ -342,7 +342,7 @@ func (app *BaseApp) create(ctx context.Context, model Model, withValidations boo
if app.txInfo != nil {
// execute later after the transaction has completed
app.txInfo.onAfterFunc(func(txErr error) error {
app.txInfo.OnComplete(func(txErr error) error {
if app.txInfo != nil && app.txInfo.parent != nil {
event.App = app.txInfo.parent
}
@ -426,7 +426,7 @@ func (app *BaseApp) update(ctx context.Context, model Model, withValidations boo
if app.txInfo != nil {
// execute later after the transaction has completed
app.txInfo.onAfterFunc(func(txErr error) error {
app.txInfo.OnComplete(func(txErr error) error {
if app.txInfo != nil && app.txInfo.parent != nil {
event.App = app.txInfo.parent
}

View File

@ -60,7 +60,7 @@ func (app *BaseApp) createTxApp(tx *dbx.Tx, isForAuxDB bool) *BaseApp {
clone.nonconcurrentDB = tx
}
clone.txInfo = &txAppInfo{
clone.txInfo = &TxAppInfo{
parent: app,
isForAuxDB: isForAuxDB,
}
@ -68,22 +68,29 @@ func (app *BaseApp) createTxApp(tx *dbx.Tx, isForAuxDB bool) *BaseApp {
return &clone
}
type txAppInfo struct {
// TxAppInfo represents an active transaction context associated to an existing app instance.
type TxAppInfo struct {
parent *BaseApp
afterFuncs []func(txErr error) error
mu sync.Mutex
isForAuxDB bool
}
func (a *txAppInfo) onAfterFunc(fn func(txErr error) error) {
// OnComplete registers the provided callback that will be invoked
// once the related transaction ends (either completes successfully or rollbacked with an error).
//
// The callback receives the transaction error (if any) as its argument.
// Any additional errors returned by the OnComplete callbacks will be
// joined together with txErr when returning the final transaction result.
func (a *TxAppInfo) OnComplete(fn func(txErr error) error) {
a.mu.Lock()
defer a.mu.Unlock()
a.afterFuncs = append(a.afterFuncs, fn)
}
// note: can be called only once because txAppInfo is cleared
func (a *txAppInfo) runAfterFuncs(txErr error) error {
// note: can be called only once because TxAppInfo is cleared
func (a *TxAppInfo) runAfterFuncs(txErr error) error {
a.mu.Lock()
defer a.mu.Unlock()