diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e7a3824..c12769ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ ```go app.OnBeforeBootstrap() app.OnAfterBootstrap() + app.OnRealtimeDisconnectRequest() + app.OnRealtimeBeforeMessageSend() + app.OnRealtimeAfterMessageSend() ``` - Refactored the `migrate` command to support **external JavaScript migration files** using an embedded JS interpreter ([goja](https://github.com/dop251/goja)). diff --git a/apis/realtime.go b/apis/realtime.go index 91191292..6b235427 100644 --- a/apis/realtime.go +++ b/apis/realtime.go @@ -42,7 +42,14 @@ func (api *realtimeApi) connect(c echo.Context) error { // register new subscription client client := subscriptions.NewDefaultClient() api.app.SubscriptionsBroker().Register(client) - defer api.app.SubscriptionsBroker().Unregister(client.Id()) + defer func() { + api.app.OnRealtimeDisconnectRequest().Trigger(&core.RealtimeDisconnectEvent{ + HttpContext: c, + Client: client, + }) + + api.app.SubscriptionsBroker().Unregister(client.Id()) + }() c.Response().Header().Set("Content-Type", "text/event-stream; charset=UTF-8") c.Response().Header().Set("Cache-Control", "no-store") @@ -51,12 +58,12 @@ func (api *realtimeApi) connect(c echo.Context) error { // https://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering c.Response().Header().Set("X-Accel-Buffering", "no") - event := &core.RealtimeConnectEvent{ + connectEvent := &core.RealtimeConnectEvent{ HttpContext: c, Client: client, } - if err := api.app.OnRealtimeConnectRequest().Trigger(event); err != nil { + if err := api.app.OnRealtimeConnectRequest().Trigger(connectEvent); err != nil { return err } @@ -65,10 +72,31 @@ func (api *realtimeApi) connect(c echo.Context) error { } // signalize established connection (aka. fire "connect" message) - fmt.Fprint(c.Response(), "id:"+client.Id()+"\n") - fmt.Fprint(c.Response(), "event:PB_CONNECT\n") - fmt.Fprint(c.Response(), "data:{\"clientId\":\""+client.Id()+"\"}\n\n") - c.Response().Flush() + connectMsgEvent := &core.RealtimeMessageEvent{ + HttpContext: c, + Client: client, + Message: &subscriptions.Message{ + Name: "PB_CONNECT", + Data: `{"clientId":"` + client.Id() + `"}`, + }, + } + connectMsgErr := api.app.OnRealtimeBeforeMessageSend().Trigger(connectMsgEvent, func(e *core.RealtimeMessageEvent) error { + w := e.HttpContext.Response() + fmt.Fprint(w, "id:"+client.Id()+"\n") + fmt.Fprint(w, "event:"+e.Message.Name+"\n") + fmt.Fprint(w, "data:"+e.Message.Data+"\n\n") + w.Flush() + return nil + }) + if connectMsgErr != nil { + if api.app.IsDebug() { + log.Println("Realtime connection closed (failed to deliver PB_CONNECT):", client.Id(), connectMsgErr) + } + return nil + } + if err := api.app.OnRealtimeAfterMessageSend().Trigger(connectMsgEvent); err != nil && api.app.IsDebug() { + log.Println("OnRealtimeAfterMessageSend PB_CONNECT error:", err) + } // start an idle timer to keep track of inactive/forgotten connections idleDuration := 5 * time.Minute @@ -88,11 +116,29 @@ func (api *realtimeApi) connect(c echo.Context) error { return nil } - w := c.Response() - fmt.Fprint(w, "id:"+client.Id()+"\n") - fmt.Fprint(w, "event:"+msg.Name+"\n") - fmt.Fprint(w, "data:"+msg.Data+"\n\n") - w.Flush() + msgEvent := &core.RealtimeMessageEvent{ + HttpContext: c, + Client: client, + Message: &msg, + } + msgErr := api.app.OnRealtimeBeforeMessageSend().Trigger(msgEvent, func(e *core.RealtimeMessageEvent) error { + w := e.HttpContext.Response() + fmt.Fprint(w, "id:"+e.Client.Id()+"\n") + fmt.Fprint(w, "event:"+e.Message.Name+"\n") + fmt.Fprint(w, "data:"+e.Message.Data+"\n\n") + w.Flush() + return nil + }) + if msgErr != nil { + if api.app.IsDebug() { + log.Println("Realtime connection closed (failed to deliver message):", client.Id(), msgErr) + } + return nil + } + + if err := api.app.OnRealtimeAfterMessageSend().Trigger(msgEvent); err != nil && api.app.IsDebug() { + log.Println("OnRealtimeAfterMessageSend error:", err) + } idleTimer.Stop() idleTimer.Reset(idleDuration) diff --git a/apis/realtime_test.go b/apis/realtime_test.go index 6810ad22..037fc293 100644 --- a/apis/realtime_test.go +++ b/apis/realtime_test.go @@ -1,6 +1,7 @@ package apis_test import ( + "errors" "net/http" "strings" "testing" @@ -10,6 +11,7 @@ import ( "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" + "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/subscriptions" ) @@ -25,7 +27,56 @@ func TestRealtimeConnect(t *testing.T) { `data:{"clientId":`, }, ExpectedEvents: map[string]int{ - "OnRealtimeConnectRequest": 1, + "OnRealtimeConnectRequest": 1, + "OnRealtimeBeforeMessageSend": 1, + "OnRealtimeAfterMessageSend": 1, + "OnRealtimeDisconnectRequest": 1, + }, + AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { + if len(app.SubscriptionsBroker().Clients()) != 0 { + t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) + } + }, + }, + { + Name: "PB_CONNECT interrupt", + Method: http.MethodGet, + Url: "/api/realtime", + ExpectedStatus: 200, + ExpectedEvents: map[string]int{ + "OnRealtimeConnectRequest": 1, + "OnRealtimeBeforeMessageSend": 1, + "OnRealtimeDisconnectRequest": 1, + }, + BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { + app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { + if e.Message.Name == "PB_CONNECT" { + return errors.New("PB_CONNECT error") + } + return nil + }) + }, + AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { + if len(app.SubscriptionsBroker().Clients()) != 0 { + t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) + } + }, + }, + { + Name: "Skipping/ignoring messages", + Method: http.MethodGet, + Url: "/api/realtime", + ExpectedStatus: 200, + ExpectedEvents: map[string]int{ + "OnRealtimeConnectRequest": 1, + "OnRealtimeBeforeMessageSend": 1, + "OnRealtimeAfterMessageSend": 1, + "OnRealtimeDisconnectRequest": 1, + }, + BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { + app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { + return hook.StopPropagation + }) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { if len(app.SubscriptionsBroker().Clients()) != 0 { diff --git a/core/app.go b/core/app.go index 77d7fdf3..c962bf84 100644 --- a/core/app.go +++ b/core/app.go @@ -176,6 +176,21 @@ type App interface { // the SSE client connection. OnRealtimeConnectRequest() *hook.Hook[*RealtimeConnectEvent] + // OnRealtimeDisconnectRequest hook is triggered on disconnected/interrupted + // SSE client connection. + OnRealtimeDisconnectRequest() *hook.Hook[*RealtimeDisconnectEvent] + + // OnRealtimeBeforeMessage hook is triggered right before sending + // an SSE message to a client. + // + // Returning [hook.StopPropagation] will prevent sending the message. + // Returning any other non-nil error will close the realtime connection. + OnRealtimeBeforeMessageSend() *hook.Hook[*RealtimeMessageEvent] + + // OnRealtimeBeforeMessage hook is triggered right after sending + // an SSE message to a client. + OnRealtimeAfterMessageSend() *hook.Hook[*RealtimeMessageEvent] + // OnRealtimeBeforeSubscribeRequest hook is triggered before changing // the client subscriptions, allowing you to further validate and // modify the submitted change. diff --git a/core/base.go b/core/base.go index af49e36d..7d5801fa 100644 --- a/core/base.go +++ b/core/base.go @@ -64,6 +64,9 @@ type BaseApp struct { // realtime api event hooks onRealtimeConnectRequest *hook.Hook[*RealtimeConnectEvent] + onRealtimeDisconnectRequest *hook.Hook[*RealtimeDisconnectEvent] + onRealtimeBeforeMessageSend *hook.Hook[*RealtimeMessageEvent] + onRealtimeAfterMessageSend *hook.Hook[*RealtimeMessageEvent] onRealtimeBeforeSubscribeRequest *hook.Hook[*RealtimeSubscribeEvent] onRealtimeAfterSubscribeRequest *hook.Hook[*RealtimeSubscribeEvent] @@ -153,6 +156,9 @@ func NewBaseApp(dataDir string, encryptionEnv string, isDebug bool) *BaseApp { // realtime API event hooks onRealtimeConnectRequest: &hook.Hook[*RealtimeConnectEvent]{}, + onRealtimeDisconnectRequest: &hook.Hook[*RealtimeDisconnectEvent]{}, + onRealtimeBeforeMessageSend: &hook.Hook[*RealtimeMessageEvent]{}, + onRealtimeAfterMessageSend: &hook.Hook[*RealtimeMessageEvent]{}, onRealtimeBeforeSubscribeRequest: &hook.Hook[*RealtimeSubscribeEvent]{}, onRealtimeAfterSubscribeRequest: &hook.Hook[*RealtimeSubscribeEvent]{}, @@ -471,6 +477,18 @@ func (app *BaseApp) OnRealtimeConnectRequest() *hook.Hook[*RealtimeConnectEvent] return app.onRealtimeConnectRequest } +func (app *BaseApp) OnRealtimeDisconnectRequest() *hook.Hook[*RealtimeDisconnectEvent] { + return app.onRealtimeDisconnectRequest +} + +func (app *BaseApp) OnRealtimeBeforeMessageSend() *hook.Hook[*RealtimeMessageEvent] { + return app.onRealtimeBeforeMessageSend +} + +func (app *BaseApp) OnRealtimeAfterMessageSend() *hook.Hook[*RealtimeMessageEvent] { + return app.onRealtimeAfterMessageSend +} + func (app *BaseApp) OnRealtimeBeforeSubscribeRequest() *hook.Hook[*RealtimeSubscribeEvent] { return app.onRealtimeBeforeSubscribeRequest } diff --git a/core/events.go b/core/events.go index c88d80f8..0a811f81 100644 --- a/core/events.go +++ b/core/events.go @@ -61,6 +61,17 @@ type RealtimeConnectEvent struct { Client subscriptions.Client } +type RealtimeDisconnectEvent struct { + HttpContext echo.Context + Client subscriptions.Client +} + +type RealtimeMessageEvent struct { + HttpContext echo.Context + Client subscriptions.Client + Message *subscriptions.Message +} + type RealtimeSubscribeEvent struct { HttpContext echo.Context Client subscriptions.Client diff --git a/tests/app.go b/tests/app.go index 8ff287e9..253080a7 100644 --- a/tests/app.go +++ b/tests/app.go @@ -222,6 +222,21 @@ func NewTestApp(optTestDataDir ...string) (*TestApp, error) { return nil }) + t.OnRealtimeDisconnectRequest().Add(func(e *core.RealtimeDisconnectEvent) error { + t.EventCalls["OnRealtimeDisconnectRequest"]++ + return nil + }) + + t.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { + t.EventCalls["OnRealtimeBeforeMessageSend"]++ + return nil + }) + + t.OnRealtimeAfterMessageSend().Add(func(e *core.RealtimeMessageEvent) error { + t.EventCalls["OnRealtimeAfterMessageSend"]++ + return nil + }) + t.OnRealtimeBeforeSubscribeRequest().Add(func(e *core.RealtimeSubscribeEvent) error { t.EventCalls["OnRealtimeBeforeSubscribeRequest"]++ return nil