diff --git a/CHANGELOG.md b/CHANGELOG.md index 199d2e69..35a18737 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,9 @@ - Minor Admin UI fixes (typos, grammar fixes, removed unnecessary 404 error check, etc.). +- (@todo docs) For consistency and convenience it is now possible to call `Dao.RecordQuery(collectionModelOrIdentifier)` with just the collection id or name. + In case an invalid collection id/name string is passed the query will be resolved with cancelled context error. + ## v0.16.8 diff --git a/daos/record.go b/daos/record.go index e2909cd7..d6df5392 100644 --- a/daos/record.go +++ b/daos/record.go @@ -1,6 +1,7 @@ package daos import ( + "context" "database/sql" "errors" "fmt" @@ -18,79 +19,114 @@ import ( "github.com/spf13/cast" ) -// RecordQuery returns a new Record select query. -func (dao *Dao) RecordQuery(collection *models.Collection) *dbx.SelectQuery { - tableName := collection.Name +// RecordQuery returns a new Record select query from a collection model, id or name. +// +// In case a collection id or name is provided and that collection doesn't +// actually exists, the generated query will be created with a cancelled context +// and will fail once an executor (Row(), One(), All(), etc.) is called. +func (dao *Dao) RecordQuery(collectionModelOrIdentifier any) *dbx.SelectQuery { + var tableName string + var collection *models.Collection + var collectionErr error + switch c := collectionModelOrIdentifier.(type) { + case *models.Collection: + collection = c + tableName = collection.Name + case models.Collection: + collection = &c + tableName = collection.Name + case string: + collection, collectionErr = dao.FindCollectionByNameOrId(c) + if collection != nil { + tableName = collection.Name + } else { + // update with some fake table name for easier debugging + tableName = "@@__missing_" + c + } + default: + // update with some fake table name for easier debugging + tableName = "@@__invalidCollectionModelOrIdentifier" + collectionErr = errors.New("unsupported collection identifier, must be collection model, id or name") + } + selectCols := fmt.Sprintf("%s.*", dao.DB().QuoteSimpleColumnName(tableName)) - return dao.DB(). - Select(selectCols). - From(tableName). - WithBuildHook(func(query *dbx.Query) { - query.WithExecHook(execLockRetry(dao.ModelQueryTimeout, dao.MaxLockRetries)). - WithOneHook(func(q *dbx.Query, a any, op func(b any) error) error { - switch v := a.(type) { - case *models.Record: - if v == nil { - return op(a) - } + query := dao.DB().Select(selectCols).From(tableName) - row := dbx.NullStringMap{} - if err := op(&row); err != nil { - return err - } + // in case of an error attach a new context and cancel it immediately with the error + if collectionErr != nil { + // @todo consider changing to WithCancelCause when upgrading + // the min Go requirement to 1.20, so that we can pass the error + ctx, cancelFunc := context.WithCancel(context.Background()) + query.WithContext(ctx) + cancelFunc() + } - record := models.NewRecordFromNullStringMap(collection, row) - - *v = *record - - return nil - default: + return query.WithBuildHook(func(q *dbx.Query) { + q.WithExecHook(execLockRetry(dao.ModelQueryTimeout, dao.MaxLockRetries)). + WithOneHook(func(q *dbx.Query, a any, op func(b any) error) error { + switch v := a.(type) { + case *models.Record: + if v == nil { return op(a) } - }). - WithAllHook(func(q *dbx.Query, sliceA any, op func(sliceB any) error) error { - switch v := sliceA.(type) { - case *[]*models.Record: - if v == nil { - return op(sliceA) - } - rows := []dbx.NullStringMap{} - if err := op(&rows); err != nil { - return err - } + row := dbx.NullStringMap{} + if err := op(&row); err != nil { + return err + } - records := models.NewRecordsFromNullStringMaps(collection, rows) + record := models.NewRecordFromNullStringMap(collection, row) - *v = records + *v = *record - return nil - case *[]models.Record: - if v == nil { - return op(sliceA) - } - - rows := []dbx.NullStringMap{} - if err := op(&rows); err != nil { - return err - } - - records := models.NewRecordsFromNullStringMaps(collection, rows) - - nonPointers := make([]models.Record, len(records)) - for i, r := range records { - nonPointers[i] = *r - } - - *v = nonPointers - - return nil - default: + return nil + default: + return op(a) + } + }). + WithAllHook(func(q *dbx.Query, sliceA any, op func(sliceB any) error) error { + switch v := sliceA.(type) { + case *[]*models.Record: + if v == nil { return op(sliceA) } - }) - }) + + rows := []dbx.NullStringMap{} + if err := op(&rows); err != nil { + return err + } + + records := models.NewRecordsFromNullStringMaps(collection, rows) + + *v = records + + return nil + case *[]models.Record: + if v == nil { + return op(sliceA) + } + + rows := []dbx.NullStringMap{} + if err := op(&rows); err != nil { + return err + } + + records := models.NewRecordsFromNullStringMaps(collection, rows) + + nonPointers := make([]models.Record, len(records)) + for i, r := range records { + nonPointers[i] = *r + } + + *v = nonPointers + + return nil + default: + return op(sliceA) + } + }) + }) } // FindRecordById finds the Record model by its id. diff --git a/daos/record_test.go b/daos/record_test.go index a9b8f8e3..6578ec1a 100644 --- a/daos/record_test.go +++ b/daos/record_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "errors" - "fmt" "regexp" "strings" "testing" @@ -19,7 +18,7 @@ import ( "github.com/pocketbase/pocketbase/tools/types" ) -func TestRecordQuery(t *testing.T) { +func TestRecordQueryWithDifferentCollectionValues(t *testing.T) { app, _ := tests.NewTestApp() defer app.Cleanup() @@ -28,11 +27,33 @@ func TestRecordQuery(t *testing.T) { t.Fatal(err) } - expected := fmt.Sprintf("SELECT `%s`.* FROM `%s`", collection.Name, collection.Name) + scenarios := []struct { + name any + collection any + expectedTotal int + expectError bool + }{ + {"with nil value", nil, 0, true}, + {"with invalid or missing collection id/name", "missing", 0, true}, + {"with pointer model", collection, 3, false}, + {"with value model", *collection, 3, false}, + {"with name", "demo1", 3, false}, + {"with id", "wsmn24bux7wo113", 3, false}, + } - sql := app.Dao().RecordQuery(collection).Build().SQL() - if sql != expected { - t.Errorf("Expected sql %s, got %s", expected, sql) + for _, s := range scenarios { + var records []*models.Record + err := app.Dao().RecordQuery(s.collection).All(&records) + + hasErr := err != nil + if hasErr != s.expectError { + t.Errorf("[%s] Expected hasError %v, got %v", s.name, s.expectError, hasErr) + continue + } + + if total := len(records); total != s.expectedTotal { + t.Errorf("[%s] Expected %d records, got %d", s.name, s.expectedTotal, total) + } } } diff --git a/tools/list/list.go b/tools/list/list.go index 9b0e40b5..049fbdde 100644 --- a/tools/list/list.go +++ b/tools/list/list.go @@ -27,7 +27,6 @@ func SubtractSlice[T comparable](base []T, subtract []T) []T { // ExistInSlice checks whether a comparable element exists in a slice of the same type. func ExistInSlice[T comparable](item T, list []T) bool { - for _, v := range list { if v == item { return true