diff --git a/CHANGELOG.md b/CHANGELOG.md index a9623234..f8f4790b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ - Added `RateLimitRule.Audience` optional field for restricting a rate limit rule for `"@guest"`-only, `"@auth"`-only, `""`-any (default). +- Added default max limits for the expressions count and length of the search filter and sort params. + _This is just an extra measure mostly for the case when the filter and sort parameters are resolved outside of the request context since the request size limits won't apply._ + +- Other minor improvements (better error in case of duplicated rate limit rule, fixed typos, resolved lint warnings, etc.). + ## v0.23.0-rc12 diff --git a/tools/search/filter.go b/tools/search/filter.go index cbbd04b2..162dac7b 100644 --- a/tools/search/filter.go +++ b/tools/search/filter.go @@ -34,9 +34,24 @@ var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50)) // // The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"), // that will be safely replaced and properly quoted inplace with the placeholderReplacements values. +// +// The parsed expressions are limited up to DefaultFilterExprLimit. +// Use [FilterData.BuildExprWithLimit] if you want to set a custom limit. func (f FilterData) BuildExpr( fieldResolver FieldResolver, placeholderReplacements ...dbx.Params, +) (dbx.Expression, error) { + return f.BuildExprWithLimit(fieldResolver, DefaultFilterExprLimit, placeholderReplacements...) +} + +// BuildExpr parses the current filter data and returns a new db WHERE expression. +// +// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"), +// that will be safely replaced and properly quoted inplace with the placeholderReplacements values. +func (f FilterData) BuildExprWithLimit( + fieldResolver FieldResolver, + maxExpressions int, + placeholderReplacements ...dbx.Params, ) (dbx.Expression, error) { raw := string(f) @@ -64,8 +79,10 @@ func (f FilterData) BuildExpr( } } - if data, ok := parsedFilterData.GetOk(raw); ok { - return buildParsedFilterExpr(data, fieldResolver) + cacheKey := raw + "/" + strconv.Itoa(maxExpressions) + + if data, ok := parsedFilterData.GetOk(cacheKey); ok { + return buildParsedFilterExpr(data, fieldResolver, &maxExpressions) } data, err := fexpr.Parse(raw) @@ -82,14 +99,14 @@ func (f FilterData) BuildExpr( // store in cache // (the limit size is arbitrary and it is there to prevent the cache growing too big) - parsedFilterData.SetIfLessThanLimit(raw, data, 500) + parsedFilterData.SetIfLessThanLimit(cacheKey, data, 500) - return buildParsedFilterExpr(data, fieldResolver) + return buildParsedFilterExpr(data, fieldResolver, &maxExpressions) } -func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) { +func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver, maxExpressions *int) (dbx.Expression, error) { if len(data) == 0 { - return nil, errors.New("empty filter expression") + return nil, fexpr.ErrEmpty } result := &concatExpr{separator: " "} @@ -100,11 +117,17 @@ func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver) switch item := group.Item.(type) { case fexpr.Expr: + if *maxExpressions <= 0 { + return nil, ErrFilterExprLimit + } + + *maxExpressions-- + expr, exprErr = resolveTokenizedExpr(item, fieldResolver) case fexpr.ExprGroup: - expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver) + expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver, maxExpressions) case []fexpr.ExprGroup: - expr, exprErr = buildParsedFilterExpr(item, fieldResolver) + expr, exprErr = buildParsedFilterExpr(item, fieldResolver, maxExpressions) default: exprErr = errors.New("unsupported expression item") } diff --git a/tools/search/filter_test.go b/tools/search/filter_test.go index 59e37e09..ad6a672a 100644 --- a/tools/search/filter_test.go +++ b/tools/search/filter_test.go @@ -3,6 +3,7 @@ package search_test import ( "context" "database/sql" + "fmt" "regexp" "strings" "testing" @@ -239,6 +240,35 @@ func TestFilterDataBuildExprWithParams(t *testing.T) { } } +func TestFilterDataBuildExprWithLimit(t *testing.T) { + resolver := search.NewSimpleFieldResolver(`^\w+$`) + + scenarios := []struct { + limit int + filter search.FilterData + expectError bool + }{ + {1, "1 = 1", false}, + {0, "1 = 1", true}, // new cache entry should be created + {2, "1 = 1 || 1 = 1", false}, + {1, "1 = 1 || 1 = 1", true}, + {3, "1 = 1 || 1 = 1", false}, + {6, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", false}, + {5, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", true}, + } + + for i, s := range scenarios { + t.Run(fmt.Sprintf("limit_%d:%d", i, s.limit), func(t *testing.T) { + _, err := s.filter.BuildExprWithLimit(resolver, s.limit) + + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr) + } + }) + } +} + func TestLikeParamsWrapping(t *testing.T) { // create a dummy db sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared") diff --git a/tools/search/provider.go b/tools/search/provider.go index 17ae72b3..940493b7 100644 --- a/tools/search/provider.go +++ b/tools/search/provider.go @@ -12,15 +12,36 @@ import ( "golang.org/x/sync/errgroup" ) -// DefaultPerPage specifies the default returned search result items. -const DefaultPerPage int = 30 +const ( + // DefaultPerPage specifies the default number of returned search result items. + DefaultPerPage int = 30 -// @todo consider making it configurable -// -// MaxPerPage specifies the maximum allowed search result items returned in a single page. -const MaxPerPage int = 1000 + // DefaultFilterExprLimit specifies the default filter expressions limit. + DefaultFilterExprLimit int = 200 -// url search query params + // DefaultSortExprLimit specifies the default sort expressions limit. + DefaultSortExprLimit int = 8 + + // MaxPerPage specifies the max allowed search result items returned in a single page. + MaxPerPage int = 1000 + + // MaxFilterLength specifies the max allowed individual search filter parsable length. + MaxFilterLength int = 3500 + + // MaxSortFieldLength specifies the max allowed individual sort field parsable length. + MaxSortFieldLength int = 255 +) + +// Common search errors. +var ( + ErrEmptyQuery = errors.New("search query is not set") + ErrSortExprLimit = errors.New("max sort expressions limit reached") + ErrFilterExprLimit = errors.New("max filter expressions limit reached") + ErrFilterLengthLimit = errors.New("max filter length limit reached") + ErrSortFieldLengthLimit = errors.New("max sort field length limit reached") +) + +// URL search query params const ( PageQueryParam string = "page" PerPageQueryParam string = "perPage" @@ -40,17 +61,19 @@ type Result struct { // Provider represents a single configured search provider instance. type Provider struct { - fieldResolver FieldResolver - query *dbx.SelectQuery - countCol string - sort []SortField - filter []FilterData - page int - perPage int - skipTotal bool + fieldResolver FieldResolver + query *dbx.SelectQuery + countCol string + sort []SortField + filter []FilterData + page int + perPage int + skipTotal bool + maxFilterExprLimit int + maxSortExprLimit int } -// NewProvider creates and returns a new search provider. +// NewProvider initializes and returns a new search provider. // // Example: // @@ -63,15 +86,31 @@ type Provider struct { // ParseAndExec("page=2&filter=id>0&sort=-email", &models) func NewProvider(fieldResolver FieldResolver) *Provider { return &Provider{ - fieldResolver: fieldResolver, - countCol: "id", - page: 1, - perPage: DefaultPerPage, - sort: []SortField{}, - filter: []FilterData{}, + fieldResolver: fieldResolver, + countCol: "id", + page: 1, + perPage: DefaultPerPage, + sort: []SortField{}, + filter: []FilterData{}, + maxFilterExprLimit: DefaultFilterExprLimit, + maxSortExprLimit: DefaultSortExprLimit, } } +// MaxFilterExpressions changes the default max allowed filter expressions. +// +// Note that currently the limit is applied individually for each separate filter. +func (s *Provider) MaxFilterExprLimit(max int) *Provider { + s.maxFilterExprLimit = max + return s +} + +// MaxSortExpressions changes the default max allowed sort expressions. +func (s *Provider) MaxSortExprLimit(max int) *Provider { + s.maxSortExprLimit = max + return s +} + // Query sets the base query that will be used to fetch the search items. func (s *Provider) Query(query *dbx.SelectQuery) *Provider { s.query = query @@ -188,7 +227,7 @@ func (s *Provider) Parse(urlQuery string) error { // the provided `items` slice with the found models. func (s *Provider) Exec(items any) (*Result, error) { if s.query == nil { - return nil, errors.New("query is not set") + return nil, ErrEmptyQuery } // shallow clone the provider's query @@ -196,7 +235,10 @@ func (s *Provider) Exec(items any) (*Result, error) { // build filters for _, f := range s.filter { - expr, err := f.BuildExpr(s.fieldResolver) + if len(f) > MaxFilterLength { + return nil, ErrFilterLengthLimit + } + expr, err := f.BuildExprWithLimit(s.fieldResolver, s.maxFilterExprLimit) if err != nil { return nil, err } @@ -206,7 +248,13 @@ func (s *Provider) Exec(items any) (*Result, error) { } // apply sorting + if len(s.sort) > s.maxSortExprLimit { + return nil, ErrSortExprLimit + } for _, sortField := range s.sort { + if len(sortField.Name) > MaxSortFieldLength { + return nil, ErrSortFieldLengthLimit + } expr, err := sortField.BuildExpr(s.fieldResolver) if err != nil { return nil, err diff --git a/tools/search/provider_test.go b/tools/search/provider_test.go index e4756ff1..0a523983 100644 --- a/tools/search/provider_test.go +++ b/tools/search/provider_test.go @@ -6,6 +6,8 @@ import ( "encoding/json" "errors" "fmt" + "strconv" + "strings" "testing" "time" @@ -25,6 +27,46 @@ func TestNewProvider(t *testing.T) { if p.perPage != DefaultPerPage { t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage) } + + if p.maxFilterExprLimit != DefaultFilterExprLimit { + t.Fatalf("Expected maxFilterExprLimit %d, got %d", DefaultFilterExprLimit, p.maxFilterExprLimit) + } + + if p.maxSortExprLimit != DefaultSortExprLimit { + t.Fatalf("Expected maxSortExprLimit %d, got %d", DefaultSortExprLimit, p.maxSortExprLimit) + } +} + +func TestMaxFilterExprLimit(t *testing.T) { + p := NewProvider(&testFieldResolver{}) + + testVals := []int{0, -10, 10} + + for _, val := range testVals { + t.Run("max_"+strconv.Itoa(val), func(t *testing.T) { + p.MaxFilterExprLimit(val) + + if p.maxFilterExprLimit != val { + t.Fatalf("Expected maxFilterExprLimit to change to %d, got %d", val, p.maxFilterExprLimit) + } + }) + } +} + +func TestMaxSortExprLimit(t *testing.T) { + p := NewProvider(&testFieldResolver{}) + + testVals := []int{0, -10, 10} + + for _, val := range testVals { + t.Run("max_"+strconv.Itoa(val), func(t *testing.T) { + p.MaxSortExprLimit(val) + + if p.maxSortExprLimit != val { + t.Fatalf("Expected maxSortExprLimit to change to %d, got %d", val, p.maxSortExprLimit) + } + }) + } } func TestProviderQuery(t *testing.T) { @@ -428,6 +470,141 @@ func TestProviderExecNonEmptyQuery(t *testing.T) { } } +func TestProviderFilterAndSortLimits(t *testing.T) { + testDB, err := createTestDB() + if err != nil { + t.Fatal(err) + } + defer testDB.Close() + + query := testDB.Select("*"). + From("test"). + Where(dbx.Not(dbx.HashExp{"test1": nil})). + OrderBy("test1 ASC") + + scenarios := []struct { + name string + filter []FilterData + sort []SortField + maxFilterExprLimit int + maxSortExprLimit int + expectError bool + }{ + // filter + { + "<= max filter length", + []FilterData{ + "1=2", + FilterData("1='" + strings.Repeat("a", MaxFilterLength-4) + "'"), + }, + []SortField{}, + 1, + 0, + false, + }, + { + "> max filter length", + []FilterData{ + "1=2", + FilterData("1='" + strings.Repeat("a", MaxFilterLength-3) + "'"), + }, + []SortField{}, + 1, + 0, + true, + }, + { + "<= max filter exprs", + []FilterData{ + "1=2", + "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", + }, + []SortField{}, + 6, + 0, + false, + }, + { + "> max filter exprs", + []FilterData{ + "1=2", + "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", + }, + []SortField{}, + 5, + 0, + true, + }, + + // sort + { + "<= max sort field length", + []FilterData{}, + []SortField{ + {"id", SortAsc}, + {"test1", SortDesc}, + {strings.Repeat("a", MaxSortFieldLength), SortDesc}, + }, + 0, + 10, + false, + }, + { + "> max sort field length", + []FilterData{}, + []SortField{ + {"id", SortAsc}, + {"test1", SortDesc}, + {strings.Repeat("b", MaxSortFieldLength+1), SortDesc}, + }, + 0, + 10, + true, + }, + { + "<= max sort exprs", + []FilterData{}, + []SortField{ + {"id", SortAsc}, + {"test1", SortDesc}, + }, + 0, + 2, + false, + }, + { + "> max sort exprs", + []FilterData{}, + []SortField{ + {"id", SortAsc}, + {"test1", SortDesc}, + }, + 0, + 1, + true, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + testResolver := &testFieldResolver{} + p := NewProvider(testResolver). + Query(query). + Sort(s.sort). + Filter(s.filter). + MaxFilterExprLimit(s.maxFilterExprLimit). + MaxSortExprLimit(s.maxSortExprLimit) + + _, err := p.Exec(&[]testTableStruct{}) + + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr) + } + }) + } +} + func TestProviderParseAndExec(t *testing.T) { testDB, err := createTestDB() if err != nil { @@ -577,7 +754,14 @@ func createTestDB() (*testDB, error) { } db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")} - db.CreateTable("test", map[string]string{"id": "int default 0", "test1": "int default 0", "test2": "text default ''", "test3": "text default ''"}).Execute() + db.CreateTable("test", map[string]string{ + "id": "int default 0", + "test1": "int default 0", + "test2": "text default ''", + "test3": "text default ''", + strings.Repeat("a", MaxSortFieldLength): "text default ''", + strings.Repeat("b", MaxSortFieldLength+1): "text default ''", + }).Execute() db.Insert("test", dbx.Params{"id": 1, "test1": 1, "test2": "test2.1"}).Execute() db.Insert("test", dbx.Params{"id": 2, "test1": 2, "test2": "test2.2"}).Execute() db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {