[#2325] trigger the related record realtime events on custom record model change
This commit is contained in:
		
							parent
							
								
									fdf4f6d3bd
								
							
						
					
					
						commit
						818857dea2
					
				| 
						 | 
				
			
			@ -7,6 +7,7 @@ import (
 | 
			
		|||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/labstack/echo/v5"
 | 
			
		||||
| 
						 | 
				
			
			@ -215,7 +216,9 @@ func (api *realtimeApi) setSubscriptions(c echo.Context) error {
 | 
			
		|||
func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error {
 | 
			
		||||
	for _, client := range api.app.SubscriptionsBroker().Clients() {
 | 
			
		||||
		clientModel, _ := client.Get(contextKey).(models.Model)
 | 
			
		||||
		if clientModel != nil && clientModel.GetId() == newModel.GetId() {
 | 
			
		||||
		if clientModel != nil &&
 | 
			
		||||
			clientModel.TableName() == newModel.TableName() &&
 | 
			
		||||
			clientModel.GetId() == newModel.GetId() {
 | 
			
		||||
			client.Set(contextKey, newModel)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +230,9 @@ func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel model
 | 
			
		|||
func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error {
 | 
			
		||||
	for _, client := range api.app.SubscriptionsBroker().Clients() {
 | 
			
		||||
		clientModel, _ := client.Get(contextKey).(models.Model)
 | 
			
		||||
		if clientModel != nil && clientModel.GetId() == model.GetId() {
 | 
			
		||||
		if clientModel != nil &&
 | 
			
		||||
			clientModel.TableName() == model.TableName() &&
 | 
			
		||||
			clientModel.GetId() == model.GetId() {
 | 
			
		||||
			api.app.SubscriptionsBroker().Unregister(client.Id())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -238,7 +243,7 @@ func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model mo
 | 
			
		|||
func (api *realtimeApi) bindEvents() {
 | 
			
		||||
	// update the clients that has admin or auth record association
 | 
			
		||||
	api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
 | 
			
		||||
		if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() {
 | 
			
		||||
		if record := api.resolveRecord(e.Model); record != nil && record.Collection().IsAuth() {
 | 
			
		||||
			return api.updateClientsAuthModel(ContextAuthRecordKey, record)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -251,8 +256,8 @@ func (api *realtimeApi) bindEvents() {
 | 
			
		|||
 | 
			
		||||
	// remove the client(s) associated to the deleted admin or auth record
 | 
			
		||||
	api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error {
 | 
			
		||||
		if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() {
 | 
			
		||||
			return api.unregisterClientsByAuthModel(ContextAuthRecordKey, record)
 | 
			
		||||
		if collection := api.resolveRecordCollection(e.Model); collection != nil && collection.IsAuth() {
 | 
			
		||||
			return api.unregisterClientsByAuthModel(ContextAuthRecordKey, e.Model)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if admin, ok := e.Model.(*models.Admin); ok && admin != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -263,7 +268,7 @@ func (api *realtimeApi) bindEvents() {
 | 
			
		|||
	})
 | 
			
		||||
 | 
			
		||||
	api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error {
 | 
			
		||||
		if record, ok := e.Model.(*models.Record); ok {
 | 
			
		||||
		if record := api.resolveRecord(e.Model); record != nil {
 | 
			
		||||
			if err := api.broadcastRecord("create", record); err != nil && api.app.IsDebug() {
 | 
			
		||||
				log.Println(err)
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +277,7 @@ func (api *realtimeApi) bindEvents() {
 | 
			
		|||
	})
 | 
			
		||||
 | 
			
		||||
	api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
 | 
			
		||||
		if record, ok := e.Model.(*models.Record); ok {
 | 
			
		||||
		if record := api.resolveRecord(e.Model); record != nil {
 | 
			
		||||
			if err := api.broadcastRecord("update", record); err != nil && api.app.IsDebug() {
 | 
			
		||||
				log.Println(err)
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -281,7 +286,7 @@ func (api *realtimeApi) bindEvents() {
 | 
			
		|||
	})
 | 
			
		||||
 | 
			
		||||
	api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error {
 | 
			
		||||
		if record, ok := e.Model.(*models.Record); ok {
 | 
			
		||||
		if record := api.resolveRecord(e.Model); record != nil {
 | 
			
		||||
			if err := api.broadcastRecord("delete", record); err != nil && api.app.IsDebug() {
 | 
			
		||||
				log.Println(err)
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -290,6 +295,33 @@ func (api *realtimeApi) bindEvents() {
 | 
			
		|||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// resolveRecord converts *if possible* the provided model interface to a Record.
 | 
			
		||||
// This is usually helpful if the provided model is a custom Record model struct.
 | 
			
		||||
func (api *realtimeApi) resolveRecord(model models.Model) (record *models.Record) {
 | 
			
		||||
	record, _ = model.(*models.Record)
 | 
			
		||||
 | 
			
		||||
	// check if it is custom Record model struct (ignore "private" tables)
 | 
			
		||||
	if record == nil && !strings.HasPrefix(model.TableName(), "_") {
 | 
			
		||||
		record, _ = api.app.Dao().FindRecordById(model.TableName(), model.GetId())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return record
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// resolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
 | 
			
		||||
// This is usually helpful if the provided model is a custom Record model struct.
 | 
			
		||||
func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection *models.Collection) {
 | 
			
		||||
	if record, ok := model.(*models.Record); ok {
 | 
			
		||||
		collection = record.Collection()
 | 
			
		||||
	} else if !strings.HasPrefix(model.TableName(), "_") {
 | 
			
		||||
		// check if it is custom Record model struct (ignore "private" tables)
 | 
			
		||||
		collection, _ = api.app.Dao().FindCollectionByNameOrId(model.TableName())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return collection
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// canAccessRecord checks if the subscription client has access to the specified record model.
 | 
			
		||||
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
 | 
			
		||||
	admin, _ := client.Get(ContextAdminKey).(*models.Admin)
 | 
			
		||||
	if admin != nil {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,8 +7,10 @@ import (
 | 
			
		|||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/labstack/echo/v5"
 | 
			
		||||
	"github.com/pocketbase/dbx"
 | 
			
		||||
	"github.com/pocketbase/pocketbase/apis"
 | 
			
		||||
	"github.com/pocketbase/pocketbase/core"
 | 
			
		||||
	"github.com/pocketbase/pocketbase/daos"
 | 
			
		||||
	"github.com/pocketbase/pocketbase/models"
 | 
			
		||||
	"github.com/pocketbase/pocketbase/tests"
 | 
			
		||||
	"github.com/pocketbase/pocketbase/tools/hook"
 | 
			
		||||
| 
						 | 
				
			
			@ -353,3 +355,96 @@ func TestRealtimeAdminUpdateEvent(t *testing.T) {
 | 
			
		|||
		t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Custom auth record model struct
 | 
			
		||||
// -------------------------------------------------------------------
 | 
			
		||||
var _ models.Model = (*CustomUser)(nil)
 | 
			
		||||
 | 
			
		||||
type CustomUser struct {
 | 
			
		||||
	models.BaseModel
 | 
			
		||||
 | 
			
		||||
	Email string `db:"email" json:"email"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *CustomUser) TableName() string {
 | 
			
		||||
	return "users" // the name of your collection
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) {
 | 
			
		||||
	model := &CustomUser{}
 | 
			
		||||
 | 
			
		||||
	err := dao.ModelQuery(model).
 | 
			
		||||
		AndWhere(dbx.HashExp{"email": email}).
 | 
			
		||||
		Limit(1).
 | 
			
		||||
		One(model)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return model, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
 | 
			
		||||
	testApp, _ := tests.NewTestApp()
 | 
			
		||||
	defer testApp.Cleanup()
 | 
			
		||||
 | 
			
		||||
	apis.InitApi(testApp)
 | 
			
		||||
 | 
			
		||||
	authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := subscriptions.NewDefaultClient()
 | 
			
		||||
	client.Set(apis.ContextAuthRecordKey, authRecord)
 | 
			
		||||
	testApp.SubscriptionsBroker().Register(client)
 | 
			
		||||
 | 
			
		||||
	// refetch the authRecord as CustomUser
 | 
			
		||||
	customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// delete the custom user (should unset the client auth record)
 | 
			
		||||
	if err := testApp.Dao().Delete(customUser); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(testApp.SubscriptionsBroker().Clients()) != 0 {
 | 
			
		||||
		t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
 | 
			
		||||
	testApp, _ := tests.NewTestApp()
 | 
			
		||||
	defer testApp.Cleanup()
 | 
			
		||||
 | 
			
		||||
	apis.InitApi(testApp)
 | 
			
		||||
 | 
			
		||||
	authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := subscriptions.NewDefaultClient()
 | 
			
		||||
	client.Set(apis.ContextAuthRecordKey, authRecord)
 | 
			
		||||
	testApp.SubscriptionsBroker().Register(client)
 | 
			
		||||
 | 
			
		||||
	// refetch the authRecord as CustomUser
 | 
			
		||||
	customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// change its email
 | 
			
		||||
	customUser.Email = "new@example.com"
 | 
			
		||||
	if err := testApp.Dao().Save(customUser); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
 | 
			
		||||
	if clientAuthRecord.Email() != customUser.Email {
 | 
			
		||||
		t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue