initial public commit
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ganigeorgiev/fexpr"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// FilterData is a filter expession string following the `fexpr` package grammar.
|
||||
//
|
||||
// Example:
|
||||
// var filter FilterData = "id = null || (name = 'test' && status = true)"
|
||||
// resolver := search.NewSimpleFieldResolver("id", "name", "status")
|
||||
// expr, err := filter.BuildExpr(resolver)
|
||||
type FilterData string
|
||||
|
||||
// parsedFilterData holds a cache with previously parsed filter data expressions
|
||||
// (initialized with some prealocated empty data map)
|
||||
var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50))
|
||||
|
||||
// BuildExpr parses the current filter data and returns a new db WHERE expression.
|
||||
func (f FilterData) BuildExpr(fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
raw := string(f)
|
||||
var data []fexpr.ExprGroup
|
||||
|
||||
if parsedFilterData.Has(raw) {
|
||||
data = parsedFilterData.Get(raw)
|
||||
} else {
|
||||
var err error
|
||||
data, err = fexpr.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// store in cache
|
||||
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
|
||||
parsedFilterData.SetIfLessThanLimit(raw, data, 500)
|
||||
}
|
||||
|
||||
return f.build(data, fieldResolver)
|
||||
}
|
||||
|
||||
func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("Empty filter expression.")
|
||||
}
|
||||
|
||||
var result dbx.Expression
|
||||
|
||||
for _, group := range data {
|
||||
var expr dbx.Expression
|
||||
var exprErr error
|
||||
|
||||
switch item := group.Item.(type) {
|
||||
case fexpr.Expr:
|
||||
expr, exprErr = f.resolveTokenizedExpr(item, fieldResolver)
|
||||
case fexpr.ExprGroup:
|
||||
expr, exprErr = f.build([]fexpr.ExprGroup{item}, fieldResolver)
|
||||
case []fexpr.ExprGroup:
|
||||
expr, exprErr = f.build(item, fieldResolver)
|
||||
default:
|
||||
exprErr = errors.New("Unsupported expression item.")
|
||||
}
|
||||
|
||||
if exprErr != nil {
|
||||
return nil, exprErr
|
||||
}
|
||||
|
||||
if group.Join == fexpr.JoinAnd {
|
||||
result = dbx.And(result, expr)
|
||||
} else {
|
||||
result = dbx.Or(result, expr)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
lName, lParams, lErr := f.resolveToken(expr.Left, fieldResolver)
|
||||
if lName == "" || lErr != nil {
|
||||
return nil, fmt.Errorf("Invalid left operand %q - %v.", expr.Left.Literal, lErr)
|
||||
}
|
||||
|
||||
rName, rParams, rErr := f.resolveToken(expr.Right, fieldResolver)
|
||||
if rName == "" || rErr != nil {
|
||||
return nil, fmt.Errorf("Invalid right operand %q - %v.", expr.Right.Literal, rErr)
|
||||
}
|
||||
|
||||
// merge both operands parameters (if any)
|
||||
params := dbx.Params{}
|
||||
if len(lParams) > 0 {
|
||||
for k, v := range lParams {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
if len(rParams) > 0 {
|
||||
for k, v := range rParams {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
switch expr.Op {
|
||||
case fexpr.SignEq:
|
||||
op := "="
|
||||
if strings.ToLower(lName) == "null" || strings.ToLower(rName) == "null" {
|
||||
op = "IS"
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s %s %s", lName, op, rName), params), nil
|
||||
case fexpr.SignNeq:
|
||||
op := "!="
|
||||
if strings.ToLower(lName) == "null" || strings.ToLower(rName) == "null" {
|
||||
op = "IS NOT"
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s %s %s", lName, op, rName), params), nil
|
||||
case fexpr.SignLike:
|
||||
// normalize operands and switch sides if the left operand is a number or text
|
||||
if len(lParams) > 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", rName, lName), f.normalizeLikeParams(params)), nil
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", lName, rName), f.normalizeLikeParams(params)), nil
|
||||
case fexpr.SignNlike:
|
||||
// normalize operands and switch sides if the left operand is a number or text
|
||||
if len(lParams) > 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", rName, lName), f.normalizeLikeParams(params)), nil
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", lName, rName), f.normalizeLikeParams(params)), nil
|
||||
case fexpr.SignLt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s < %s", lName, rName), params), nil
|
||||
case fexpr.SignLte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s <= %s", lName, rName), params), nil
|
||||
case fexpr.SignGt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s > %s", lName, rName), params), nil
|
||||
case fexpr.SignGte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s >= %s", lName, rName), params), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown expression operator %q", expr.Op)
|
||||
}
|
||||
|
||||
func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver) (name string, params dbx.Params, err error) {
|
||||
if token.Type == fexpr.TokenIdentifier {
|
||||
name, params, err := fieldResolver.Resolve(token.Literal)
|
||||
|
||||
if name == "" || err != nil {
|
||||
// if `null` field is missing, treat `null` identifier as NULL token
|
||||
if strings.ToLower(token.Literal) == "null" {
|
||||
return "NULL", nil, nil
|
||||
}
|
||||
|
||||
// if `true` field is missing, treat `true` identifier as TRUE token
|
||||
if strings.ToLower(token.Literal) == "true" {
|
||||
return "1", nil, nil
|
||||
}
|
||||
|
||||
// if `false` field is missing, treat `false` identifier as FALSE token
|
||||
if strings.ToLower(token.Literal) == "false" {
|
||||
return "0", nil, nil
|
||||
}
|
||||
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return name, params, err
|
||||
}
|
||||
|
||||
if token.Type == fexpr.TokenNumber || token.Type == fexpr.TokenText {
|
||||
placeholder := "t" + security.RandomString(7)
|
||||
name := fmt.Sprintf("{:%s}", placeholder)
|
||||
params := dbx.Params{placeholder: token.Literal}
|
||||
return name, params, nil
|
||||
}
|
||||
|
||||
return "", nil, errors.New("Unresolvable token type.")
|
||||
}
|
||||
|
||||
func (f FilterData) normalizeLikeParams(params dbx.Params) dbx.Params {
|
||||
result := dbx.Params{}
|
||||
|
||||
if len(params) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
for k, v := range params {
|
||||
vStr := cast.ToString(v)
|
||||
if !strings.Contains(vStr, "%") {
|
||||
vStr = "%" + vStr + "%"
|
||||
}
|
||||
result[k] = vStr
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestFilterDataBuildExpr(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", "test4.sub")
|
||||
|
||||
scenarios := []struct {
|
||||
filterData search.FilterData
|
||||
expectError bool
|
||||
expectPattern string
|
||||
}{
|
||||
// empty
|
||||
{"", true, ""},
|
||||
// invalid format
|
||||
{"(test1 > 1", true, ""},
|
||||
// invalid operator
|
||||
{"test1 + 123", true, ""},
|
||||
// unknown field
|
||||
{"test1 = 'example' && unknown > 1", true, ""},
|
||||
// simple expression
|
||||
{"test1 > 1", false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("[[test1]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}") +
|
||||
"$",
|
||||
},
|
||||
// complex expression
|
||||
{
|
||||
"((test1 > 1) || (test2 != 2)) && test3 ~ '%%example' && test4.sub = null",
|
||||
false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("((([[test1]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}) OR ([[test2]] != {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test4.sub]] IS NULL)") +
|
||||
"$",
|
||||
},
|
||||
// combination of special literals (null, true, false)
|
||||
{
|
||||
"test1=true && test2 != false && test3 = null || test4.sub != null",
|
||||
false,
|
||||
"^" + regexp.QuoteMeta("((([[test1]] = 1) AND ([[test2]] != 0)) AND ([[test3]] IS NULL)) OR ([[test4.sub]] IS NOT NULL)") + "$",
|
||||
},
|
||||
// all operators
|
||||
{
|
||||
"(test1 = test2 || test2 != test3) && (test2 ~ 'example' || test2 !~ '%%abc') && 'switch1%%' ~ test1 && 'switch2' !~ test2 && test3 > 1 && test3 >= 0 && test3 <= 4 && 2 < 5",
|
||||
false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("(((((((([[test1]] = [[test2]]) OR ([[test2]] != [[test3]])) AND (([[test2]] LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}) OR ([[test2]] NOT LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}))) AND ([[test1]] LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test2]] NOT LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] >= {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] <= {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ({:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("} < {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})") +
|
||||
"$",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
expr, err := s.filterData.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
dummyDB := &dbx.DB{}
|
||||
rawSql := expr.Build(dummyDB, map[string]any{})
|
||||
|
||||
pattern := regexp.MustCompile(s.expectPattern)
|
||||
if !pattern.MatchString(rawSql) {
|
||||
t.Errorf("(%d) Pattern %v don't match with expression: \n%v", i, s.expectPattern, rawSql)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// DefaultPerPage specifies the default returned search result items.
|
||||
const DefaultPerPage int = 30
|
||||
|
||||
// MaxPerPage specifies the maximum allowed search result items returned in a single page.
|
||||
const MaxPerPage int = 200
|
||||
|
||||
// url search query params
|
||||
const (
|
||||
PageQueryParam string = "page"
|
||||
PerPageQueryParam string = "perPage"
|
||||
SortQueryParam string = "sort"
|
||||
FilterQueryParam string = "filter"
|
||||
)
|
||||
|
||||
// Result defines the returned search result structure.
|
||||
type Result struct {
|
||||
Page int `json:"page"`
|
||||
PerPage int `json:"perPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
Items any `json:"items"`
|
||||
}
|
||||
|
||||
// Provider represents a single configured search provider instance.
|
||||
type Provider struct {
|
||||
fieldResolver FieldResolver
|
||||
query *dbx.SelectQuery
|
||||
page int
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
}
|
||||
|
||||
// NewProvider creates and returns a new search provider.
|
||||
//
|
||||
// Example:
|
||||
// baseQuery := db.Select("*").From("user")
|
||||
// fieldResolver := search.NewSimpleFieldResolver("id", "name")
|
||||
// models := []*YourDataStruct{}
|
||||
//
|
||||
// result, err := search.NewProvider(fieldResolver).
|
||||
// Query(baseQuery).
|
||||
// ParseAndExec("page=2&filter=id>0&sort=-name", &models)
|
||||
func NewProvider(fieldResolver FieldResolver) *Provider {
|
||||
return &Provider{
|
||||
fieldResolver: fieldResolver,
|
||||
page: 1,
|
||||
perPage: DefaultPerPage,
|
||||
sort: []SortField{},
|
||||
filter: []FilterData{},
|
||||
}
|
||||
}
|
||||
|
||||
// Query sets the base query that will be used to fetch the search items.
|
||||
func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
|
||||
s.query = query
|
||||
return s
|
||||
}
|
||||
|
||||
// Page sets the `page` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `page` value is done during `Exec()`.
|
||||
func (s *Provider) Page(page int) *Provider {
|
||||
s.page = page
|
||||
return s
|
||||
}
|
||||
|
||||
// PerPage sets the `perPage` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `perPage` value is done during `Exec()`.
|
||||
func (s *Provider) PerPage(perPage int) *Provider {
|
||||
s.perPage = perPage
|
||||
return s
|
||||
}
|
||||
|
||||
// Sort sets the `sort` field of the current search provider.
|
||||
func (s *Provider) Sort(sort []SortField) *Provider {
|
||||
s.sort = sort
|
||||
return s
|
||||
}
|
||||
|
||||
// AddSort appends the provided SortField to the existing provider's sort field.
|
||||
func (s *Provider) AddSort(field SortField) *Provider {
|
||||
s.sort = append(s.sort, field)
|
||||
return s
|
||||
}
|
||||
|
||||
// Filter sets the `filter` field of the current search provider.
|
||||
func (s *Provider) Filter(filter []FilterData) *Provider {
|
||||
s.filter = filter
|
||||
return s
|
||||
}
|
||||
|
||||
// AddFilter appends the provided FilterData to the existing provider's filter field.
|
||||
func (s *Provider) AddFilter(filter FilterData) *Provider {
|
||||
if filter != "" {
|
||||
s.filter = append(s.filter, filter)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Parse parses the search query parameter from the provided query string
|
||||
// and assigns the found fields to the current search provider.
|
||||
//
|
||||
// The data from the "sort" and "filter" query parameters are appended
|
||||
// to the existing provider's `sort` and `filter` fields
|
||||
// (aka. using `AddSort` and `AddFilter`).
|
||||
func (s *Provider) Parse(urlQuery string) error {
|
||||
params, err := url.ParseQuery(urlQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawPage := params.Get(PageQueryParam)
|
||||
if rawPage != "" {
|
||||
page, err := strconv.Atoi(rawPage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Page(page)
|
||||
}
|
||||
|
||||
rawPerPage := params.Get(PerPageQueryParam)
|
||||
if rawPerPage != "" {
|
||||
perPage, err := strconv.Atoi(rawPerPage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.PerPage(perPage)
|
||||
}
|
||||
|
||||
rawSort := params.Get(SortQueryParam)
|
||||
if rawSort != "" {
|
||||
for _, sortField := range ParseSortFromString(rawSort) {
|
||||
s.AddSort(sortField)
|
||||
}
|
||||
}
|
||||
|
||||
rawFilter := params.Get(FilterQueryParam)
|
||||
if rawFilter != "" {
|
||||
s.AddFilter(FilterData(rawFilter))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes the search provider and fills/scans
|
||||
// the provided `items` slice with the found models.
|
||||
func (s *Provider) Exec(items any) (*Result, error) {
|
||||
if s.query == nil {
|
||||
return nil, errors.New("Query is not set.")
|
||||
}
|
||||
|
||||
// clone provider's query
|
||||
modelsQuery := *s.query
|
||||
|
||||
// apply filters
|
||||
if len(s.filter) > 0 {
|
||||
for _, f := range s.filter {
|
||||
expr, err := f.BuildExpr(s.fieldResolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != nil {
|
||||
modelsQuery.AndWhere(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply sorting
|
||||
if len(s.sort) > 0 {
|
||||
for _, sortField := range s.sort {
|
||||
expr, err := sortField.BuildExpr(s.fieldResolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != "" {
|
||||
modelsQuery.AndOrderBy(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply field resolver query modifications (if any)
|
||||
updateQueryErr := s.fieldResolver.UpdateQuery(&modelsQuery)
|
||||
if updateQueryErr != nil {
|
||||
return nil, updateQueryErr
|
||||
}
|
||||
|
||||
// count
|
||||
var totalCount int64
|
||||
countQuery := modelsQuery
|
||||
if err := countQuery.Select("count(*)").Row(&totalCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// normalize perPage
|
||||
if s.perPage <= 0 {
|
||||
s.perPage = DefaultPerPage
|
||||
} else if s.perPage > MaxPerPage {
|
||||
s.perPage = MaxPerPage
|
||||
}
|
||||
|
||||
// normalize page accoring to the total count
|
||||
if s.page <= 0 || totalCount == 0 {
|
||||
s.page = 1
|
||||
} else if totalPages := int(math.Ceil(float64(totalCount) / float64(s.perPage))); s.page > totalPages {
|
||||
s.page = totalPages
|
||||
}
|
||||
|
||||
// apply pagination
|
||||
modelsQuery.Limit(int64(s.perPage))
|
||||
modelsQuery.Offset(int64(s.perPage * (s.page - 1)))
|
||||
|
||||
// fetch models
|
||||
if err := modelsQuery.All(items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Result{
|
||||
Page: s.page,
|
||||
PerPage: s.perPage,
|
||||
TotalItems: int(totalCount),
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseAndExec is a short conventient method to trigger both
|
||||
// `Parse()` and `Exec()` in a single call.
|
||||
func (s *Provider) ParseAndExec(urlQuery string, modelsSlice any) (*Result, error) {
|
||||
if err := s.Parse(urlQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Exec(modelsSlice)
|
||||
}
|
||||
@@ -0,0 +1,505 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r)
|
||||
|
||||
if p.page != 1 {
|
||||
t.Fatalf("Expected page %d, got %d", 1, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != DefaultPerPage {
|
||||
t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderQuery(t *testing.T) {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
querySql := query.Build().SQL()
|
||||
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).Query(query)
|
||||
|
||||
expected := p.query.Build().SQL()
|
||||
|
||||
if querySql != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, querySql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderPage(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).Page(10)
|
||||
|
||||
if p.page != 10 {
|
||||
t.Fatalf("Expected page %v, got %v", 10, p.page)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderPerPage(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).PerPage(456)
|
||||
|
||||
if p.perPage != 456 {
|
||||
t.Fatalf("Expected perPage %v, got %v", 456, p.perPage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderSort(t *testing.T) {
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Sort(initialSort).
|
||||
AddSort(SortField{"test3", SortDesc})
|
||||
|
||||
encoded, _ := json.Marshal(p.sort)
|
||||
expected := `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"test3","direction":"DESC"}]`
|
||||
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected sort %v, got \n%v", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFilter(t *testing.T) {
|
||||
initialFilter := []FilterData{"test1", "test2"}
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Filter(initialFilter).
|
||||
AddFilter("test3")
|
||||
|
||||
encoded, _ := json.Marshal(p.filter)
|
||||
expected := `["test1","test2","test3"]`
|
||||
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected filter %v, got \n%v", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderParse(t *testing.T) {
|
||||
initialPage := 2
|
||||
initialPerPage := 123
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
initialFilter := []FilterData{"test1", "test2"}
|
||||
|
||||
scenarios := []struct {
|
||||
query string
|
||||
expectError bool
|
||||
expectPage int
|
||||
expectPerPage int
|
||||
expectSort string
|
||||
expectFilter string
|
||||
}{
|
||||
// empty
|
||||
{
|
||||
"",
|
||||
false,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid query
|
||||
{
|
||||
"invalid;",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid page
|
||||
{
|
||||
"page=a",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid perPage
|
||||
{
|
||||
"perPage=a",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// valid query parameters
|
||||
{
|
||||
"page=3&perPage=456&filter=test3&sort=-a,b,+c&other=123",
|
||||
false,
|
||||
3,
|
||||
456,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"a","direction":"DESC"},{"name":"b","direction":"ASC"},{"name":"c","direction":"ASC"}]`,
|
||||
`["test1","test2","test3"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Page(initialPage).
|
||||
PerPage(initialPerPage).
|
||||
Sort(initialSort).
|
||||
Filter(initialFilter)
|
||||
|
||||
err := p.Parse(s.query)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if p.page != s.expectPage {
|
||||
t.Errorf("(%d) Expected page %v, got %v", i, s.expectPage, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != s.expectPerPage {
|
||||
t.Errorf("(%d) Expected perPage %v, got %v", i, s.expectPerPage, p.perPage)
|
||||
}
|
||||
|
||||
encodedSort, _ := json.Marshal(p.sort)
|
||||
if string(encodedSort) != s.expectSort {
|
||||
t.Errorf("(%d) Expected sort %v, got \n%v", i, s.expectSort, string(encodedSort))
|
||||
}
|
||||
|
||||
encodedFilter, _ := json.Marshal(p.filter)
|
||||
if string(encodedFilter) != s.expectFilter {
|
||||
t.Errorf("(%d) Expected filter %v, got \n%v", i, s.expectFilter, string(encodedFilter))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderExecEmptyQuery(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{}).
|
||||
Query(nil)
|
||||
|
||||
_, err := p.Exec(&[]testTableStruct{})
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error with empty query, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
page int
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
expectError bool
|
||||
expectResult string
|
||||
expectQueries []string
|
||||
}{
|
||||
// page normalization
|
||||
{
|
||||
-1,
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
// perPage normalization
|
||||
{
|
||||
10, // will be capped by total count
|
||||
0, // fallback to default
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
`{"page":1,"perPage":30,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30",
|
||||
},
|
||||
},
|
||||
// invalid sort field
|
||||
{
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"unknown", SortAsc}},
|
||||
[]FilterData{},
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
// invalid filter
|
||||
{
|
||||
1,
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{"test2 = 'test2.1'", "invalid"},
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
// valid sort and filter fields
|
||||
{
|
||||
1,
|
||||
5555, // will be limited by MaxPerPage
|
||||
[]SortField{{"test2", SortDesc}},
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
false,
|
||||
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (test2 IS NOT null)) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC",
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (test2 IS NOT null)) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC LIMIT 200",
|
||||
},
|
||||
},
|
||||
// valid sort and filter fields (zero results)
|
||||
{
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"test3", SortAsc}},
|
||||
[]FilterData{"test3 != ''"},
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":0,"items":[]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (test3 != '') ORDER BY `test1` ASC, `test3` ASC",
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (test3 != '') ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
// pagination test
|
||||
{
|
||||
3,
|
||||
1,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
`{"page":2,"perPage":1,"totalItems":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
p := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(s.page).
|
||||
PerPage(s.perPage).
|
||||
Sort(s.sort).
|
||||
Filter(s.filter)
|
||||
|
||||
result, err := p.Exec(&[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Errorf("(%d) Expected resolver.Update to be called %d, got %d", i, 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Errorf("(%d) Expected result %v, got \n%v", i, s.expectResult, string(encoded))
|
||||
}
|
||||
|
||||
if len(s.expectQueries) != len(testDB.CalledQueries) {
|
||||
t.Errorf("(%d) Expected %d queries, got %d: \n%v", i, len(s.expectQueries), len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, q := range testDB.CalledQueries {
|
||||
if !list.ExistInSliceWithRegex(q, s.expectQueries) {
|
||||
t.Errorf("(%d) Didn't expect query %v", i, q)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderParseAndExec(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
queryString string
|
||||
expectError bool
|
||||
expectResult string
|
||||
}{
|
||||
// empty
|
||||
{
|
||||
"",
|
||||
false,
|
||||
`{"page":1,"perPage":123,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
},
|
||||
// invalid query
|
||||
{
|
||||
"invalid;",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid page
|
||||
{
|
||||
"page=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid perPage
|
||||
{
|
||||
"perPage=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid sorting field
|
||||
{
|
||||
"sort=-unknown",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid filter field
|
||||
{
|
||||
"filter=unknown>1",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// valid query params
|
||||
{
|
||||
"page=3&perPage=555&filter=test1>1&sort=-test2,test3",
|
||||
false,
|
||||
`{"page":1,"perPage":200,"totalItems":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
provider := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(2).
|
||||
PerPage(123).
|
||||
Sort([]SortField{{"test2", SortAsc}}).
|
||||
Filter([]FilterData{"test1 > 0"})
|
||||
|
||||
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Errorf("(%d) Expected resolver.Update to be called %d, got %d", i, 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
if len(testDB.CalledQueries) != 2 {
|
||||
t.Errorf("(%d) Expected %d db queries, got %d: \n%v", i, 2, len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Errorf("(%d) Expected result %v, got \n%v", i, s.expectResult, string(encoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testTableStruct struct {
|
||||
Test1 int `db:"test1" json:"test1"`
|
||||
Test2 string `db:"test2" json:"test2"`
|
||||
Test3 string `db:"test3" json:"test3"`
|
||||
}
|
||||
|
||||
type testDB struct {
|
||||
*dbx.DB
|
||||
CalledQueries []string
|
||||
}
|
||||
|
||||
// NB! Don't forget to call `db.Close()` at the end of the test.
|
||||
func createTestDB() (*testDB, error) {
|
||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.CreateTable("test", map[string]string{"test1": "int default 0", "test2": "text default ''", "test3": "text default ''"}).Execute()
|
||||
db.Insert("test", dbx.Params{"test1": 1, "test2": "test2.1"}).Execute()
|
||||
db.Insert("test", dbx.Params{"test1": 2, "test2": "test2.2"}).Execute()
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
type testFieldResolver struct {
|
||||
UpdateQueryCalls int
|
||||
ResolveCalls int
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
t.UpdateQueryCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) Resolve(field string) (name string, placeholderParams dbx.Params, err error) {
|
||||
t.ResolveCalls++
|
||||
|
||||
if field == "unknown" {
|
||||
return "", nil, errors.New("test error")
|
||||
}
|
||||
|
||||
return field, nil, nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
// FieldResolver defines an interface for managing search fields.
|
||||
type FieldResolver interface {
|
||||
// UpdateQuery allows to updated the provided db query based on the
|
||||
// resolved search fields (eg. adding joins aliases, etc.).
|
||||
//
|
||||
// Called internally by `search.Provider` before executing the search request.
|
||||
UpdateQuery(query *dbx.SelectQuery) error
|
||||
|
||||
// Resolve parses the provided field and returns a properly
|
||||
// formatted db identifier (eg. NULL, quoted column, placeholder parameter, etc.).
|
||||
Resolve(field string) (name string, placeholderParams dbx.Params, err error)
|
||||
}
|
||||
|
||||
// NewSimpleFieldResolver creates a new `SimpleFieldResolver` with the
|
||||
// provided `allowedFields`.
|
||||
//
|
||||
// Each `allowedFields` could be a plain string (eg. "name")
|
||||
// or a regexp pattern (eg. `^\w+[\w\.]*$`).
|
||||
func NewSimpleFieldResolver(allowedFields ...string) *SimpleFieldResolver {
|
||||
return &SimpleFieldResolver{
|
||||
allowedFields: allowedFields,
|
||||
}
|
||||
}
|
||||
|
||||
// SimpleFieldResolver defines a generic search resolver that allows
|
||||
// only its listed fields to be resolved and take part in a search query.
|
||||
//
|
||||
// If `allowedFields` are empty no fields filtering is applied.
|
||||
type SimpleFieldResolver struct {
|
||||
allowedFields []string
|
||||
}
|
||||
|
||||
// UpdateQuery implements `search.UpdateQuery` interface.
|
||||
func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
// nothing to update...
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve implements `search.Resolve` interface.
|
||||
//
|
||||
// Returns error if `field` is not in `r.allowedFields`.
|
||||
func (r *SimpleFieldResolver) Resolve(field string) (resultName string, placeholderParams dbx.Params, err error) {
|
||||
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
|
||||
return "", nil, fmt.Errorf("Failed to resolve field %q.", field)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[[%s]]", inflector.Columnify(field)), nil, nil
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSimpleFieldResolverUpdateQuery(t *testing.T) {
|
||||
r := search.NewSimpleFieldResolver("test")
|
||||
|
||||
scenarios := []struct {
|
||||
fieldName string
|
||||
expectQuery string
|
||||
}{
|
||||
// missing field (the query shouldn't change)
|
||||
{"", `SELECT "id" FROM "test"`},
|
||||
// unknown field (the query shouldn't change)
|
||||
{"unknown", `SELECT "id" FROM "test"`},
|
||||
// allowed field (the query shouldn't change)
|
||||
{"test", `SELECT "id" FROM "test"`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
|
||||
r.Resolve(s.fieldName)
|
||||
|
||||
if err := r.UpdateQuery(nil); err != nil {
|
||||
t.Errorf("(%d) UpdateQuery failed with error %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
rawQuery := query.Build().SQL()
|
||||
// rawQuery := s.expectQuery
|
||||
|
||||
if rawQuery != s.expectQuery {
|
||||
t.Errorf("(%d) Expected query %v, got \n%v", i, s.expectQuery, rawQuery)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleFieldResolverResolve(t *testing.T) {
|
||||
r := search.NewSimpleFieldResolver("test", `^test_regex\d+$`, "Test columnify!")
|
||||
|
||||
scenarios := []struct {
|
||||
fieldName string
|
||||
expectError bool
|
||||
expectName string
|
||||
}{
|
||||
{"", true, ""},
|
||||
{" ", true, ""},
|
||||
{"unknown", true, ""},
|
||||
{"test", false, "[[test]]"},
|
||||
{"test.sub", true, ""},
|
||||
{"test_regex", true, ""},
|
||||
{"test_regex1", false, "[[test_regex1]]"},
|
||||
{"Test columnify!", false, "[[Testcolumnify]]"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
name, params, err := r.Resolve(s.fieldName)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if name != s.expectName {
|
||||
t.Errorf("(%d) Expected name %q, got %q", i, s.expectName, name)
|
||||
}
|
||||
|
||||
// params should be empty
|
||||
if len(params) != 0 {
|
||||
t.Errorf("(%d) Expected 0 params, got %v", i, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// sort field directions
|
||||
const (
|
||||
SortAsc string = "ASC"
|
||||
SortDesc string = "DESC"
|
||||
)
|
||||
|
||||
// SortField defines a single search sort field.
|
||||
type SortField struct {
|
||||
Name string `json:"name"`
|
||||
Direction string `json:"direction"`
|
||||
}
|
||||
|
||||
// BuildExpr resolves the sort field into a valid db sort expression.
|
||||
func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
|
||||
name, params, err := fieldResolver.Resolve(s.Name)
|
||||
|
||||
// invalidate empty fields and non-column identifiers
|
||||
if err != nil || len(params) > 0 || name == "" || strings.ToLower(name) == "null" {
|
||||
return "", fmt.Errorf("Invalid sort field %q.", s.Name)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s", name, s.Direction), nil
|
||||
}
|
||||
|
||||
// ParseSortFromString parses the provided string expression
|
||||
// into a slice of SortFields.
|
||||
//
|
||||
// Example:
|
||||
// fields := search.ParseSortFromString("-name,+created")
|
||||
func ParseSortFromString(str string) []SortField {
|
||||
result := []SortField{}
|
||||
|
||||
data := strings.Split(str, ",")
|
||||
|
||||
for _, field := range data {
|
||||
// trim whitespaces
|
||||
field = strings.TrimSpace(field)
|
||||
|
||||
var dir string
|
||||
if strings.HasPrefix(field, "-") {
|
||||
dir = SortDesc
|
||||
field = strings.TrimPrefix(field, "-")
|
||||
} else {
|
||||
dir = SortAsc
|
||||
field = strings.TrimPrefix(field, "+")
|
||||
}
|
||||
|
||||
result = append(result, SortField{field, dir})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSortFieldBuildExpr(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", "test4.sub")
|
||||
|
||||
scenarios := []struct {
|
||||
sortField search.SortField
|
||||
expectError bool
|
||||
expectExpression string
|
||||
}{
|
||||
// empty
|
||||
{search.SortField{"", search.SortDesc}, true, ""},
|
||||
// unknown field
|
||||
{search.SortField{"unknown", search.SortAsc}, true, ""},
|
||||
// placeholder field
|
||||
{search.SortField{"'test'", search.SortAsc}, true, ""},
|
||||
// null field
|
||||
{search.SortField{"null", search.SortAsc}, true, ""},
|
||||
// allowed field - asc
|
||||
{search.SortField{"test1", search.SortAsc}, false, "[[test1]] ASC"},
|
||||
// allowed field - desc
|
||||
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.sortField.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if result != s.expectExpression {
|
||||
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectExpression, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortFromString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value string
|
||||
expectedJson string
|
||||
}{
|
||||
{"", `[{"name":"","direction":"ASC"}]`},
|
||||
{"test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"+test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"-test", `[{"name":"test","direction":"DESC"}]`},
|
||||
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := search.ParseSortFromString(s.value)
|
||||
encoded, _ := json.Marshal(result)
|
||||
|
||||
if string(encoded) != s.expectedJson {
|
||||
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectedJson, string(encoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user