From 75f58a28ac4c1b6f2941e7f70a2bb08a03b5bd8d Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Fri, 18 Aug 2023 06:31:14 +0300 Subject: [PATCH] added placeholder params support for Dao.FindRecordsByFilter and Dao.FindFirstRecordByFilter --- CHANGELOG.md | 37 ++++++++++--- daos/record.go | 35 ++++++++---- daos/record_test.go | 99 ++++++++++++++++++++++++---------- tools/search/filter.go | 67 ++++++++++++++++------- tools/search/filter_test.go | 103 ++++++++++++++++++++++++++++++------ 5 files changed, 261 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 01193565..4f25d3e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ ## v0.18.0 - WIP - Added new `SmtpConfig.LocalName` option to specify a custom domain name (or IP address) for the initial EHLO/HELO exchange ([#3097](https://github.com/pocketbase/pocketbase/discussions/3097)). - _This is usually required for verification purposes only by some SMTP providers, such as Gmail SMTP-relay._ + _This is usually required for verification purposes only by some SMTP providers, such as on-premise [Gmail SMTP-relay](https://support.google.com/a/answer/2956491)._ - Added cron expression macros ([#3132](https://github.com/pocketbase/pocketbase/issues/3132)): ``` @@ -14,9 +14,30 @@ "@hourly": "0 * * * *" ``` -- Added JSVM `$mails.*` binds for sending. +- To minimize the footguns with `Dao.FindFirstRecordByFilter()` and `Dao.FindRecordsByFilter()`, the functions now supports an optional placeholder params argument that is safe to be populated with untrusted user input. + The placeholders are in the same format as when binding regular SQL parameters. + ```go + // unsanitized and untrusted filter variables + status := "..." + author := "..." -- Fill the `LastVerificationSentAt` and `LastResetSentAt` fields only after a successfull email send. + app.Dao().FindFirstRecordByFilter("articles", "status={:status} && author={:author}", dbx.Params{ + "status": status, + "author": author, + }) + + app.Dao().FindRecordsByFilter("articles", "status={:status} && author={:author}", "-created", 10, 0, dbx.Params{ + "status": status, + "author": author, + }) + ``` + +- ⚠️ Added offset argument `Dao.FindRecordsByFilter(collection, filter, sort, limit, offset, [params...])`. + _If you don't need an offset, you can set it to `0`._ + +- Added JSVM `$mails.*` binds for the corresponding Go [mails package](https://pkg.go.dev/github.com/pocketbase/pocketbase/mails) functions. + +- Fill the `LastVerificationSentAt` and `LastResetSentAt` fields only after a successfull email send ([#3121](https://github.com/pocketbase/pocketbase/issues/3121)). - Reflected the latest JS SDK changes in the Admin UI. @@ -749,7 +770,7 @@ _Note2: The old index (`"field.0":null`) and filename (`"field.filename.png":null`) based suffixed syntax for deleting files is still supported._ -- ! Added support for multi-match/match-all request data and collection multi-valued fields (`select`, `relation`) conditions. +- ⚠️ Added support for multi-match/match-all request data and collection multi-valued fields (`select`, `relation`) conditions. If you want a "at least one of" type of condition, you can prefix the operator with `?`. ```js // for each someRelA.someRelB record require the "status" field to be "active" @@ -829,7 +850,7 @@ ## v0.10.3 -- ! Renamed the metadata key `original_filename` to `original-filename` due to an S3 file upload error caused by the underscore character ([#1343](https://github.com/pocketbase/pocketbase/pull/1343); thanks @yuxiang-gao). +- ⚠️ Renamed the metadata key `original_filename` to `original-filename` due to an S3 file upload error caused by the underscore character ([#1343](https://github.com/pocketbase/pocketbase/pull/1343); thanks @yuxiang-gao). - Fixed request verification docs api url ([#1332](https://github.com/pocketbase/pocketbase/pull/1332); thanks @JoyMajumdar2001) @@ -865,7 +886,7 @@ - Refactored the `core.app.Bootstrap()` to be called before starting the cobra commands ([#1267](https://github.com/pocketbase/pocketbase/discussions/1267)). -- ! Changed `pocketbase.NewWithConfig(config Config)` to `pocketbase.NewWithConfig(config *Config)` and added 4 new config settings: +- ⚠️ Changed `pocketbase.NewWithConfig(config Config)` to `pocketbase.NewWithConfig(config *Config)` and added 4 new config settings: ```go DataMaxOpenConns int // default to core.DefaultDataMaxOpenConns DataMaxIdleConns int // default to core.DefaultDataMaxIdleConns @@ -875,9 +896,9 @@ - Added new helper method `core.App.IsBootstrapped()` to check the current app bootstrap state. -- ! Changed `core.NewBaseApp(dir, encryptionEnv, isDebug)` to `NewBaseApp(config *BaseAppConfig)`. +- ⚠️ Changed `core.NewBaseApp(dir, encryptionEnv, isDebug)` to `NewBaseApp(config *BaseAppConfig)`. -- ! Removed `rest.UploadedFile` struct (see below `filesystem.File`). +- ⚠️ Removed `rest.UploadedFile` struct (see below `filesystem.File`). - Added generic file resource struct that allows loading and uploading file content from different sources (at the moment multipart/form-data requests and from the local filesystem). diff --git a/daos/record.go b/daos/record.go index 4e8d54a1..4f4a67b8 100644 --- a/daos/record.go +++ b/daos/record.go @@ -252,23 +252,31 @@ func (dao *Dao) FindFirstRecordByData( // FindRecordsByFilter returns limit number of records matching the // provided string filter. // +// NB! Use the last "params" argument to bind untrusted user variables! +// // The sort argument is optional and can be empty string OR the same format // used in the web APIs, eg. "-created,title". // // If the limit argument is <= 0, no limit is applied to the query and // all matching records are returned. // -// NB! Don't put untrusted user input in the filter string as it -// practically would allow the users to inject their own custom filter. -// // Example: // -// dao.FindRecordsByFilter("posts", "title ~ 'lorem ipsum' && visible = true", "-created", 10) +// dao.FindRecordsByFilter( +// "posts", +// "title ~ {:title} && visible = {:visible}", +// "-created", +// 10, +// 0, +// dbx.Params{"title": "lorem ipsum", "visible": true} +// ) func (dao *Dao) FindRecordsByFilter( collectionNameOrId string, filter string, sort string, limit int, + offset int, + params ...dbx.Params, ) ([]*models.Record, error) { collection, err := dao.FindCollectionByNameOrId(collectionNameOrId) if err != nil { @@ -286,7 +294,7 @@ func (dao *Dao) FindRecordsByFilter( true, // allow searching hidden/protected fields like "email" ) - expr, err := search.FilterData(filter).BuildExpr(resolver) + expr, err := search.FilterData(filter).BuildExpr(resolver, params...) if err != nil || expr == nil { return nil, errors.New("invalid or empty filter expression") } @@ -307,6 +315,10 @@ func (dao *Dao) FindRecordsByFilter( resolver.UpdateQuery(q) // attaches any adhoc joins and aliases // --- + if offset > 0 { + q.Offset(int64(offset)) + } + if limit > 0 { q.Limit(int64(limit)) } @@ -322,14 +334,17 @@ func (dao *Dao) FindRecordsByFilter( // FindFirstRecordByFilter returns the first available record matching the provided filter. // -// NB! Don't put untrusted user input in the filter string as it -// practically would allow the users to inject their own custom filter. +// NB! Use the last params argument to bind untrusted user variables! // // Example: // -// dao.FindFirstRecordByFilter("posts", "slug='test'") -func (dao *Dao) FindFirstRecordByFilter(collectionNameOrId string, filter string) (*models.Record, error) { - result, err := dao.FindRecordsByFilter(collectionNameOrId, filter, "", 1) +// dao.FindFirstRecordByFilter("posts", "slug={:slug} && status='public'", dbx.Params{"slug": "test"}) +func (dao *Dao) FindFirstRecordByFilter( + collectionNameOrId string, + filter string, + params ...dbx.Params, +) (*models.Record, error) { + result, err := dao.FindRecordsByFilter(collectionNameOrId, filter, "", 1, 0, params...) if err != nil { return nil, err } diff --git a/daos/record_test.go b/daos/record_test.go index ebe2fbeb..3f5e8cc4 100644 --- a/daos/record_test.go +++ b/daos/record_test.go @@ -436,6 +436,8 @@ func TestFindRecordsByFilter(t *testing.T) { filter string sort string limit int + offset int + params []dbx.Params expectError bool expectRecordIds []string }{ @@ -445,6 +447,8 @@ func TestFindRecordsByFilter(t *testing.T) { "id != ''", "", 0, + 0, + nil, true, nil, }, @@ -454,6 +458,8 @@ func TestFindRecordsByFilter(t *testing.T) { "", "", 0, + 0, + nil, true, nil, }, @@ -463,6 +469,8 @@ func TestFindRecordsByFilter(t *testing.T) { "someMissingField > 1", "", 0, + 0, + nil, true, nil, }, @@ -472,6 +480,8 @@ func TestFindRecordsByFilter(t *testing.T) { "id != ''", "", 0, + 0, + nil, false, []string{ "llvuca81nly1qls", @@ -485,6 +495,8 @@ func TestFindRecordsByFilter(t *testing.T) { "id != '' && active=true", "-created,title", -1, // should behave the same as 0 + 0, + nil, false, []string{ "0yxhwia2amd8gec", @@ -492,47 +504,64 @@ func TestFindRecordsByFilter(t *testing.T) { }, }, { - "with limit", + "with limit and offset", "demo2", "id != ''", "title", 2, + 1, + nil, + false, + []string{ + "achvryl401bhse3", + "0yxhwia2amd8gec", + }, + }, + { + "with placeholder params", + "demo2", + "active = {:active}", + "", + 10, + 0, + []dbx.Params{{"active": false}}, false, []string{ "llvuca81nly1qls", - "achvryl401bhse3", }, }, } for _, s := range scenarios { - records, err := app.Dao().FindRecordsByFilter( - s.collectionIdOrName, - s.filter, - s.sort, - s.limit, - ) + t.Run(s.name, func(t *testing.T) { + records, err := app.Dao().FindRecordsByFilter( + s.collectionIdOrName, + s.filter, + s.sort, + s.limit, + s.offset, + s.params..., + ) - hasErr := err != nil - if hasErr != s.expectError { - t.Errorf("[%s] Expected hasErr to be %v, got %v (%v)", s.name, s.expectError, hasErr, err) - continue - } - - if hasErr { - continue - } - - if len(records) != len(s.expectRecordIds) { - t.Errorf("[%s] Expected %d records, got %d", s.name, len(s.expectRecordIds), len(records)) - continue - } - - for i, id := range s.expectRecordIds { - if id != records[i].Id { - t.Errorf("[%s] Expected record with id %q, got %q at index %d", s.name, id, records[i].Id, i) + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.name, s.expectError, hasErr, err) } - } + + if hasErr { + return + } + + if len(records) != len(s.expectRecordIds) { + t.Fatalf("[%s] Expected %d records, got %d", s.name, len(s.expectRecordIds), len(records)) + } + + for i, id := range s.expectRecordIds { + if id != records[i].Id { + t.Fatalf("[%s] Expected record with id %q, got %q at index %d", s.name, id, records[i].Id, i) + } + } + }) } } @@ -544,6 +573,7 @@ func TestFindFirstRecordByFilter(t *testing.T) { name string collectionIdOrName string filter string + params []dbx.Params expectError bool expectRecordId string }{ @@ -551,6 +581,7 @@ func TestFindFirstRecordByFilter(t *testing.T) { "missing collection", "missing", "id != ''", + nil, true, "", }, @@ -558,6 +589,7 @@ func TestFindFirstRecordByFilter(t *testing.T) { "missing filter", "demo2", "", + nil, true, "", }, @@ -565,6 +597,7 @@ func TestFindFirstRecordByFilter(t *testing.T) { "invalid filter", "demo2", "someMissingField > 1", + nil, true, "", }, @@ -572,6 +605,7 @@ func TestFindFirstRecordByFilter(t *testing.T) { "valid filter but no matches", "demo2", "id = 'test'", + nil, true, "", }, @@ -579,13 +613,22 @@ func TestFindFirstRecordByFilter(t *testing.T) { "valid filter and multiple matches", "demo2", "id != ''", + nil, + false, + "llvuca81nly1qls", + }, + { + "with placeholder params", + "demo2", + "active = {:active}", + []dbx.Params{{"active": false}}, false, "llvuca81nly1qls", }, } for _, s := range scenarios { - record, err := app.Dao().FindFirstRecordByFilter(s.collectionIdOrName, s.filter) + record, err := app.Dao().FindFirstRecordByFilter(s.collectionIdOrName, s.filter, s.params...) hasErr := err != nil if hasErr != s.expectError { diff --git a/tools/search/filter.go b/tools/search/filter.go index cf07ce98..d2ab6074 100644 --- a/tools/search/filter.go +++ b/tools/search/filter.go @@ -3,6 +3,7 @@ package search import ( "errors" "fmt" + "strconv" "strings" "github.com/ganigeorgiev/fexpr" @@ -15,11 +16,14 @@ import ( // FilterData is a filter expression string following the `fexpr` package grammar. // +// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"), +// that will be safely replaced and properly quoted inplace with the placeholderReplacements values. +// // Example: // -// var filter FilterData = "id = null || (name = 'test' && status = true)" +// var filter FilterData = "id = null || (name = 'test' && status = true) || (total >= {:min} && total <= {:max})" // resolver := search.NewSimpleFieldResolver("id", "name", "status") -// expr, err := filter.BuildExpr(resolver) +// expr, err := filter.BuildExpr(resolver, dbx.Params{"min": 100, "max": 200}) type FilterData string // parsedFilterData holds a cache with previously parsed filter data expressions @@ -27,10 +31,33 @@ type FilterData string var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50)) // BuildExpr parses the current filter data and returns a new db WHERE expression. -func (f FilterData) BuildExpr(fieldResolver FieldResolver) (dbx.Expression, error) { +// +// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"), +// that will be safely replaced and properly quoted inplace with the placeholderReplacements values. +func (f FilterData) BuildExpr( + fieldResolver FieldResolver, + placeholderReplacements ...dbx.Params, +) (dbx.Expression, error) { raw := string(f) + + // replace the placeholder params in the raw string filter + for _, p := range placeholderReplacements { + for key, value := range p { + var replacement string + switch v := value.(type) { + case nil: + replacement = "null" + case bool, float64, float32, int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: + replacement = cast.ToString(v) + default: + replacement = strconv.Quote(cast.ToString(v)) + } + raw = strings.ReplaceAll(raw, "{:"+key+"}", replacement) + } + } + if parsedFilterData.Has(raw) { - return f.build(parsedFilterData.Get(raw), fieldResolver) + return buildParsedFilterExpr(parsedFilterData.Get(raw), fieldResolver) } data, err := fexpr.Parse(raw) if err != nil { @@ -39,10 +66,10 @@ func (f FilterData) BuildExpr(fieldResolver FieldResolver) (dbx.Expression, erro // store in cache // (the limit size is arbitrary and it is there to prevent the cache growing too big) parsedFilterData.SetIfLessThanLimit(raw, data, 500) - return f.build(data, fieldResolver) + return buildParsedFilterExpr(data, fieldResolver) } -func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) { +func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) { if len(data) == 0 { return nil, errors.New("empty filter expression") } @@ -55,11 +82,11 @@ func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) ( switch item := group.Item.(type) { case fexpr.Expr: - expr, exprErr = f.resolveTokenizedExpr(item, fieldResolver) + expr, exprErr = resolveTokenizedExpr(item, fieldResolver) case fexpr.ExprGroup: - expr, exprErr = f.build([]fexpr.ExprGroup{item}, fieldResolver) + expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver) case []fexpr.ExprGroup: - expr, exprErr = f.build(item, fieldResolver) + expr, exprErr = buildParsedFilterExpr(item, fieldResolver) default: exprErr = errors.New("unsupported expression item") } @@ -84,7 +111,7 @@ func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) ( return result, nil } -func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) { +func resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) { lResult, lErr := resolveToken(expr.Left, fieldResolver) if lErr != nil || lResult.Identifier == "" { return nil, fmt.Errorf("invalid left operand %q - %v", expr.Left.Literal, lErr) @@ -95,10 +122,10 @@ func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldRes return nil, fmt.Errorf("invalid right operand %q - %v", expr.Right.Literal, rErr) } - return buildExpr(lResult, expr.Op, rResult) + return buildResolversExpr(lResult, expr.Op, rResult) } -func buildExpr( +func buildResolversExpr( left *ResolverResult, op fexpr.SignOp, right *ResolverResult, @@ -179,17 +206,21 @@ func buildExpr( return expr, nil } +var identifierMacros = map[string]func() string{ + "@now": func() string { return types.NowDateTime().String() }, +} + func resolveToken(token fexpr.Token, fieldResolver FieldResolver) (*ResolverResult, error) { switch token.Type { case fexpr.TokenIdentifier: - // current datetime constant + // check for macros // --- - if token.Literal == "@now" { + if f, ok := identifierMacros[token.Literal]; ok { placeholder := "t" + security.PseudorandomString(5) return &ResolverResult{ Identifier: "{:" + placeholder + "}", - Params: dbx.Params{placeholder: types.NowDateTime().String()}, + Params: dbx.Params{placeholder: f()}, }, nil } @@ -469,7 +500,7 @@ func (e *manyVsManyExpr) Build(db *dbx.DB, params dbx.Params) string { lAlias := "__ml" + security.PseudorandomString(5) rAlias := "__mr" + security.PseudorandomString(5) - whereExpr, buildErr := buildExpr( + whereExpr, buildErr := buildResolversExpr( &ResolverResult{ Identifier: "[[" + lAlias + ".multiMatchValue]]", }, @@ -536,9 +567,9 @@ func (e *manyVsOneExpr) Build(db *dbx.DB, params dbx.Params) string { var buildErr error if e.inverse { - whereExpr, buildErr = buildExpr(r2, e.op, r1) + whereExpr, buildErr = buildResolversExpr(r2, e.op, r1) } else { - whereExpr, buildErr = buildExpr(r1, e.op, r2) + whereExpr, buildErr = buildResolversExpr(r1, e.op, r2) } if buildErr != nil { diff --git a/tools/search/filter_test.go b/tools/search/filter_test.go index 2490ead7..e6e9b74b 100644 --- a/tools/search/filter_test.go +++ b/tools/search/filter_test.go @@ -1,8 +1,11 @@ package search_test import ( + "context" + "database/sql" "regexp" "testing" + "time" "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/tools/search" @@ -25,7 +28,10 @@ func TestFilterDataBuildExpr(t *testing.T) { }, { "invalid format", - "(test1 > 1", true, ""}, + "(test1 > 1", + true, + "", + }, { "invalid operator", "test1 + 123", @@ -169,24 +175,89 @@ func TestFilterDataBuildExpr(t *testing.T) { } for _, s := range scenarios { - expr, err := s.filterData.BuildExpr(resolver) + t.Run(s.name, func(t *testing.T) { + expr, err := s.filterData.BuildExpr(resolver) - hasErr := err != nil - if hasErr != s.expectError { - t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err) - continue - } + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err) + } - if hasErr { - continue - } + if hasErr { + return + } - dummyDB := &dbx.DB{} - rawSql := expr.Build(dummyDB, map[string]any{}) + dummyDB := &dbx.DB{} - pattern := regexp.MustCompile(s.expectPattern) - if !pattern.MatchString(rawSql) { - t.Errorf("[%s] Pattern %v don't match with expression: \n%v", s.name, s.expectPattern, rawSql) - } + rawSql := expr.Build(dummyDB, dbx.Params{}) + + pattern := regexp.MustCompile(s.expectPattern) + if !pattern.MatchString(rawSql) { + t.Fatalf("[%s] Pattern %v don't match with expression: \n%v", s.name, s.expectPattern, rawSql) + } + }) + } +} + +func TestFilterDataBuildExprWithParams(t *testing.T) { + // create a dummy db + sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared") + if err != nil { + t.Fatal(err) + } + db := dbx.NewFromDB(sqlDB, "sqlite") + + calledQueries := []string{} + db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) { + calledQueries = append(calledQueries, sql) + } + db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) { + calledQueries = append(calledQueries, sql) + } + + date, err := time.Parse("2006-01-02", "2023-01-01") + if err != nil { + t.Fatal(err) + } + + resolver := search.NewSimpleFieldResolver(`^test\w+$`) + + filter := search.FilterData(` + test1 = {:test1} || + test2 = {:test2} || + test3a = {:test3} || + test3b = {:test3} || + test4 = {:test4} || + test5 = {:test5} || + test6 = {:test6} || + test7 = {:test7} || + test8 = {:test8} || + test9 = {:test9} || + test10 = {:test10} + `) + + replacements := []dbx.Params{ + {"test1": true}, + {"test2": false}, + {"test3": 123.456}, + {"test4": nil}, + {"test5": "", "test6": "simple", "test7": `'single_quotes'`, "test8": `"double_quotes"`, "test9": `escape\"quote`}, + {"test10": date}, + } + + expr, err := filter.BuildExpr(resolver, replacements...) + if err != nil { + t.Fatal(err) + } + + db.Select().Where(expr).Build().Execute() + + if len(calledQueries) != 1 { + t.Fatalf("Expected 1 query, got %d", len(calledQueries)) + } + + expectedQuery := `SELECT * WHERE ([[test1]] = 1 OR [[test2]] = 0 OR [[test3a]] = 123.456 OR [[test3b]] = 123.456 OR ([[test4]] = '' OR [[test4]] IS NULL) OR ([[test5]] = '' OR [[test5]] IS NULL) OR [[test6]] = 'simple' OR [[test7]] = '''single_quotes''' OR [[test8]] = '"double_quotes"' OR [[test9]] = 'escape\\"quote' OR [[test10]] = '2023-01-01 00:00:00 +0000 UTC')` + if expectedQuery != calledQueries[0] { + t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0]) } }