filter enhancements

This commit is contained in:
Gani Georgiev
2023-01-07 22:25:56 +02:00
parent d5775ff657
commit 9b880f5ab4
102 changed files with 3693 additions and 986 deletions
+275 -66
View File
@@ -85,72 +85,119 @@ func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (
}
func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
lName, lParams, lErr := f.resolveToken(expr.Left, fieldResolver)
if lName == "" || lErr != nil {
return nil, fmt.Errorf("Invalid left operand %q - %v.", expr.Left.Literal, lErr)
lResult, lErr := resolveToken(expr.Left, fieldResolver)
if lErr != nil || lResult.Identifier == "" {
return nil, fmt.Errorf("invalid left operand %q - %v", expr.Left.Literal, lErr)
}
rName, rParams, rErr := f.resolveToken(expr.Right, fieldResolver)
if rName == "" || rErr != nil {
return nil, fmt.Errorf("Invalid right operand %q - %v.", expr.Right.Literal, rErr)
rResult, rErr := resolveToken(expr.Right, fieldResolver)
if rErr != nil || rResult.Identifier == "" {
return nil, fmt.Errorf("invalid right operand %q - %v", expr.Right.Literal, rErr)
}
switch expr.Op {
case fexpr.SignEq:
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') = COALESCE(%s, '')", lName, rName), mergeParams(lParams, rParams)), nil
case fexpr.SignNeq:
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') != COALESCE(%s, '')", lName, rName), mergeParams(lParams, rParams)), nil
case fexpr.SignLike:
// the right side is a column and therefor wrap it with "%" for contains like behavior
if len(rParams) == 0 {
return dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%') ESCAPE '\\'", lName, rName), lParams), nil
}
return dbx.NewExp(fmt.Sprintf("%s LIKE %s ESCAPE '\\'", lName, rName), mergeParams(lParams, wrapLikeParams(rParams))), nil
case fexpr.SignNlike:
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
if len(rParams) == 0 {
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%') ESCAPE '\\'", lName, rName), lParams), nil
}
// normalize operands and switch sides if the left operand is a number/text, but the right one is a column
// (usually this shouldn't be needed, but it's kept for backward compatibility)
if len(lParams) > 0 && len(rParams) == 0 {
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", rName, lName), wrapLikeParams(lParams)), nil
}
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", lName, rName), mergeParams(lParams, wrapLikeParams(rParams))), nil
case fexpr.SignLt:
return dbx.NewExp(fmt.Sprintf("%s < %s", lName, rName), mergeParams(lParams, rParams)), nil
case fexpr.SignLte:
return dbx.NewExp(fmt.Sprintf("%s <= %s", lName, rName), mergeParams(lParams, rParams)), nil
case fexpr.SignGt:
return dbx.NewExp(fmt.Sprintf("%s > %s", lName, rName), mergeParams(lParams, rParams)), nil
case fexpr.SignGte:
return dbx.NewExp(fmt.Sprintf("%s >= %s", lName, rName), mergeParams(lParams, rParams)), nil
}
return nil, fmt.Errorf("Unknown expression operator %q", expr.Op)
return buildExpr(lResult, expr.Op, rResult)
}
func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver) (name string, params dbx.Params, err error) {
func buildExpr(
left *ResolverResult,
op fexpr.SignOp,
right *ResolverResult,
) (dbx.Expression, error) {
var expr dbx.Expression
switch op {
case fexpr.SignEq, fexpr.SignAnyEq:
expr = dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') = COALESCE(%s, '')", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignNeq, fexpr.SignAnyNeq:
expr = dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') != COALESCE(%s, '')", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignLike, fexpr.SignAnyLike:
// the right side is a column and therefor wrap it with "%" for contains like behavior
if len(right.Params) == 0 {
expr = dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
} else {
expr = dbx.NewExp(fmt.Sprintf("%s LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
}
case fexpr.SignNlike, fexpr.SignAnyNlike:
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
if len(right.Params) == 0 {
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
} else {
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
}
case fexpr.SignLt, fexpr.SignAnyLt:
expr = dbx.NewExp(fmt.Sprintf("%s < %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignLte, fexpr.SignAnyLte:
expr = dbx.NewExp(fmt.Sprintf("%s <= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignGt, fexpr.SignAnyGt:
expr = dbx.NewExp(fmt.Sprintf("%s > %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
case fexpr.SignGte, fexpr.SignAnyGte:
expr = dbx.NewExp(fmt.Sprintf("%s >= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
}
if expr == nil {
return nil, fmt.Errorf("unknown expression operator %q", op)
}
// multi-match expressions
if !isAnyMatchOp(op) {
if left.MultiMatchSubQuery != nil && right.MultiMatchSubQuery != nil {
mm := &manyVsManyExpr{
leftSubQuery: left.MultiMatchSubQuery,
rightSubQuery: right.MultiMatchSubQuery,
op: op,
}
expr = dbx.And(expr, mm)
} else if left.MultiMatchSubQuery != nil {
mm := &manyVsOneExpr{
subQuery: left.MultiMatchSubQuery,
op: op,
otherOperand: right,
}
expr = dbx.And(expr, mm)
} else if right.MultiMatchSubQuery != nil {
mm := &manyVsOneExpr{
subQuery: right.MultiMatchSubQuery,
op: op,
otherOperand: left,
inverse: true,
}
expr = dbx.And(expr, mm)
}
}
if left.AfterBuild != nil {
expr = left.AfterBuild(expr)
}
if right.AfterBuild != nil {
expr = right.AfterBuild(expr)
}
return expr, nil
}
func resolveToken(token fexpr.Token, fieldResolver FieldResolver) (*ResolverResult, error) {
switch token.Type {
case fexpr.TokenIdentifier:
// current datetime constant
// ---
if token.Literal == "@now" {
placeholder := "t" + security.PseudorandomString(8)
name := fmt.Sprintf("{:%s}", placeholder)
params := dbx.Params{placeholder: types.NowDateTime().String()}
placeholder := "t" + security.PseudorandomString(5)
return name, params, nil
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: types.NowDateTime().String()},
}, nil
}
// custom resolver
// ---
name, params, err := fieldResolver.Resolve(token.Literal)
result, err := fieldResolver.Resolve(token.Literal)
if name == "" || err != nil {
if err != nil || result.Identifier == "" {
m := map[string]string{
// if `null` field is missing, treat `null` identifier as NULL token
"null": "NULL",
@@ -160,27 +207,46 @@ func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver)
"false": "0",
}
if v, ok := m[strings.ToLower(token.Literal)]; ok {
return v, nil, nil
return &ResolverResult{Identifier: v}, nil
}
return "", nil, err
return nil, err
}
return name, params, err
return result, err
case fexpr.TokenText:
placeholder := "t" + security.PseudorandomString(8)
name := fmt.Sprintf("{:%s}", placeholder)
params := dbx.Params{placeholder: token.Literal}
placeholder := "t" + security.PseudorandomString(5)
return name, params, nil
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: token.Literal},
}, nil
case fexpr.TokenNumber:
placeholder := "t" + security.PseudorandomString(8)
name := fmt.Sprintf("{:%s}", placeholder)
params := dbx.Params{placeholder: cast.ToFloat64(token.Literal)}
placeholder := "t" + security.PseudorandomString(5)
return name, params, nil
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: cast.ToFloat64(token.Literal)},
}, nil
}
return "", nil, errors.New("Unresolvable token type.")
return nil, errors.New("unresolvable token type")
}
func isAnyMatchOp(op fexpr.SignOp) bool {
switch op {
case
fexpr.SignAnyEq,
fexpr.SignAnyNeq,
fexpr.SignAnyLike,
fexpr.SignAnyNlike,
fexpr.SignAnyLt,
fexpr.SignAnyLte,
fexpr.SignAnyGt,
fexpr.SignAnyGte:
return true
}
return false
}
// mergeParams returns new dbx.Params where each provided params item
@@ -218,18 +284,24 @@ func wrapLikeParams(params dbx.Params) dbx.Params {
// -------------------------------------------------------------------
var _ dbx.Expression = (*opExpr)(nil)
// opExpr defines an expression that contains a raw sql operator string.
type opExpr struct {
op string
}
// Build converts an expression into a SQL fragment.
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *opExpr) Build(db *dbx.DB, params dbx.Params) string {
return e.op
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*concatExpr)(nil)
// concatExpr defines an expression that concatenates multiple
// other expressions with a specified separator.
type concatExpr struct {
@@ -237,7 +309,7 @@ type concatExpr struct {
separator string
}
// Build converts an expression into a SQL fragment.
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
@@ -247,12 +319,12 @@ func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
stringParts := make([]string, 0, len(e.parts))
for _, a := range e.parts {
if a == nil {
for _, p := range e.parts {
if p == nil {
continue
}
if sql := a.Build(db, params); sql != "" {
if sql := p.Build(db, params); sql != "" {
stringParts = append(stringParts, sql)
}
}
@@ -267,3 +339,140 @@ func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
return "(" + strings.Join(stringParts, e.separator) + ")"
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*manyVsManyExpr)(nil)
// manyVsManyExpr constructs a multi-match many<->many db where expression.
//
// Expects leftSubQuery and rightSubQuery to return a subquery with a
// single "multiMatchValue" column.
type manyVsManyExpr struct {
leftSubQuery dbx.Expression
rightSubQuery dbx.Expression
op fexpr.SignOp
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *manyVsManyExpr) Build(db *dbx.DB, params dbx.Params) string {
if e.leftSubQuery == nil || e.rightSubQuery == nil {
return "0=1"
}
lAlias := "__ml" + security.PseudorandomString(5)
rAlias := "__mr" + security.PseudorandomString(5)
whereExpr, buildErr := buildExpr(
&ResolverResult{
Identifier: "[[" + lAlias + ".multiMatchValue]]",
},
e.op,
&ResolverResult{
Identifier: "[[" + rAlias + ".multiMatchValue]]",
// note: the AfterBuild needs to be handled only once and it
// doesn't matter whether it is applied on the left or right subquery operand
AfterBuild: multiMatchAfterBuildFunc(e.op, lAlias, rAlias),
},
)
if buildErr != nil {
return "0=1"
}
return fmt.Sprintf(
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} LEFT JOIN (%s) {{%s}} WHERE %s)",
e.leftSubQuery.Build(db, params),
lAlias,
e.rightSubQuery.Build(db, params),
rAlias,
whereExpr.Build(db, params),
)
}
// -------------------------------------------------------------------
var _ dbx.Expression = (*manyVsOneExpr)(nil)
// manyVsManyExpr constructs a multi-match many<->one db where expression.
//
// Expects subQuery to return a subquery with a single "multiMatchValue" column.
//
// You can set inverse=false to reverse the condition sides (aka. one<->many).
type manyVsOneExpr struct {
subQuery dbx.Expression
op fexpr.SignOp
otherOperand *ResolverResult
inverse bool
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (e *manyVsOneExpr) Build(db *dbx.DB, params dbx.Params) string {
if e.subQuery == nil {
return "0=1"
}
alias := "__sm" + security.PseudorandomString(5)
r1 := &ResolverResult{
Identifier: "[[" + alias + ".multiMatchValue]]",
AfterBuild: multiMatchAfterBuildFunc(e.op, alias),
}
r2 := &ResolverResult{
Identifier: e.otherOperand.Identifier,
Params: e.otherOperand.Params,
}
var whereExpr dbx.Expression
var buildErr error
if e.inverse {
whereExpr, buildErr = buildExpr(r2, e.op, r1)
} else {
whereExpr, buildErr = buildExpr(r1, e.op, r2)
}
if buildErr != nil {
return "0=1"
}
return fmt.Sprintf(
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} WHERE %s)",
e.subQuery.Build(db, params),
alias,
whereExpr.Build(db, params),
)
}
func multiMatchAfterBuildFunc(op fexpr.SignOp, multiMatchAliases ...string) func(dbx.Expression) dbx.Expression {
return func(expr dbx.Expression) dbx.Expression {
expr = dbx.Not(expr) // inverse for the not-exist expression
if op == fexpr.SignEq {
return expr
}
orExprs := make([]dbx.Expression, len(multiMatchAliases)+1)
orExprs[0] = expr
// Add an optional "IS NULL" condition(s) to handle the empty rows result.
//
// For example, let's assume that some "rel" field is [nonemptyRel1, nonemptyRel2, emptyRel3],
// The filter "rel.total > 0" will ensures that the above will return true only if all relations
// are existing and match the condition.
//
// The "=" operator is excluded because it will never equal directly with NULL anyway
// and also because we want in case "rel.id = ''" is specified to allow
// matching the empty relations (they will match due to the applied COALESCE).
for i, mAlias := range multiMatchAliases {
orExprs[i+1] = dbx.NewExp("[[" + mAlias + ".multiMatchValue]] IS NULL")
}
return dbx.Or(orExprs...)
}
}
+3 -3
View File
@@ -495,12 +495,12 @@ func (t *testFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
return nil
}
func (t *testFieldResolver) Resolve(field string) (name string, placeholderParams dbx.Params, err error) {
func (t *testFieldResolver) Resolve(field string) (*ResolverResult, error) {
t.ResolveCalls++
if field == "unknown" {
return "", nil, errors.New("test error")
return nil, errors.New("test error")
}
return field, nil, nil
return &ResolverResult{Identifier: field}, nil
}
+25 -4
View File
@@ -8,6 +8,25 @@ import (
"github.com/pocketbase/pocketbase/tools/list"
)
// ResolverResult defines a single FieldResolver.Resolve() successfully parsed result.
type ResolverResult struct {
// Identifier is the plain SQL identifier/column that will be used
// in the final db expression as left or right operand.
Identifier string
// Params is a map with db placeholder->value pairs that will be added
// to the query when building both resolved operands/sides in a single expression.
Params dbx.Params
// MultiMatchSubQuery is an optional sub query expression that will be added
// in addition to the combined ResolverResult expression during build.
MultiMatchSubQuery dbx.Expression
// AfterBuild is an optional function that will be called after building
// and combining the result of both resolved operands/sides in a single expression.
AfterBuild func(expr dbx.Expression) dbx.Expression
}
// FieldResolver defines an interface for managing search fields.
type FieldResolver interface {
// UpdateQuery allows to updated the provided db query based on the
@@ -18,7 +37,7 @@ type FieldResolver interface {
// Resolve parses the provided field and returns a properly
// formatted db identifier (eg. NULL, quoted column, placeholder parameter, etc.).
Resolve(field string) (name string, placeholderParams dbx.Params, err error)
Resolve(field string) (*ResolverResult, error)
}
// NewSimpleFieldResolver creates a new `SimpleFieldResolver` with the
@@ -49,10 +68,12 @@ func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
// Resolve implements `search.Resolve` interface.
//
// Returns error if `field` is not in `r.allowedFields`.
func (r *SimpleFieldResolver) Resolve(field string) (resultName string, placeholderParams dbx.Params, err error) {
func (r *SimpleFieldResolver) Resolve(field string) (*ResolverResult, error) {
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
return "", nil, fmt.Errorf("Failed to resolve field %q.", field)
return nil, fmt.Errorf("Failed to resolve field %q.", field)
}
return fmt.Sprintf("[[%s]]", inflector.Columnify(field)), nil, nil
return &ResolverResult{
Identifier: "[[" + inflector.Columnify(field) + "]]",
}, nil
}
+9 -5
View File
@@ -61,7 +61,7 @@ func TestSimpleFieldResolverResolve(t *testing.T) {
}
for i, s := range scenarios {
name, params, err := r.Resolve(s.fieldName)
r, err := r.Resolve(s.fieldName)
hasErr := err != nil
if hasErr != s.expectError {
@@ -69,13 +69,17 @@ func TestSimpleFieldResolverResolve(t *testing.T) {
continue
}
if name != s.expectName {
t.Errorf("(%d) Expected name %q, got %q", i, s.expectName, name)
if hasErr {
continue
}
if r.Identifier != s.expectName {
t.Errorf("(%d) Expected r.Identifier %q, got %q", i, s.expectName, r.Identifier)
}
// params should be empty
if len(params) != 0 {
t.Errorf("(%d) Expected 0 params, got %v", i, params)
if len(r.Params) != 0 {
t.Errorf("(%d) Expected 0 r.Params, got %v", i, r.Params)
}
}
}
+13 -6
View File
@@ -5,6 +5,8 @@ import (
"strings"
)
const randomSortKey string = "@random"
// sort field directions
const (
SortAsc string = "ASC"
@@ -19,14 +21,19 @@ type SortField struct {
// BuildExpr resolves the sort field into a valid db sort expression.
func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
name, params, err := fieldResolver.Resolve(s.Name)
// invalidate empty fields and non-column identifiers
if err != nil || len(params) > 0 || name == "" || strings.ToLower(name) == "null" {
return "", fmt.Errorf("Invalid sort field %q.", s.Name)
// special case for random sort
if s.Name == randomSortKey {
return "RANDOM()", nil
}
return fmt.Sprintf("%s %s", name, s.Direction), nil
result, err := fieldResolver.Resolve(s.Name)
// invalidate empty fields and non-column identifiers
if err != nil || len(result.Params) > 0 || result.Identifier == "" || strings.ToLower(result.Identifier) == "null" {
return "", fmt.Errorf("invalid sort field %q", s.Name)
}
return fmt.Sprintf("%s %s", result.Identifier, s.Direction), nil
}
// ParseSortFromString parses the provided string expression
+3
View File
@@ -27,6 +27,8 @@ func TestSortFieldBuildExpr(t *testing.T) {
{search.SortField{"test1", search.SortAsc}, false, "[[test1]] ASC"},
// allowed field - desc
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
// special @random field (ignore direction)
{search.SortField{"@random", search.SortDesc}, false, "RANDOM()"},
}
for i, s := range scenarios {
@@ -54,6 +56,7 @@ func TestParseSortFromString(t *testing.T) {
{"+test", `[{"name":"test","direction":"ASC"}]`},
{"-test", `[{"name":"test","direction":"DESC"}]`},
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
{"@random,-test", `[{"name":"@random","direction":"ASC"},{"name":"test","direction":"DESC"}]`},
}
for i, s := range scenarios {