diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d2024c6..8747b804 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## v0.28.0 (WIP) + +- Write the default response body of `*Request` hooks that are wrapped in a transaction after the related transaction completes to allow propagating errors ([#6462](https://github.com/pocketbase/pocketbase/discussions/6462#discussioncomment-12207818)). + + ## v0.27.1 - Updated example `geoPoint` API preview body data. diff --git a/apis/batch.go b/apis/batch.go index 85e990ac..39f10d91 100644 --- a/apis/batch.go +++ b/apis/batch.go @@ -49,7 +49,7 @@ var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{ params["id"] = id // required for the path value ir.Method = "PATCH" ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"] - return recordUpdate(next) + return recordUpdate(false, next) } } @@ -57,16 +57,16 @@ var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{ // --- ir.Method = "POST" ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"] - return recordCreate(next) + return recordCreate(false, next) }, regexp.MustCompile(`^POST /api/collections/(?P[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { - return recordCreate(next) + return recordCreate(false, next) }, regexp.MustCompile(`^PATCH /api/collections/(?P[^\/\?]+)/records/(?P[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { - return recordUpdate(next) + return recordUpdate(false, next) }, regexp.MustCompile(`^DELETE /api/collections/(?P[^\/\?]+)/records/(?P[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { - return recordDelete(next) + return recordDelete(false, next) }, } diff --git a/apis/collection.go b/apis/collection.go index ea6e3bc5..3adf1db7 100644 --- a/apis/collection.go +++ b/apis/collection.go @@ -45,7 +45,9 @@ func collectionsList(e *core.RequestEvent) error { event.Result = result return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error { - return e.JSON(http.StatusOK, e.Result) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Result) + }) }) } @@ -60,7 +62,9 @@ func collectionView(e *core.RequestEvent) error { event.Collection = collection return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error { - return e.JSON(http.StatusOK, e.Collection) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Collection) + }) }) } @@ -98,7 +102,9 @@ func collectionCreate(e *core.RequestEvent) error { return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil) } - return e.JSON(http.StatusOK, e.Collection) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Collection) + }) }) } @@ -128,7 +134,9 @@ func collectionUpdate(e *core.RequestEvent) error { return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil) } - return e.JSON(http.StatusOK, e.Collection) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Collection) + }) }) } @@ -159,7 +167,9 @@ func collectionDelete(e *core.RequestEvent) error { return e.BadRequestError(msg, err) } - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/collection_import.go b/apis/collection_import.go index b77a9ba6..1a2f2ada 100644 --- a/apis/collection_import.go +++ b/apis/collection_import.go @@ -29,7 +29,9 @@ func collectionsImport(e *core.RequestEvent) error { return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error { importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing) if importErr == nil { - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) } // validation failure diff --git a/apis/collection_import_test.go b/apis/collection_import_test.go index f20b0772..aa0bf655 100644 --- a/apis/collection_import_test.go +++ b/apis/collection_import_test.go @@ -316,6 +316,51 @@ func TestCollectionsImport(t *testing.T) { } }, }, + { + Name: "OnCollectionsImportRequest tx body write check", + Method: http.MethodPut, + URL: "/api/collections/import", + Body: strings.NewReader(`{ + "deleteMissing": true, + "collections":[ + {"name": "test123"}, + { + "id":"wsmn24bux7wo113", + "name":"demo1", + "fields":[ + { + "id":"_2hlxbmp", + "name":"title", + "type":"text", + "required":true + } + ], + "indexes": [] + } + ] + }`), + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnCollectionsImportRequest().BindFunc(func(e *core.CollectionsImportRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnCollectionsImportRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, } for _, scenario := range scenarios { diff --git a/apis/collection_test.go b/apis/collection_test.go index 35a1b778..8d74f750 100644 --- a/apis/collection_test.go +++ b/apis/collection_test.go @@ -1,7 +1,6 @@ package apis_test import ( - "errors" "net/http" "os" "path/filepath" @@ -130,6 +129,32 @@ func TestCollectionsList(t *testing.T) { "OnCollectionsListRequest": 1, }, }, + { + Name: "OnCollectionsListRequest tx body write check", + Method: http.MethodGet, + URL: "/api/collections", + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnCollectionsListRequest().BindFunc(func(e *core.CollectionsListRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnCollectionsListRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, } for _, scenario := range scenarios { @@ -205,6 +230,32 @@ func TestCollectionView(t *testing.T) { "OnCollectionViewRequest": 1, }, }, + { + Name: "OnCollectionViewRequest tx body write check", + Method: http.MethodGet, + URL: "/api/collections/wsmn24bux7wo113", + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnCollectionViewRequest().BindFunc(func(e *core.CollectionRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnCollectionViewRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, } for _, scenario := range scenarios { @@ -361,7 +412,7 @@ func TestCollectionDelete(t *testing.T) { }, }, { - Name: "OnCollectionAfterDeleteSuccessRequest error response", + Name: "OnCollectionDeleteRequest tx body write check", Method: http.MethodDelete, URL: "/api/collections/view2", Headers: map[string]string{ @@ -369,15 +420,22 @@ func TestCollectionDelete(t *testing.T) { }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnCollectionDeleteRequest().BindFunc(func(e *core.CollectionRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnCollectionDeleteRequest": 1, - }, + ExpectedEvents: map[string]int{"OnCollectionDeleteRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, } @@ -656,7 +714,7 @@ func TestCollectionCreate(t *testing.T) { }, }, { - Name: "OnCollectionCreateRequest error response", + Name: "OnCollectionCreateRequest tx body write check", Method: http.MethodPost, URL: "/api/collections", Body: strings.NewReader(`{"name":"new","type":"base"}`), @@ -665,15 +723,22 @@ func TestCollectionCreate(t *testing.T) { }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnCollectionCreateRequest().BindFunc(func(e *core.CollectionRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnCollectionCreateRequest": 1, - }, + ExpectedEvents: map[string]int{"OnCollectionCreateRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, // view @@ -978,7 +1043,7 @@ func TestCollectionUpdate(t *testing.T) { }, }, { - Name: "OnCollectionAfterUpdateSuccessRequest error response", + Name: "OnCollectionUpdateRequest tx body write check", Method: http.MethodPatch, URL: "/api/collections/demo1", Body: strings.NewReader(`{}`), @@ -987,15 +1052,22 @@ func TestCollectionUpdate(t *testing.T) { }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnCollectionUpdateRequest().BindFunc(func(e *core.CollectionRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnCollectionUpdateRequest": 1, - }, + ExpectedEvents: map[string]int{"OnCollectionUpdateRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, { Name: "authorized as superuser + invalid data (eg. existing name)", diff --git a/apis/file.go b/apis/file.go index 997e74ee..5c6ec211 100644 --- a/apis/file.go +++ b/apis/file.go @@ -75,8 +75,8 @@ func (api *fileApi) fileToken(e *core.RequestEvent) error { event.Record = e.Auth return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error { - return e.JSON(http.StatusOK, map[string]string{ - "token": e.Token, + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, map[string]string{"token": e.Token}) }) }) } @@ -192,7 +192,10 @@ func (api *fileApi) download(e *core.RequestEvent) error { e.Response.Header().Del("X-Frame-Options") return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error { - if err := fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName); err != nil { + err = execAfterSuccessTx(true, e.App, func() error { + return fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName) + }) + if err != nil { return e.NotFoundError("", err) } diff --git a/apis/middlewares.go b/apis/middlewares.go index f5bee73c..6f4424d7 100644 --- a/apis/middlewares.go +++ b/apis/middlewares.go @@ -365,11 +365,25 @@ func logRequest(event *core.RequestEvent, err error) { // parse the request error if err != nil { - if apiErr, ok := err.(*router.ApiError); ok { - status = apiErr.Status + apiErr, isPlainApiError := err.(*router.ApiError) + if isPlainApiError || errors.As(err, &apiErr) { + // the status header wasn't written yet + if status == 0 { + status = apiErr.Status + } + + var errMsg string + if isPlainApiError { + errMsg = apiErr.Message + } else { + // wrapped ApiError -> add the full serialized version + // of the original error since it could contain more information + errMsg = err.Error() + } + attrs = append( attrs, - slog.String("error", apiErr.Message), + slog.String("error", errMsg), slog.Any("details", apiErr.RawData()), ) } else { diff --git a/apis/realtime.go b/apis/realtime.go index 76d2b390..a70093b2 100644 --- a/apis/realtime.go +++ b/apis/realtime.go @@ -213,7 +213,9 @@ func realtimeSetSubscriptions(e *core.RequestEvent) error { slog.Any("subscriptions", e.Subscriptions), ) - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_email_change_confirm.go b/apis/record_auth_email_change_confirm.go index ad09ec1d..799e082e 100644 --- a/apis/record_auth_email_change_confirm.go +++ b/apis/record_auth_email_change_confirm.go @@ -45,7 +45,9 @@ func recordConfirmEmailChange(e *core.RequestEvent) error { return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err)) } - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_email_change_confirm_test.go b/apis/record_auth_email_change_confirm_test.go index 8eb95838..bf56be76 100644 --- a/apis/record_auth_email_change_confirm_test.go +++ b/apis/record_auth_email_change_confirm_test.go @@ -1,7 +1,6 @@ package apis_test import ( - "errors" "net/http" "strings" "testing" @@ -136,7 +135,7 @@ func TestRecordConfirmEmailChange(t *testing.T) { ExpectedEvents: map[string]int{"*": 0}, }, { - Name: "OnRecordAfterConfirmEmailChangeRequest error response", + Name: "OnRecordConfirmEmailChangeRequest tx body write check", Method: http.MethodPost, URL: "/api/collections/users/confirm-email-change", Body: strings.NewReader(`{ @@ -145,15 +144,22 @@ func TestRecordConfirmEmailChange(t *testing.T) { }`), BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordConfirmEmailChangeRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordConfirmEmailChangeRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, // rate limit checks diff --git a/apis/record_auth_email_change_request.go b/apis/record_auth_email_change_request.go index 686db572..079b0903 100644 --- a/apis/record_auth_email_change_request.go +++ b/apis/record_auth_email_change_request.go @@ -43,7 +43,9 @@ func recordRequestEmailChange(e *core.RequestEvent) error { return firstApiError(err, e.BadRequestError("Failed to request email change.", err)) } - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_email_change_request_test.go b/apis/record_auth_email_change_request_test.go index c1282f9d..ef39e0cd 100644 --- a/apis/record_auth_email_change_request_test.go +++ b/apis/record_auth_email_change_request_test.go @@ -118,6 +118,33 @@ func TestRecordRequestEmailChange(t *testing.T) { } }, }, + { + Name: "OnRecordRequestEmailChangeRequest tx body write check", + Method: http.MethodPost, + URL: "/api/collections/users/request-email-change", + Body: strings.NewReader(`{"newEmail":"change@example.com"}`), + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnRecordRequestEmailChangeRequest().BindFunc(func(e *core.RecordRequestEmailChangeRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordRequestEmailChangeRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // rate limit checks // ----------------------------------------------------------- diff --git a/apis/record_auth_impersonate.go b/apis/record_auth_impersonate.go index 75ac2a96..c1982054 100644 --- a/apis/record_auth_impersonate.go +++ b/apis/record_auth_impersonate.go @@ -26,10 +26,10 @@ func recordAuthImpersonate(e *core.RequestEvent) error { form := &impersonateForm{} if err = e.BindBody(form); err != nil { - return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err)) + return e.BadRequestError("An error occurred while loading the submitted data.", err) } if err = form.validate(); err != nil { - return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err)) + return e.BadRequestError("An error occurred while validating the submitted data.", err) } token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second) diff --git a/apis/record_auth_otp_request.go b/apis/record_auth_otp_request.go index 29183d40..6cc5ff74 100644 --- a/apis/record_auth_otp_request.go +++ b/apis/record_auth_otp_request.go @@ -108,8 +108,8 @@ func recordRequestOTP(e *core.RequestEvent) error { }) } - return e.JSON(http.StatusOK, map[string]string{ - "otpId": otp.Id, + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, map[string]string{"otpId": otp.Id}) }) }) } diff --git a/apis/record_auth_otp_request_test.go b/apis/record_auth_otp_request_test.go index a0a607b7..5a262d72 100644 --- a/apis/record_auth_otp_request_test.go +++ b/apis/record_auth_otp_request_test.go @@ -247,6 +247,31 @@ func TestRecordRequestOTP(t *testing.T) { } }, }, + { + Name: "OnRecordRequestOTPRequest tx body write check", + Method: http.MethodPost, + URL: "/api/collections/users/request-otp", + Body: strings.NewReader(`{"email":"test@example.com"}`), + Delay: 100 * time.Millisecond, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnRecordRequestOTPRequest().BindFunc(func(e *core.RecordCreateOTPRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordRequestOTPRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // rate limit checks // ----------------------------------------------------------- diff --git a/apis/record_auth_password_reset_confirm.go b/apis/record_auth_password_reset_confirm.go index c73b0d76..2d3e4706 100644 --- a/apis/record_auth_password_reset_confirm.go +++ b/apis/record_auth_password_reset_confirm.go @@ -54,7 +54,9 @@ func recordConfirmPasswordReset(e *core.RequestEvent) error { e.App.Store().Remove(getPasswordResetResendKey(authRecord)) - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_password_reset_confirm_test.go b/apis/record_auth_password_reset_confirm_test.go index d22ef80c..fff8f2b3 100644 --- a/apis/record_auth_password_reset_confirm_test.go +++ b/apis/record_auth_password_reset_confirm_test.go @@ -1,7 +1,6 @@ package apis_test import ( - "errors" "net/http" "strings" "testing" @@ -282,7 +281,7 @@ func TestRecordConfirmPasswordReset(t *testing.T) { }, }, { - Name: "OnRecordAfterConfirmPasswordResetRequest error response", + Name: "OnRecordConfirmPasswordResetRequest tx body write check", Method: http.MethodPost, URL: "/api/collections/users/confirm-password-reset", Body: strings.NewReader(`{ @@ -292,15 +291,22 @@ func TestRecordConfirmPasswordReset(t *testing.T) { }`), BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordConfirmPasswordResetRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordConfirmPasswordResetRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, // rate limit checks diff --git a/apis/record_auth_password_reset_request.go b/apis/record_auth_password_reset_request.go index 3c0592d4..16d7b843 100644 --- a/apis/record_auth_password_reset_request.go +++ b/apis/record_auth_password_reset_request.go @@ -65,7 +65,9 @@ func recordRequestPasswordReset(e *core.RequestEvent) error { }) }) - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_password_reset_request_test.go b/apis/record_auth_password_reset_request_test.go index cb5ec956..0b4a3485 100644 --- a/apis/record_auth_password_reset_request_test.go +++ b/apis/record_auth_password_reset_request_test.go @@ -101,6 +101,30 @@ func TestRecordRequestPasswordReset(t *testing.T) { app.Store().Set(resendKey, struct{}{}) }, }, + { + Name: "OnRecordRequestPasswordResetRequest tx body write check", + Method: http.MethodPost, + URL: "/api/collections/users/request-password-reset", + Body: strings.NewReader(`{"email":"test@example.com"}`), + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnRecordRequestPasswordResetRequest().BindFunc(func(e *core.RecordRequestPasswordResetRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordRequestPasswordResetRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // rate limit checks // ----------------------------------------------------------- diff --git a/apis/record_auth_refresh_test.go b/apis/record_auth_refresh_test.go index 5f3722cc..d8e4b835 100644 --- a/apis/record_auth_refresh_test.go +++ b/apis/record_auth_refresh_test.go @@ -1,7 +1,6 @@ package apis_test import ( - "errors" "net/http" "testing" @@ -130,23 +129,30 @@ func TestRecordAuthRefresh(t *testing.T) { }, }, { - Name: "OnRecordAfterAuthRefreshRequest error response", + Name: "OnRecordAuthRefreshRequest tx body write check", Method: http.MethodPost, - URL: "/api/collections/users/auth-refresh?expand=rel,missing", + URL: "/api/collections/users/auth-refresh", Headers: map[string]string{ "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo", }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordAuthRefreshRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordAuthRefreshRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, // rate limit checks diff --git a/apis/record_auth_verification_confirm.go b/apis/record_auth_verification_confirm.go index 509ebc42..dcf86fb2 100644 --- a/apis/record_auth_verification_confirm.go +++ b/apis/record_auth_verification_confirm.go @@ -42,19 +42,19 @@ func recordConfirmVerification(e *core.RequestEvent) error { event.Record = record return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error { - if wasVerified { - return e.NoContent(http.StatusNoContent) - } + if !wasVerified { + e.Record.SetVerified(true) - e.Record.SetVerified(true) - - if err := e.App.Save(e.Record); err != nil { - return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err)) + if err := e.App.Save(e.Record); err != nil { + return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err)) + } } e.App.Store().Remove(getVerificationResendKey(e.Record)) - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_verification_confirm_test.go b/apis/record_auth_verification_confirm_test.go index ea70fd29..ef48c7c6 100644 --- a/apis/record_auth_verification_confirm_test.go +++ b/apis/record_auth_verification_confirm_test.go @@ -1,7 +1,6 @@ package apis_test import ( - "errors" "net/http" "strings" "testing" @@ -144,7 +143,7 @@ func TestRecordConfirmVerification(t *testing.T) { }, }, { - Name: "OnRecordAfterConfirmVerificationRequest error response", + Name: "OnRecordConfirmVerificationRequest tx body write check", Method: http.MethodPost, URL: "/api/collections/users/confirm-verification", Body: strings.NewReader(`{ @@ -152,15 +151,22 @@ func TestRecordConfirmVerification(t *testing.T) { }`), BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordConfirmVerificationRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordConfirmVerificationRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, // rate limit checks diff --git a/apis/record_auth_verification_request.go b/apis/record_auth_verification_request.go index fc980e41..dbabe2df 100644 --- a/apis/record_auth_verification_request.go +++ b/apis/record_auth_verification_request.go @@ -68,7 +68,9 @@ func recordRequestVerification(e *core.RequestEvent) error { }) }) - return e.NoContent(http.StatusNoContent) + return execAfterSuccessTx(true, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) }) } diff --git a/apis/record_auth_verification_request_test.go b/apis/record_auth_verification_request_test.go index a15ab1bd..edbc7ca7 100644 --- a/apis/record_auth_verification_request_test.go +++ b/apis/record_auth_verification_request_test.go @@ -118,6 +118,30 @@ func TestRecordRequestVerification(t *testing.T) { } }, }, + { + Name: "OnRecordRequestVerificationRequest tx body write check", + Method: http.MethodPost, + URL: "/api/collections/users/request-verification", + Body: strings.NewReader(`{"email":"test@example.com"}`), + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnRecordRequestVerificationRequest().BindFunc(func(e *core.RecordRequestVerificationRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordRequestVerificationRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // rate limit checks // ----------------------------------------------------------- diff --git a/apis/record_auth_with_oauth2_test.go b/apis/record_auth_with_oauth2_test.go index 1893bb51..0e2a385c 100644 --- a/apis/record_auth_with_oauth2_test.go +++ b/apis/record_auth_with_oauth2_test.go @@ -1577,6 +1577,69 @@ func TestRecordAuthWithOAuth2(t *testing.T) { "OnRecordValidate": 4, }, }, + { + Name: "OnRecordAuthWithOAuth2Request tx body write check", + Method: http.MethodPost, + URL: "/api/collections/users/auth-with-oauth2", + Body: strings.NewReader(`{ + "provider": "test", + "code":"123", + "redirectURL": "https://example.com" + }`), + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + user, err := app.FindAuthRecordByEmail("users", "test@example.com") + if err != nil { + t.Fatal(err) + } + + // register the test provider + auth.Providers["test"] = func() auth.Provider { + return &oauth2MockProvider{ + AuthUser: &auth.AuthUser{Id: "test_id"}, + Token: &oauth2.Token{AccessToken: "abc"}, + } + } + + // add the test provider in the collection + user.Collection().MFA.Enabled = false + user.Collection().OAuth2.Enabled = true + user.Collection().OAuth2.Providers = []core.OAuth2ProviderConfig{{ + Name: "test", + ClientId: "123", + ClientSecret: "456", + }} + if err := app.Save(user.Collection()); err != nil { + t.Fatal(err) + } + + // stub linked provider + ea := core.NewExternalAuth(app) + ea.SetCollectionRef(user.Collection().Id) + ea.SetRecordRef(user.Id) + ea.SetProvider("test") + ea.SetProviderId("test_id") + if err := app.Save(ea); err != nil { + t.Fatal(err) + } + + app.OnRecordAuthWithOAuth2Request().BindFunc(func(e *core.RecordAuthWithOAuth2RequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordAuthWithOAuth2Request": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // rate limit checks // ----------------------------------------------------------- diff --git a/apis/record_auth_with_otp_test.go b/apis/record_auth_with_otp_test.go index c8757d74..bc514422 100644 --- a/apis/record_auth_with_otp_test.go +++ b/apis/record_auth_with_otp_test.go @@ -419,6 +419,53 @@ func TestRecordAuthWithOTP(t *testing.T) { } }, }, + { + Name: "OnRecordAuthWithOTPRequest tx body write check", + Method: http.MethodPost, + URL: "/api/collections/users/auth-with-otp", + Body: strings.NewReader(`{ + "otpId":"` + strings.Repeat("a", 15) + `", + "password":"123456" + }`), + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + user, err := app.FindAuthRecordByEmail("users", "test@example.com") + if err != nil { + t.Fatal(err) + } + + // disable MFA + user.Collection().MFA.Enabled = false + if err = app.Save(user.Collection()); err != nil { + t.Fatal(err) + } + + otp := core.NewOTP(app) + otp.Id = strings.Repeat("a", 15) + otp.SetCollectionRef(user.Collection().Id) + otp.SetRecordRef(user.Id) + otp.SetPassword("123456") + if err := app.Save(otp); err != nil { + t.Fatal(err) + } + + app.OnRecordAuthWithOTPRequest().BindFunc(func(e *core.RecordAuthWithOTPRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordAuthWithOTPRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // rate limit checks // ----------------------------------------------------------- diff --git a/apis/record_auth_with_password_test.go b/apis/record_auth_with_password_test.go index a2c21b02..8ed3aea6 100644 --- a/apis/record_auth_with_password_test.go +++ b/apis/record_auth_with_password_test.go @@ -1,7 +1,6 @@ package apis_test import ( - "errors" "net/http" "strings" "testing" @@ -82,7 +81,7 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedEvents: map[string]int{"*": 0}, }, { - Name: "OnRecordAuthWithPasswordRequest error response", + Name: "OnRecordAuthWithPasswordRequest tx body write check", Method: http.MethodPost, URL: "/api/collections/clients/auth-with-password", Body: strings.NewReader(`{ @@ -91,15 +90,22 @@ func TestRecordAuthWithPassword(t *testing.T) { }`), BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordAuthWithPasswordRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordAuthWithPasswordRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, { Name: "valid identity field and invalid password", diff --git a/apis/record_crud.go b/apis/record_crud.go index 18aa4d4e..de7cd8d1 100644 --- a/apis/record_crud.go +++ b/apis/record_crud.go @@ -28,9 +28,9 @@ func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId) subGroup.GET("", recordsList) subGroup.GET("/{id}", recordView) - subGroup.POST("", recordCreate(nil)).Bind(dynamicCollectionBodyLimit("")) - subGroup.PATCH("/{id}", recordUpdate(nil)).Bind(dynamicCollectionBodyLimit("")) - subGroup.DELETE("/{id}", recordDelete(nil)) + subGroup.POST("", recordCreate(true, nil)).Bind(dynamicCollectionBodyLimit("")) + subGroup.PATCH("/{id}", recordUpdate(true, nil)).Bind(dynamicCollectionBodyLimit("")) + subGroup.DELETE("/{id}", recordDelete(true, nil)) } func recordsList(e *core.RequestEvent) error { @@ -121,7 +121,9 @@ func recordsList(e *core.RequestEvent) error { randomizedThrottle(150) } - return e.JSON(http.StatusOK, e.Result) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Result) + }) }) } @@ -192,11 +194,13 @@ func recordView(e *core.RequestEvent) error { return firstApiError(err, e.InternalServerError("Failed to enrich record", err)) } - return e.JSON(http.StatusOK, e.Record) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Record) + }) }) } -func recordCreate(optFinalizer func(data any) error) func(e *core.RequestEvent) error { +func recordCreate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error { return func(e *core.RequestEvent) error { collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection")) if err != nil || collection == nil { @@ -344,7 +348,9 @@ func recordCreate(optFinalizer func(data any) error) func(e *core.RequestEvent) return firstApiError(err, e.InternalServerError("Failed to enrich record", err)) } - err = e.JSON(http.StatusOK, e.Record) + err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error { + return e.JSON(http.StatusOK, e.Record) + }) if err != nil { return err } @@ -374,7 +380,7 @@ func recordCreate(optFinalizer func(data any) error) func(e *core.RequestEvent) } } -func recordUpdate(optFinalizer func(data any) error) func(e *core.RequestEvent) error { +func recordUpdate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error { return func(e *core.RequestEvent) error { collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection")) if err != nil || collection == nil { @@ -475,7 +481,9 @@ func recordUpdate(optFinalizer func(data any) error) func(e *core.RequestEvent) return firstApiError(err, e.InternalServerError("Failed to enrich record", err)) } - err = e.JSON(http.StatusOK, e.Record) + err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error { + return e.JSON(http.StatusOK, e.Record) + }) if err != nil { return err } @@ -505,7 +513,7 @@ func recordUpdate(optFinalizer func(data any) error) func(e *core.RequestEvent) } } -func recordDelete(optFinalizer func(data any) error) func(e *core.RequestEvent) error { +func recordDelete(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error { return func(e *core.RequestEvent) error { collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection")) if err != nil || collection == nil { @@ -565,7 +573,9 @@ func recordDelete(optFinalizer func(data any) error) func(e *core.RequestEvent) return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err)) } - err = e.NoContent(http.StatusNoContent) + err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error { + return e.NoContent(http.StatusNoContent) + }) if err != nil { return err } diff --git a/apis/record_crud_test.go b/apis/record_crud_test.go index 0d7b3cc1..2db19fca 100644 --- a/apis/record_crud_test.go +++ b/apis/record_crud_test.go @@ -2,7 +2,6 @@ package apis_test import ( "bytes" - "errors" "net/http" "net/url" "os" @@ -418,6 +417,32 @@ func TestRecordCrudList(t *testing.T) { "OnRecordsListRequest": 1, }, }, + { + Name: "OnRecordsListRequest tx body write check", + Method: http.MethodGet, + URL: "/api/collections/demo4/records", + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnRecordsListRequest().BindFunc(func(e *core.RecordsListRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordsListRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // auth collection // ----------------------------------------------------------- @@ -862,6 +887,32 @@ func TestRecordCrudView(t *testing.T) { "OnRecordEnrich": 7, }, }, + { + Name: "OnRecordViewRequest tx body write check", + Method: http.MethodGet, + URL: "/api/collections/demo1/records/al1h9ijdeojtsjy", + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnRecordViewRequest().BindFunc(func(e *core.RecordRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnRecordViewRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, // auth collection // ----------------------------------------------------------- @@ -1209,7 +1260,7 @@ func TestRecordCrudDelete(t *testing.T) { }, }, { - Name: "OnRecordAfterDeleteSuccessRequest error response", + Name: "OnRecordDeleteRequest tx body write check", Method: http.MethodDelete, URL: "/api/collections/clients/records/o1y0dd0spd786md", Headers: map[string]string{ @@ -1217,15 +1268,22 @@ func TestRecordCrudDelete(t *testing.T) { }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordDeleteRequest().BindFunc(func(e *core.RecordRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordDeleteRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordDeleteRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, { Name: "authenticated record that match the collection delete rule", @@ -1792,21 +1850,31 @@ func TestRecordCrudCreate(t *testing.T) { }, }, { - Name: "OnRecordAfterCreateSuccessRequest error response", + Name: "OnRecordCreateRequest tx body write check", Method: http.MethodPost, URL: "/api/collections/demo2/records", Body: strings.NewReader(`{"title":"new"}`), + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordCreateRequest().BindFunc(func(e *core.RecordRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordCreateRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordCreateRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, // ID checks @@ -2799,21 +2867,31 @@ func TestRecordCrudUpdate(t *testing.T) { }, }, { - Name: "OnRecordAfterUpdateSuccessRequest error response", + Name: "OnRecordUpdateRequest tx body write check", Method: http.MethodPatch, URL: "/api/collections/demo2/records/0yxhwia2amd8gec", Body: strings.NewReader(`{"title":"new"}`), + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { app.OnRecordUpdateRequest().BindFunc(func(e *core.RecordRequestEvent) error { - return errors.New("error") + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) }) }, ExpectedStatus: 400, - ExpectedContent: []string{`"data":{}`}, - ExpectedEvents: map[string]int{ - "*": 0, - "OnRecordUpdateRequest": 1, - }, + ExpectedEvents: map[string]int{"OnRecordUpdateRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, }, { Name: "try to change the id of an existing record", diff --git a/apis/record_helpers.go b/apis/record_helpers.go index de5ff5b5..3ba5e341 100644 --- a/apis/record_helpers.go +++ b/apis/record_helpers.go @@ -129,7 +129,9 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str result.Meta = e.Meta } - return e.JSON(http.StatusOK, result) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, result) + }) }) } @@ -535,6 +537,27 @@ func firstApiError(errs ...error) *router.ApiError { return router.NewInternalServerError("", errors.Join(errs...)) } +// execAfterSuccessTx ensures that fn is executed only after a succesul transaction. +// +// If the current app instance is not a transactional or checkTx is false, +// then fn is directly executed. +// +// It could be usually used to allow propagating an error or writing +// custom response from within the wrapped transaction block. +func execAfterSuccessTx(checkTx bool, app core.App, fn func() error) error { + if txInfo := app.TxInfo(); txInfo != nil && checkTx { + txInfo.OnComplete(func(txErr error) error { + if txErr == nil { + return fn() + } + return nil + }) + return nil + } + + return fn() +} + // ------------------------------------------------------------------- const maxAuthOrigins = 5 diff --git a/apis/settings.go b/apis/settings.go index 6ef31acd..d6d609cd 100644 --- a/apis/settings.go +++ b/apis/settings.go @@ -30,7 +30,9 @@ func settingsList(e *core.RequestEvent) error { event.Settings = clone return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error { - return e.JSON(http.StatusOK, e.Settings) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, e.Settings) + }) }) } @@ -65,7 +67,9 @@ func settingsSet(e *core.RequestEvent) error { return e.InternalServerError("Failed to clone app settings.", err) } - return e.JSON(http.StatusOK, appSettings) + return execAfterSuccessTx(true, e.App, func() error { + return e.JSON(http.StatusOK, appSettings) + }) }) } diff --git a/apis/settings_test.go b/apis/settings_test.go index 02deedf9..3c20a750 100644 --- a/apis/settings_test.go +++ b/apis/settings_test.go @@ -11,6 +11,7 @@ import ( "strings" "testing" + "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tests" ) @@ -58,6 +59,32 @@ func TestSettingsList(t *testing.T) { "OnSettingsListRequest": 1, }, }, + { + Name: "OnSettingsListRequest tx body write check", + Method: http.MethodGet, + URL: "/api/settings", + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnSettingsListRequest().BindFunc(func(e *core.SettingsListRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnSettingsListRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, } for _, scenario := range scenarios { @@ -176,6 +203,33 @@ func TestSettingsSet(t *testing.T) { "OnSettingsReload": 1, }, }, + { + Name: "OnSettingsUpdateRequest tx body write check", + Method: http.MethodPatch, + URL: "/api/settings", + Body: strings.NewReader(validData), + Headers: map[string]string{ + "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY", + }, + BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) { + app.OnSettingsUpdateRequest().BindFunc(func(e *core.SettingsUpdateRequestEvent) error { + original := e.App + return e.App.RunInTransaction(func(txApp core.App) error { + e.App = txApp + defer func() { e.App = original }() + + if err := e.Next(); err != nil { + return err + } + + return e.BadRequestError("TX_ERROR", nil) + }) + }) + }, + ExpectedStatus: 400, + ExpectedEvents: map[string]int{"OnSettingsUpdateRequest": 1}, + ExpectedContent: []string{"TX_ERROR"}, + }, } for _, scenario := range scenarios { diff --git a/core/app.go b/core/app.go index 7a1b9c06..4c6247a0 100644 --- a/core/app.go +++ b/core/app.go @@ -45,6 +45,12 @@ type App interface { // IsTransactional checks if the current app instance is part of a transaction. IsTransactional() bool + // TxInfo returns the transaction associated with the current app instance (if any). + // + // Could be used if you want to execute indirectly a function after + // the related app transaction completes using `app.TxInfo().OnAfterFunc(callback)`. + TxInfo() *TxAppInfo + // Bootstrap initializes the application // (aka. create data dir, open db connections, load settings, etc.). // diff --git a/core/base.go b/core/base.go index 89a8609b..15a6fe88 100644 --- a/core/base.go +++ b/core/base.go @@ -69,7 +69,7 @@ var _ App = (*BaseApp)(nil) // BaseApp implements core.App and defines the base PocketBase app structure. type BaseApp struct { config *BaseAppConfig - txInfo *txAppInfo + txInfo *TxAppInfo store *store.Store[string, any] cron *cron.Cron settings *Settings @@ -360,9 +360,17 @@ func (app *BaseApp) Logger() *slog.Logger { return app.logger } +// TxInfo returns the transaction associated with the current app instance (if any). +// +// Could be used if you want to execute indirectly a function after +// the related app transaction completes using `app.TxInfo().OnAfterFunc(callback)`. +func (app *BaseApp) TxInfo() *TxAppInfo { + return app.txInfo +} + // IsTransactional checks if the current app instance is part of a transaction. func (app *BaseApp) IsTransactional() bool { - return app.txInfo != nil + return app.TxInfo() != nil } // IsBootstrapped checks if the application was initialized diff --git a/core/base_test.go b/core/base_test.go index 96d4b448..794dd2e9 100644 --- a/core/base_test.go +++ b/core/base_test.go @@ -128,7 +128,7 @@ func TestBaseAppBootstrap(t *testing.T) { runNilChecks(nilChecksAfterReset) } -func TestNewBaseAppIsTransactional(t *testing.T) { +func TestNewBaseAppTx(t *testing.T) { const testDataDir = "./pb_base_app_test_data_dir/" defer os.RemoveAll(testDataDir) @@ -141,17 +141,34 @@ func TestNewBaseAppIsTransactional(t *testing.T) { t.Fatal(err) } - if app.IsTransactional() { - t.Fatalf("Didn't expect the app to be transactional") + mustNotHaveTx := func(app core.App) { + if app.IsTransactional() { + t.Fatalf("Didn't expect the app to be transactional") + } + + if app.TxInfo() != nil { + t.Fatalf("Didn't expect the app.txInfo to be loaded") + } } - app.RunInTransaction(func(txApp core.App) error { - if !txApp.IsTransactional() { + mustHaveTx := func(app core.App) { + if !app.IsTransactional() { t.Fatalf("Expected the app to be transactional") } + if app.TxInfo() == nil { + t.Fatalf("Expected the app.txInfo to be loaded") + } + } + + mustNotHaveTx(app) + + app.RunInTransaction(func(txApp core.App) error { + mustHaveTx(txApp) return nil }) + + mustNotHaveTx(app) } func TestBaseAppNewMailClient(t *testing.T) { diff --git a/core/db.go b/core/db.go index e799d832..9c3312d2 100644 --- a/core/db.go +++ b/core/db.go @@ -151,7 +151,7 @@ func (app *BaseApp) delete(ctx context.Context, model Model, isForAuxDB bool) er if app.txInfo != nil { // execute later after the transaction has completed - app.txInfo.onAfterFunc(func(txErr error) error { + app.txInfo.OnComplete(func(txErr error) error { if app.txInfo != nil && app.txInfo.parent != nil { event.App = app.txInfo.parent } @@ -342,7 +342,7 @@ func (app *BaseApp) create(ctx context.Context, model Model, withValidations boo if app.txInfo != nil { // execute later after the transaction has completed - app.txInfo.onAfterFunc(func(txErr error) error { + app.txInfo.OnComplete(func(txErr error) error { if app.txInfo != nil && app.txInfo.parent != nil { event.App = app.txInfo.parent } @@ -426,7 +426,7 @@ func (app *BaseApp) update(ctx context.Context, model Model, withValidations boo if app.txInfo != nil { // execute later after the transaction has completed - app.txInfo.onAfterFunc(func(txErr error) error { + app.txInfo.OnComplete(func(txErr error) error { if app.txInfo != nil && app.txInfo.parent != nil { event.App = app.txInfo.parent } diff --git a/core/db_tx.go b/core/db_tx.go index 53ef4f2b..3f08228a 100644 --- a/core/db_tx.go +++ b/core/db_tx.go @@ -60,7 +60,7 @@ func (app *BaseApp) createTxApp(tx *dbx.Tx, isForAuxDB bool) *BaseApp { clone.nonconcurrentDB = tx } - clone.txInfo = &txAppInfo{ + clone.txInfo = &TxAppInfo{ parent: app, isForAuxDB: isForAuxDB, } @@ -68,22 +68,29 @@ func (app *BaseApp) createTxApp(tx *dbx.Tx, isForAuxDB bool) *BaseApp { return &clone } -type txAppInfo struct { +// TxAppInfo represents an active transaction context associated to an existing app instance. +type TxAppInfo struct { parent *BaseApp afterFuncs []func(txErr error) error mu sync.Mutex isForAuxDB bool } -func (a *txAppInfo) onAfterFunc(fn func(txErr error) error) { +// OnComplete registers the provided callback that will be invoked +// once the related transaction ends (either completes successfully or rollbacked with an error). +// +// The callback receives the transaction error (if any) as its argument. +// Any additional errors returned by the OnComplete callbacks will be +// joined together with txErr when returning the final transaction result. +func (a *TxAppInfo) OnComplete(fn func(txErr error) error) { a.mu.Lock() defer a.mu.Unlock() a.afterFuncs = append(a.afterFuncs, fn) } -// note: can be called only once because txAppInfo is cleared -func (a *txAppInfo) runAfterFuncs(txErr error) error { +// note: can be called only once because TxAppInfo is cleared +func (a *TxAppInfo) runAfterFuncs(txErr error) error { a.mu.Lock() defer a.mu.Unlock()