added search filter and sort limits
This commit is contained in:
		
							parent
							
								
									fc133d8665
								
							
						
					
					
						commit
						45628a919f
					
				| 
						 | 
					@ -7,6 +7,11 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- Added `RateLimitRule.Audience` optional field for restricting a rate limit rule for `"@guest"`-only, `"@auth"`-only, `""`-any (default).
 | 
					- Added `RateLimitRule.Audience` optional field for restricting a rate limit rule for `"@guest"`-only, `"@auth"`-only, `""`-any (default).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Added default max limits for the expressions count and length of the search filter and sort params.
 | 
				
			||||||
 | 
					    _This is just an extra measure mostly for the case when the filter and sort parameters are resolved outside of the request context since the request size limits won't apply._
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Other minor improvements (better error in case of duplicated rate limit rule, fixed typos, resolved lint warnings, etc.).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## v0.23.0-rc12
 | 
					## v0.23.0-rc12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,9 +34,24 @@ var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50))
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
 | 
					// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
 | 
				
			||||||
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
 | 
					// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The parsed expressions are limited up to DefaultFilterExprLimit.
 | 
				
			||||||
 | 
					// Use [FilterData.BuildExprWithLimit] if you want to set a custom limit.
 | 
				
			||||||
func (f FilterData) BuildExpr(
 | 
					func (f FilterData) BuildExpr(
 | 
				
			||||||
	fieldResolver FieldResolver,
 | 
						fieldResolver FieldResolver,
 | 
				
			||||||
	placeholderReplacements ...dbx.Params,
 | 
						placeholderReplacements ...dbx.Params,
 | 
				
			||||||
 | 
					) (dbx.Expression, error) {
 | 
				
			||||||
 | 
						return f.BuildExprWithLimit(fieldResolver, DefaultFilterExprLimit, placeholderReplacements...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// BuildExpr parses the current filter data and returns a new db WHERE expression.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
 | 
				
			||||||
 | 
					// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
 | 
				
			||||||
 | 
					func (f FilterData) BuildExprWithLimit(
 | 
				
			||||||
 | 
						fieldResolver FieldResolver,
 | 
				
			||||||
 | 
						maxExpressions int,
 | 
				
			||||||
 | 
						placeholderReplacements ...dbx.Params,
 | 
				
			||||||
) (dbx.Expression, error) {
 | 
					) (dbx.Expression, error) {
 | 
				
			||||||
	raw := string(f)
 | 
						raw := string(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -64,8 +79,10 @@ func (f FilterData) BuildExpr(
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if data, ok := parsedFilterData.GetOk(raw); ok {
 | 
						cacheKey := raw + "/" + strconv.Itoa(maxExpressions)
 | 
				
			||||||
		return buildParsedFilterExpr(data, fieldResolver)
 | 
					
 | 
				
			||||||
 | 
						if data, ok := parsedFilterData.GetOk(cacheKey); ok {
 | 
				
			||||||
 | 
							return buildParsedFilterExpr(data, fieldResolver, &maxExpressions)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	data, err := fexpr.Parse(raw)
 | 
						data, err := fexpr.Parse(raw)
 | 
				
			||||||
| 
						 | 
					@ -82,14 +99,14 @@ func (f FilterData) BuildExpr(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// store in cache
 | 
						// store in cache
 | 
				
			||||||
	// (the limit size is arbitrary and it is there to prevent the cache growing too big)
 | 
						// (the limit size is arbitrary and it is there to prevent the cache growing too big)
 | 
				
			||||||
	parsedFilterData.SetIfLessThanLimit(raw, data, 500)
 | 
						parsedFilterData.SetIfLessThanLimit(cacheKey, data, 500)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return buildParsedFilterExpr(data, fieldResolver)
 | 
						return buildParsedFilterExpr(data, fieldResolver, &maxExpressions)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) {
 | 
					func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver, maxExpressions *int) (dbx.Expression, error) {
 | 
				
			||||||
	if len(data) == 0 {
 | 
						if len(data) == 0 {
 | 
				
			||||||
		return nil, errors.New("empty filter expression")
 | 
							return nil, fexpr.ErrEmpty
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	result := &concatExpr{separator: " "}
 | 
						result := &concatExpr{separator: " "}
 | 
				
			||||||
| 
						 | 
					@ -100,11 +117,17 @@ func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		switch item := group.Item.(type) {
 | 
							switch item := group.Item.(type) {
 | 
				
			||||||
		case fexpr.Expr:
 | 
							case fexpr.Expr:
 | 
				
			||||||
 | 
								if *maxExpressions <= 0 {
 | 
				
			||||||
 | 
									return nil, ErrFilterExprLimit
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								*maxExpressions--
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			expr, exprErr = resolveTokenizedExpr(item, fieldResolver)
 | 
								expr, exprErr = resolveTokenizedExpr(item, fieldResolver)
 | 
				
			||||||
		case fexpr.ExprGroup:
 | 
							case fexpr.ExprGroup:
 | 
				
			||||||
			expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver)
 | 
								expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver, maxExpressions)
 | 
				
			||||||
		case []fexpr.ExprGroup:
 | 
							case []fexpr.ExprGroup:
 | 
				
			||||||
			expr, exprErr = buildParsedFilterExpr(item, fieldResolver)
 | 
								expr, exprErr = buildParsedFilterExpr(item, fieldResolver, maxExpressions)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			exprErr = errors.New("unsupported expression item")
 | 
								exprErr = errors.New("unsupported expression item")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,7 @@ package search_test
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
| 
						 | 
					@ -239,6 +240,35 @@ func TestFilterDataBuildExprWithParams(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFilterDataBuildExprWithLimit(t *testing.T) {
 | 
				
			||||||
 | 
						resolver := search.NewSimpleFieldResolver(`^\w+$`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scenarios := []struct {
 | 
				
			||||||
 | 
							limit       int
 | 
				
			||||||
 | 
							filter      search.FilterData
 | 
				
			||||||
 | 
							expectError bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{1, "1 = 1", false},
 | 
				
			||||||
 | 
							{0, "1 = 1", true}, // new cache entry should be created
 | 
				
			||||||
 | 
							{2, "1 = 1 || 1 = 1", false},
 | 
				
			||||||
 | 
							{1, "1 = 1 || 1 = 1", true},
 | 
				
			||||||
 | 
							{3, "1 = 1 || 1 = 1", false},
 | 
				
			||||||
 | 
							{6, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", false},
 | 
				
			||||||
 | 
							{5, "(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)", true},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i, s := range scenarios {
 | 
				
			||||||
 | 
							t.Run(fmt.Sprintf("limit_%d:%d", i, s.limit), func(t *testing.T) {
 | 
				
			||||||
 | 
								_, err := s.filter.BuildExprWithLimit(resolver, s.limit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								hasErr := err != nil
 | 
				
			||||||
 | 
								if hasErr != s.expectError {
 | 
				
			||||||
 | 
									t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestLikeParamsWrapping(t *testing.T) {
 | 
					func TestLikeParamsWrapping(t *testing.T) {
 | 
				
			||||||
	// create a dummy db
 | 
						// create a dummy db
 | 
				
			||||||
	sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
 | 
						sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,15 +12,36 @@ import (
 | 
				
			||||||
	"golang.org/x/sync/errgroup"
 | 
						"golang.org/x/sync/errgroup"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DefaultPerPage specifies the default returned search result items.
 | 
					const (
 | 
				
			||||||
const DefaultPerPage int = 30
 | 
						// DefaultPerPage specifies the default number of returned search result items.
 | 
				
			||||||
 | 
						DefaultPerPage int = 30
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// @todo consider making it configurable
 | 
						// DefaultFilterExprLimit specifies the default filter expressions limit.
 | 
				
			||||||
//
 | 
						DefaultFilterExprLimit int = 200
 | 
				
			||||||
// MaxPerPage specifies the maximum allowed search result items returned in a single page.
 | 
					 | 
				
			||||||
const MaxPerPage int = 1000
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
// url search query params
 | 
						// DefaultSortExprLimit specifies the default sort expressions limit.
 | 
				
			||||||
 | 
						DefaultSortExprLimit int = 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// MaxPerPage specifies the max allowed search result items returned in a single page.
 | 
				
			||||||
 | 
						MaxPerPage int = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// MaxFilterLength specifies the max allowed individual search filter parsable length.
 | 
				
			||||||
 | 
						MaxFilterLength int = 3500
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// MaxSortFieldLength specifies the max allowed individual sort field parsable length.
 | 
				
			||||||
 | 
						MaxSortFieldLength int = 255
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Common search errors.
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						ErrEmptyQuery           = errors.New("search query is not set")
 | 
				
			||||||
 | 
						ErrSortExprLimit        = errors.New("max sort expressions limit reached")
 | 
				
			||||||
 | 
						ErrFilterExprLimit      = errors.New("max filter expressions limit reached")
 | 
				
			||||||
 | 
						ErrFilterLengthLimit    = errors.New("max filter length limit reached")
 | 
				
			||||||
 | 
						ErrSortFieldLengthLimit = errors.New("max sort field length limit reached")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// URL search query params
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	PageQueryParam      string = "page"
 | 
						PageQueryParam      string = "page"
 | 
				
			||||||
	PerPageQueryParam   string = "perPage"
 | 
						PerPageQueryParam   string = "perPage"
 | 
				
			||||||
| 
						 | 
					@ -48,9 +69,11 @@ type Provider struct {
 | 
				
			||||||
	page               int
 | 
						page               int
 | 
				
			||||||
	perPage            int
 | 
						perPage            int
 | 
				
			||||||
	skipTotal          bool
 | 
						skipTotal          bool
 | 
				
			||||||
 | 
						maxFilterExprLimit int
 | 
				
			||||||
 | 
						maxSortExprLimit   int
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewProvider creates and returns a new search provider.
 | 
					// NewProvider initializes and returns a new search provider.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Example:
 | 
					// Example:
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -69,9 +92,25 @@ func NewProvider(fieldResolver FieldResolver) *Provider {
 | 
				
			||||||
		perPage:            DefaultPerPage,
 | 
							perPage:            DefaultPerPage,
 | 
				
			||||||
		sort:               []SortField{},
 | 
							sort:               []SortField{},
 | 
				
			||||||
		filter:             []FilterData{},
 | 
							filter:             []FilterData{},
 | 
				
			||||||
 | 
							maxFilterExprLimit: DefaultFilterExprLimit,
 | 
				
			||||||
 | 
							maxSortExprLimit:   DefaultSortExprLimit,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MaxFilterExpressions changes the default max allowed filter expressions.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Note that currently the limit is applied individually for each separate filter.
 | 
				
			||||||
 | 
					func (s *Provider) MaxFilterExprLimit(max int) *Provider {
 | 
				
			||||||
 | 
						s.maxFilterExprLimit = max
 | 
				
			||||||
 | 
						return s
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MaxSortExpressions changes the default max allowed sort expressions.
 | 
				
			||||||
 | 
					func (s *Provider) MaxSortExprLimit(max int) *Provider {
 | 
				
			||||||
 | 
						s.maxSortExprLimit = max
 | 
				
			||||||
 | 
						return s
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Query sets the base query that will be used to fetch the search items.
 | 
					// Query sets the base query that will be used to fetch the search items.
 | 
				
			||||||
func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
 | 
					func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
 | 
				
			||||||
	s.query = query
 | 
						s.query = query
 | 
				
			||||||
| 
						 | 
					@ -188,7 +227,7 @@ func (s *Provider) Parse(urlQuery string) error {
 | 
				
			||||||
// the provided `items` slice with the found models.
 | 
					// the provided `items` slice with the found models.
 | 
				
			||||||
func (s *Provider) Exec(items any) (*Result, error) {
 | 
					func (s *Provider) Exec(items any) (*Result, error) {
 | 
				
			||||||
	if s.query == nil {
 | 
						if s.query == nil {
 | 
				
			||||||
		return nil, errors.New("query is not set")
 | 
							return nil, ErrEmptyQuery
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// shallow clone the provider's query
 | 
						// shallow clone the provider's query
 | 
				
			||||||
| 
						 | 
					@ -196,7 +235,10 @@ func (s *Provider) Exec(items any) (*Result, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// build filters
 | 
						// build filters
 | 
				
			||||||
	for _, f := range s.filter {
 | 
						for _, f := range s.filter {
 | 
				
			||||||
		expr, err := f.BuildExpr(s.fieldResolver)
 | 
							if len(f) > MaxFilterLength {
 | 
				
			||||||
 | 
								return nil, ErrFilterLengthLimit
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							expr, err := f.BuildExprWithLimit(s.fieldResolver, s.maxFilterExprLimit)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					@ -206,7 +248,13 @@ func (s *Provider) Exec(items any) (*Result, error) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// apply sorting
 | 
						// apply sorting
 | 
				
			||||||
 | 
						if len(s.sort) > s.maxSortExprLimit {
 | 
				
			||||||
 | 
							return nil, ErrSortExprLimit
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	for _, sortField := range s.sort {
 | 
						for _, sortField := range s.sort {
 | 
				
			||||||
 | 
							if len(sortField.Name) > MaxSortFieldLength {
 | 
				
			||||||
 | 
								return nil, ErrSortFieldLengthLimit
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		expr, err := sortField.BuildExpr(s.fieldResolver)
 | 
							expr, err := sortField.BuildExpr(s.fieldResolver)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,8 @@ import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,6 +27,46 @@ func TestNewProvider(t *testing.T) {
 | 
				
			||||||
	if p.perPage != DefaultPerPage {
 | 
						if p.perPage != DefaultPerPage {
 | 
				
			||||||
		t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage)
 | 
							t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if p.maxFilterExprLimit != DefaultFilterExprLimit {
 | 
				
			||||||
 | 
							t.Fatalf("Expected maxFilterExprLimit %d, got %d", DefaultFilterExprLimit, p.maxFilterExprLimit)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if p.maxSortExprLimit != DefaultSortExprLimit {
 | 
				
			||||||
 | 
							t.Fatalf("Expected maxSortExprLimit %d, got %d", DefaultSortExprLimit, p.maxSortExprLimit)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMaxFilterExprLimit(t *testing.T) {
 | 
				
			||||||
 | 
						p := NewProvider(&testFieldResolver{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						testVals := []int{0, -10, 10}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, val := range testVals {
 | 
				
			||||||
 | 
							t.Run("max_"+strconv.Itoa(val), func(t *testing.T) {
 | 
				
			||||||
 | 
								p.MaxFilterExprLimit(val)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if p.maxFilterExprLimit != val {
 | 
				
			||||||
 | 
									t.Fatalf("Expected maxFilterExprLimit to change to %d, got %d", val, p.maxFilterExprLimit)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMaxSortExprLimit(t *testing.T) {
 | 
				
			||||||
 | 
						p := NewProvider(&testFieldResolver{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						testVals := []int{0, -10, 10}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, val := range testVals {
 | 
				
			||||||
 | 
							t.Run("max_"+strconv.Itoa(val), func(t *testing.T) {
 | 
				
			||||||
 | 
								p.MaxSortExprLimit(val)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if p.maxSortExprLimit != val {
 | 
				
			||||||
 | 
									t.Fatalf("Expected maxSortExprLimit to change to %d, got %d", val, p.maxSortExprLimit)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProviderQuery(t *testing.T) {
 | 
					func TestProviderQuery(t *testing.T) {
 | 
				
			||||||
| 
						 | 
					@ -428,6 +470,141 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestProviderFilterAndSortLimits(t *testing.T) {
 | 
				
			||||||
 | 
						testDB, err := createTestDB()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer testDB.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						query := testDB.Select("*").
 | 
				
			||||||
 | 
							From("test").
 | 
				
			||||||
 | 
							Where(dbx.Not(dbx.HashExp{"test1": nil})).
 | 
				
			||||||
 | 
							OrderBy("test1 ASC")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						scenarios := []struct {
 | 
				
			||||||
 | 
							name               string
 | 
				
			||||||
 | 
							filter             []FilterData
 | 
				
			||||||
 | 
							sort               []SortField
 | 
				
			||||||
 | 
							maxFilterExprLimit int
 | 
				
			||||||
 | 
							maxSortExprLimit   int
 | 
				
			||||||
 | 
							expectError        bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							// filter
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"<= max filter length",
 | 
				
			||||||
 | 
								[]FilterData{
 | 
				
			||||||
 | 
									"1=2",
 | 
				
			||||||
 | 
									FilterData("1='" + strings.Repeat("a", MaxFilterLength-4) + "'"),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								[]SortField{},
 | 
				
			||||||
 | 
								1,
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"> max filter length",
 | 
				
			||||||
 | 
								[]FilterData{
 | 
				
			||||||
 | 
									"1=2",
 | 
				
			||||||
 | 
									FilterData("1='" + strings.Repeat("a", MaxFilterLength-3) + "'"),
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								[]SortField{},
 | 
				
			||||||
 | 
								1,
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"<= max filter exprs",
 | 
				
			||||||
 | 
								[]FilterData{
 | 
				
			||||||
 | 
									"1=2",
 | 
				
			||||||
 | 
									"(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								[]SortField{},
 | 
				
			||||||
 | 
								6,
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"> max filter exprs",
 | 
				
			||||||
 | 
								[]FilterData{
 | 
				
			||||||
 | 
									"1=2",
 | 
				
			||||||
 | 
									"(1=1 || 1=1) && (1=1 || (1=1 || 1=1)) && (1=1)",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								[]SortField{},
 | 
				
			||||||
 | 
								5,
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// sort
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"<= max sort field length",
 | 
				
			||||||
 | 
								[]FilterData{},
 | 
				
			||||||
 | 
								[]SortField{
 | 
				
			||||||
 | 
									{"id", SortAsc},
 | 
				
			||||||
 | 
									{"test1", SortDesc},
 | 
				
			||||||
 | 
									{strings.Repeat("a", MaxSortFieldLength), SortDesc},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								10,
 | 
				
			||||||
 | 
								false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"> max sort field length",
 | 
				
			||||||
 | 
								[]FilterData{},
 | 
				
			||||||
 | 
								[]SortField{
 | 
				
			||||||
 | 
									{"id", SortAsc},
 | 
				
			||||||
 | 
									{"test1", SortDesc},
 | 
				
			||||||
 | 
									{strings.Repeat("b", MaxSortFieldLength+1), SortDesc},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								10,
 | 
				
			||||||
 | 
								true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"<= max sort exprs",
 | 
				
			||||||
 | 
								[]FilterData{},
 | 
				
			||||||
 | 
								[]SortField{
 | 
				
			||||||
 | 
									{"id", SortAsc},
 | 
				
			||||||
 | 
									{"test1", SortDesc},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								2,
 | 
				
			||||||
 | 
								false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								"> max sort exprs",
 | 
				
			||||||
 | 
								[]FilterData{},
 | 
				
			||||||
 | 
								[]SortField{
 | 
				
			||||||
 | 
									{"id", SortAsc},
 | 
				
			||||||
 | 
									{"test1", SortDesc},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								0,
 | 
				
			||||||
 | 
								1,
 | 
				
			||||||
 | 
								true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, s := range scenarios {
 | 
				
			||||||
 | 
							t.Run(s.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								testResolver := &testFieldResolver{}
 | 
				
			||||||
 | 
								p := NewProvider(testResolver).
 | 
				
			||||||
 | 
									Query(query).
 | 
				
			||||||
 | 
									Sort(s.sort).
 | 
				
			||||||
 | 
									Filter(s.filter).
 | 
				
			||||||
 | 
									MaxFilterExprLimit(s.maxFilterExprLimit).
 | 
				
			||||||
 | 
									MaxSortExprLimit(s.maxSortExprLimit)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								_, err := p.Exec(&[]testTableStruct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								hasErr := err != nil
 | 
				
			||||||
 | 
								if hasErr != s.expectError {
 | 
				
			||||||
 | 
									t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestProviderParseAndExec(t *testing.T) {
 | 
					func TestProviderParseAndExec(t *testing.T) {
 | 
				
			||||||
	testDB, err := createTestDB()
 | 
						testDB, err := createTestDB()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -577,7 +754,14 @@ func createTestDB() (*testDB, error) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
 | 
						db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
 | 
				
			||||||
	db.CreateTable("test", map[string]string{"id": "int default 0", "test1": "int default 0", "test2": "text default ''", "test3": "text default ''"}).Execute()
 | 
						db.CreateTable("test", map[string]string{
 | 
				
			||||||
 | 
							"id":                                    "int default 0",
 | 
				
			||||||
 | 
							"test1":                                 "int default 0",
 | 
				
			||||||
 | 
							"test2":                                 "text default ''",
 | 
				
			||||||
 | 
							"test3":                                 "text default ''",
 | 
				
			||||||
 | 
							strings.Repeat("a", MaxSortFieldLength): "text default ''",
 | 
				
			||||||
 | 
							strings.Repeat("b", MaxSortFieldLength+1): "text default ''",
 | 
				
			||||||
 | 
						}).Execute()
 | 
				
			||||||
	db.Insert("test", dbx.Params{"id": 1, "test1": 1, "test2": "test2.1"}).Execute()
 | 
						db.Insert("test", dbx.Params{"id": 1, "test1": 1, "test2": "test2.1"}).Execute()
 | 
				
			||||||
	db.Insert("test", dbx.Params{"id": 2, "test1": 2, "test2": "test2.2"}).Execute()
 | 
						db.Insert("test", dbx.Params{"id": 2, "test1": 2, "test2": "test2.2"}).Execute()
 | 
				
			||||||
	db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
 | 
						db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue