added placeholder params support for Dao.FindRecordsByFilter and Dao.FindFirstRecordByFilter

This commit is contained in:
Gani Georgiev
2023-08-18 06:31:14 +03:00
parent e87ef431c5
commit 75f58a28ac
5 changed files with 261 additions and 80 deletions
+49 -18
View File
@@ -3,6 +3,7 @@ package search
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/ganigeorgiev/fexpr"
@@ -15,11 +16,14 @@ import (
// FilterData is a filter expression string following the `fexpr` package grammar.
//
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
//
// Example:
//
// var filter FilterData = "id = null || (name = 'test' && status = true)"
// var filter FilterData = "id = null || (name = 'test' && status = true) || (total >= {:min} && total <= {:max})"
// resolver := search.NewSimpleFieldResolver("id", "name", "status")
// expr, err := filter.BuildExpr(resolver)
// expr, err := filter.BuildExpr(resolver, dbx.Params{"min": 100, "max": 200})
type FilterData string
// parsedFilterData holds a cache with previously parsed filter data expressions
@@ -27,10 +31,33 @@ type FilterData string
var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50))
// BuildExpr parses the current filter data and returns a new db WHERE expression.
func (f FilterData) BuildExpr(fieldResolver FieldResolver) (dbx.Expression, error) {
//
// The filter string can also contain dbx placeholder parameters (eg. "title = {:name}"),
// that will be safely replaced and properly quoted inplace with the placeholderReplacements values.
func (f FilterData) BuildExpr(
fieldResolver FieldResolver,
placeholderReplacements ...dbx.Params,
) (dbx.Expression, error) {
raw := string(f)
// replace the placeholder params in the raw string filter
for _, p := range placeholderReplacements {
for key, value := range p {
var replacement string
switch v := value.(type) {
case nil:
replacement = "null"
case bool, float64, float32, int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8:
replacement = cast.ToString(v)
default:
replacement = strconv.Quote(cast.ToString(v))
}
raw = strings.ReplaceAll(raw, "{:"+key+"}", replacement)
}
}
if parsedFilterData.Has(raw) {
return f.build(parsedFilterData.Get(raw), fieldResolver)
return buildParsedFilterExpr(parsedFilterData.Get(raw), fieldResolver)
}
data, err := fexpr.Parse(raw)
if err != nil {
@@ -39,10 +66,10 @@ func (f FilterData) BuildExpr(fieldResolver FieldResolver) (dbx.Expression, erro
// store in cache
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
parsedFilterData.SetIfLessThanLimit(raw, data, 500)
return f.build(data, fieldResolver)
return buildParsedFilterExpr(data, fieldResolver)
}
func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) {
func buildParsedFilterExpr(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) {
if len(data) == 0 {
return nil, errors.New("empty filter expression")
}
@@ -55,11 +82,11 @@ func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (
switch item := group.Item.(type) {
case fexpr.Expr:
expr, exprErr = f.resolveTokenizedExpr(item, fieldResolver)
expr, exprErr = resolveTokenizedExpr(item, fieldResolver)
case fexpr.ExprGroup:
expr, exprErr = f.build([]fexpr.ExprGroup{item}, fieldResolver)
expr, exprErr = buildParsedFilterExpr([]fexpr.ExprGroup{item}, fieldResolver)
case []fexpr.ExprGroup:
expr, exprErr = f.build(item, fieldResolver)
expr, exprErr = buildParsedFilterExpr(item, fieldResolver)
default:
exprErr = errors.New("unsupported expression item")
}
@@ -84,7 +111,7 @@ func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (
return result, nil
}
func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
func resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
lResult, lErr := resolveToken(expr.Left, fieldResolver)
if lErr != nil || lResult.Identifier == "" {
return nil, fmt.Errorf("invalid left operand %q - %v", expr.Left.Literal, lErr)
@@ -95,10 +122,10 @@ func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldRes
return nil, fmt.Errorf("invalid right operand %q - %v", expr.Right.Literal, rErr)
}
return buildExpr(lResult, expr.Op, rResult)
return buildResolversExpr(lResult, expr.Op, rResult)
}
func buildExpr(
func buildResolversExpr(
left *ResolverResult,
op fexpr.SignOp,
right *ResolverResult,
@@ -179,17 +206,21 @@ func buildExpr(
return expr, nil
}
var identifierMacros = map[string]func() string{
"@now": func() string { return types.NowDateTime().String() },
}
func resolveToken(token fexpr.Token, fieldResolver FieldResolver) (*ResolverResult, error) {
switch token.Type {
case fexpr.TokenIdentifier:
// current datetime constant
// check for macros
// ---
if token.Literal == "@now" {
if f, ok := identifierMacros[token.Literal]; ok {
placeholder := "t" + security.PseudorandomString(5)
return &ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: types.NowDateTime().String()},
Params: dbx.Params{placeholder: f()},
}, nil
}
@@ -469,7 +500,7 @@ func (e *manyVsManyExpr) Build(db *dbx.DB, params dbx.Params) string {
lAlias := "__ml" + security.PseudorandomString(5)
rAlias := "__mr" + security.PseudorandomString(5)
whereExpr, buildErr := buildExpr(
whereExpr, buildErr := buildResolversExpr(
&ResolverResult{
Identifier: "[[" + lAlias + ".multiMatchValue]]",
},
@@ -536,9 +567,9 @@ func (e *manyVsOneExpr) Build(db *dbx.DB, params dbx.Params) string {
var buildErr error
if e.inverse {
whereExpr, buildErr = buildExpr(r2, e.op, r1)
whereExpr, buildErr = buildResolversExpr(r2, e.op, r1)
} else {
whereExpr, buildErr = buildExpr(r1, e.op, r2)
whereExpr, buildErr = buildResolversExpr(r1, e.op, r2)
}
if buildErr != nil {
+87 -16
View File
@@ -1,8 +1,11 @@
package search_test
import (
"context"
"database/sql"
"regexp"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/search"
@@ -25,7 +28,10 @@ func TestFilterDataBuildExpr(t *testing.T) {
},
{
"invalid format",
"(test1 > 1", true, ""},
"(test1 > 1",
true,
"",
},
{
"invalid operator",
"test1 + 123",
@@ -169,24 +175,89 @@ func TestFilterDataBuildExpr(t *testing.T) {
}
for _, s := range scenarios {
expr, err := s.filterData.BuildExpr(resolver)
t.Run(s.name, func(t *testing.T) {
expr, err := s.filterData.BuildExpr(resolver)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
continue
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
if hasErr {
return
}
dummyDB := &dbx.DB{}
rawSql := expr.Build(dummyDB, map[string]any{})
dummyDB := &dbx.DB{}
pattern := regexp.MustCompile(s.expectPattern)
if !pattern.MatchString(rawSql) {
t.Errorf("[%s] Pattern %v don't match with expression: \n%v", s.name, s.expectPattern, rawSql)
}
rawSql := expr.Build(dummyDB, dbx.Params{})
pattern := regexp.MustCompile(s.expectPattern)
if !pattern.MatchString(rawSql) {
t.Fatalf("[%s] Pattern %v don't match with expression: \n%v", s.name, s.expectPattern, rawSql)
}
})
}
}
func TestFilterDataBuildExprWithParams(t *testing.T) {
// create a dummy db
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
t.Fatal(err)
}
db := dbx.NewFromDB(sqlDB, "sqlite")
calledQueries := []string{}
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
calledQueries = append(calledQueries, sql)
}
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
calledQueries = append(calledQueries, sql)
}
date, err := time.Parse("2006-01-02", "2023-01-01")
if err != nil {
t.Fatal(err)
}
resolver := search.NewSimpleFieldResolver(`^test\w+$`)
filter := search.FilterData(`
test1 = {:test1} ||
test2 = {:test2} ||
test3a = {:test3} ||
test3b = {:test3} ||
test4 = {:test4} ||
test5 = {:test5} ||
test6 = {:test6} ||
test7 = {:test7} ||
test8 = {:test8} ||
test9 = {:test9} ||
test10 = {:test10}
`)
replacements := []dbx.Params{
{"test1": true},
{"test2": false},
{"test3": 123.456},
{"test4": nil},
{"test5": "", "test6": "simple", "test7": `'single_quotes'`, "test8": `"double_quotes"`, "test9": `escape\"quote`},
{"test10": date},
}
expr, err := filter.BuildExpr(resolver, replacements...)
if err != nil {
t.Fatal(err)
}
db.Select().Where(expr).Build().Execute()
if len(calledQueries) != 1 {
t.Fatalf("Expected 1 query, got %d", len(calledQueries))
}
expectedQuery := `SELECT * WHERE ([[test1]] = 1 OR [[test2]] = 0 OR [[test3a]] = 123.456 OR [[test3b]] = 123.456 OR ([[test4]] = '' OR [[test4]] IS NULL) OR ([[test5]] = '' OR [[test5]] IS NULL) OR [[test6]] = 'simple' OR [[test7]] = '''single_quotes''' OR [[test8]] = '"double_quotes"' OR [[test9]] = 'escape\\"quote' OR [[test10]] = '2023-01-01 00:00:00 +0000 UTC')`
if expectedQuery != calledQueries[0] {
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
}
}