filter enhancements
This commit is contained in:
+8
-7
@@ -9,13 +9,14 @@ import (
|
||||
|
||||
// AuthUser defines a standardized oauth2 user data structure.
|
||||
type AuthUser struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatarUrl"`
|
||||
RawUser map[string]any `json:"rawUser"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatarUrl"`
|
||||
RawUser map[string]any `json:"rawUser"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// Provider defines a common interface for an OAuth2 client.
|
||||
|
||||
@@ -63,12 +63,13 @@ func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
username := fmt.Sprintf("%s#%s", extracted.Username, extracted.Discriminator)
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: username,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: avatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Id,
|
||||
Name: username,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: avatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
if extracted.Verified {
|
||||
user.Email = extracted.Email
|
||||
|
||||
@@ -54,12 +54,13 @@ func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.Picture.Data.Url,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.Picture.Data.Url,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
+7
-6
@@ -55,12 +55,13 @@ func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
if extracted.Email != "" && is.EmailFormat.Validate(extracted.Email) == nil {
|
||||
|
||||
@@ -55,13 +55,14 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
// in case user has set "Keep my email address private", send an
|
||||
|
||||
@@ -53,13 +53,14 @@ func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -52,12 +52,13 @@ func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.Picture,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.Picture,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
+6
-5
@@ -59,11 +59,12 @@ func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Username: extracted.Profile.Nickname,
|
||||
AvatarUrl: extracted.Profile.ImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Username: extracted.Profile.Nickname,
|
||||
AvatarUrl: extracted.Profile.ImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
if extracted.KakaoAccount.IsEmailValid && extracted.KakaoAccount.IsEmailVerified {
|
||||
user.Email = extracted.KakaoAccount.Email
|
||||
|
||||
@@ -53,11 +53,12 @@ func (p *Microsoft) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -61,10 +61,11 @@ func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
if len(extracted.Images) > 0 {
|
||||
user.AvatarUrl = extracted.Images[0].Url
|
||||
|
||||
@@ -58,12 +58,13 @@ func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.FirstName + " " + extracted.LastName,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: extracted.ProfileImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.FirstName + " " + extracted.LastName,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: extracted.ProfileImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -61,13 +61,14 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data[0].Id,
|
||||
Name: extracted.Data[0].DisplayName,
|
||||
Username: extracted.Data[0].Login,
|
||||
Email: extracted.Data[0].Email,
|
||||
AvatarUrl: extracted.Data[0].ProfileImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Data[0].Id,
|
||||
Name: extracted.Data[0].DisplayName,
|
||||
Username: extracted.Data[0].Login,
|
||||
Email: extracted.Data[0].Email,
|
||||
AvatarUrl: extracted.Data[0].ProfileImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -63,12 +63,13 @@ func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.Name,
|
||||
Username: extracted.Data.Username,
|
||||
AvatarUrl: extracted.Data.ProfileImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.Name,
|
||||
Username: extracted.Data.Username,
|
||||
AvatarUrl: extracted.Data.ProfileImageUrl,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -10,6 +10,20 @@ import (
|
||||
|
||||
var cachedPatterns = map[string]*regexp.Regexp{}
|
||||
|
||||
// SubtractSlice returns a new slice with only the "base" elements
|
||||
// that don't exist in "subtract".
|
||||
func SubtractSlice[T comparable](base []T, subtract []T) []T {
|
||||
var result = make([]T, 0, len(base))
|
||||
|
||||
for _, b := range base {
|
||||
if !ExistInSlice(b, subtract) {
|
||||
result = append(result, b)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ExistInSlice checks whether a comparable element exists in a slice of the same type.
|
||||
func ExistInSlice[T comparable](item T, list []T) bool {
|
||||
if len(list) == 0 {
|
||||
|
||||
@@ -1,12 +1,111 @@
|
||||
package list_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestSubtractSliceString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
base []string
|
||||
subtract []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
[]string{},
|
||||
[]string{},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
[]string{"1", "2", "3", "4"},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]string{"1", "2", "3", "4"},
|
||||
[]string{},
|
||||
`["1","2","3","4"]`,
|
||||
},
|
||||
{
|
||||
[]string{"1", "2", "3", "4"},
|
||||
[]string{"1", "2", "3", "4"},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]string{"1", "2", "3", "4", "7"},
|
||||
[]string{"2", "4", "5", "6"},
|
||||
`["1","3","7"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("(%d) Failed to serialize: %v", i, err)
|
||||
}
|
||||
|
||||
strResult := string(raw)
|
||||
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("(%d) Expected %v, got %v", i, s.expected, strResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubtractSliceInt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
base []int
|
||||
subtract []int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
[]int{},
|
||||
[]int{},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]int{},
|
||||
[]int{1, 2, 3, 4},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3, 4},
|
||||
[]int{},
|
||||
`[1,2,3,4]`,
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3, 4},
|
||||
[]int{1, 2, 3, 4},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3, 4, 7},
|
||||
[]int{2, 4, 5, 6},
|
||||
`[1,3,7]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("(%d) Failed to serialize: %v", i, err)
|
||||
}
|
||||
|
||||
strResult := string(raw)
|
||||
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("(%d) Expected %v, got %v", i, s.expected, strResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistInSliceString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item string
|
||||
|
||||
+59
-25
@@ -2,6 +2,7 @@ package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
@@ -72,9 +73,19 @@ func (r *Runner) Run(args ...string) error {
|
||||
}
|
||||
}
|
||||
|
||||
names, err := r.lastAppliedMigrations(toRevertCount)
|
||||
if err != nil {
|
||||
color.Red(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you really want to revert the last %d applied migration(s)?", toRevertCount),
|
||||
Message: fmt.Sprintf(
|
||||
"\n%v\nDo you really want to revert the last %d applied migration(s)?",
|
||||
strings.Join(names, "\n"),
|
||||
toRevertCount,
|
||||
),
|
||||
}
|
||||
survey.AskOne(prompt, &confirm)
|
||||
if !confirm {
|
||||
@@ -138,38 +149,43 @@ func (r *Runner) Up() ([]string, error) {
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// Down reverts the last `toRevertCount` applied migrations.
|
||||
// Down reverts the last `toRevertCount` applied migrations
|
||||
// (in the order they were applied).
|
||||
//
|
||||
// On success returns list with the reverted migrations file names.
|
||||
func (r *Runner) Down(toRevertCount int) ([]string, error) {
|
||||
reverted := make([]string, 0, toRevertCount)
|
||||
|
||||
names, appliedErr := r.lastAppliedMigrations(toRevertCount)
|
||||
if appliedErr != nil {
|
||||
return nil, appliedErr
|
||||
}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
for i := len(r.migrationsList.Items()) - 1; i >= 0; i-- {
|
||||
m := r.migrationsList.Item(i)
|
||||
|
||||
// skip unapplied
|
||||
if !r.isMigrationApplied(tx, m.File) {
|
||||
continue
|
||||
}
|
||||
|
||||
// revert limit reached
|
||||
if toRevertCount-len(reverted) <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// ignore empty Down action
|
||||
if m.Down != nil {
|
||||
if err := m.Down(tx); err != nil {
|
||||
return fmt.Errorf("Failed to revert migration %s: %w", m.File, err)
|
||||
for _, name := range names {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
if m.File != name {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.saveRevertedMigration(tx, m.File); err != nil {
|
||||
return fmt.Errorf("Failed to save reverted migration info for %s: %w", m.File, err)
|
||||
}
|
||||
// revert limit reached
|
||||
if toRevertCount-len(reverted) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
reverted = append(reverted, m.File)
|
||||
// ignore empty Down action
|
||||
if m.Down != nil {
|
||||
if err := m.Down(tx); err != nil {
|
||||
return fmt.Errorf("Failed to revert migration %s: %w", m.File, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.saveRevertedMigration(tx, m.File); err != nil {
|
||||
return fmt.Errorf("Failed to save reverted migration info for %s: %w", m.File, err)
|
||||
}
|
||||
|
||||
reverted = append(reverted, m.File)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -178,6 +194,7 @@ func (r *Runner) Down(toRevertCount int) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reverted, nil
|
||||
}
|
||||
|
||||
@@ -207,7 +224,7 @@ func (r *Runner) isMigrationApplied(tx dbx.Builder, file string) bool {
|
||||
func (r *Runner) saveAppliedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Insert(r.tableName, dbx.Params{
|
||||
"file": file,
|
||||
"applied": time.Now().Unix(),
|
||||
"applied": time.Now().UnixMicro(),
|
||||
}).Execute()
|
||||
|
||||
return err
|
||||
@@ -218,3 +235,20 @@ func (r *Runner) saveRevertedMigration(tx dbx.Builder, file string) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) lastAppliedMigrations(limit int) ([]string, error) {
|
||||
var files = make([]string, 0, limit)
|
||||
|
||||
err := r.db.Select("file").
|
||||
From(r.tableName).
|
||||
Where(dbx.Not(dbx.HashExp{"applied": nil})).
|
||||
OrderBy("applied DESC", "file DESC").
|
||||
Limit(int64(limit)).
|
||||
Column(&files)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package migrate
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -52,73 +53,88 @@ func TestRunnerUpAndDown(t *testing.T) {
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
var test1UpCalled bool
|
||||
var test1DownCalled bool
|
||||
var test2UpCalled bool
|
||||
var test2DownCalled bool
|
||||
callsOrder := []string{}
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
test1UpCalled = true
|
||||
callsOrder = append(callsOrder, "up2")
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
test1DownCalled = true
|
||||
return nil
|
||||
}, "1_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
test2UpCalled = true
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
test2DownCalled = true
|
||||
callsOrder = append(callsOrder, "down2")
|
||||
return nil
|
||||
}, "2_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "up3")
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down3")
|
||||
return nil
|
||||
}, "3_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "up1")
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down1")
|
||||
return nil
|
||||
}, "1_test")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// simulate partially run migration
|
||||
r.saveAppliedMigration(testDB, r.migrationsList.Item(0).File)
|
||||
// simulate partially out-of-order run migration
|
||||
r.saveAppliedMigration(testDB, "2_test")
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Up()
|
||||
// ---
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if _, err := r.Up(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test1UpCalled {
|
||||
t.Fatalf("Didn't expect 1_test to be called")
|
||||
}
|
||||
expectedUpCallsOrder := `["up1","up3"]` // skip up2 since it was applied previously
|
||||
|
||||
if !test2UpCalled {
|
||||
t.Fatalf("Expected 2_test to be called")
|
||||
}
|
||||
|
||||
// simulate unrun migration
|
||||
var test3DownCalled bool
|
||||
r.migrationsList.Register(nil, func(db dbx.Builder) error {
|
||||
test3DownCalled = true
|
||||
return nil
|
||||
}, "3_test")
|
||||
|
||||
// Down()
|
||||
// ---
|
||||
// revert one migration
|
||||
if _, err := r.Down(1); err != nil {
|
||||
upCallsOrder, err := json.Marshal(callsOrder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test3DownCalled {
|
||||
t.Fatal("Didn't expect 3_test to be reverted.")
|
||||
if v := string(upCallsOrder); v != expectedUpCallsOrder {
|
||||
t.Fatalf("Expected Up() calls order %s, got %s", expectedUpCallsOrder, upCallsOrder)
|
||||
}
|
||||
|
||||
if !test2DownCalled {
|
||||
t.Fatal("Expected 2_test to be reverted.")
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// reset callsOrder
|
||||
callsOrder = []string{}
|
||||
|
||||
// simulate unrun migration
|
||||
r.migrationsList.Register(nil, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down4")
|
||||
return nil
|
||||
}, "4_test")
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Down()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if _, err := r.Down(2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test1DownCalled {
|
||||
t.Fatal("Didn't expect 1_test to be reverted.")
|
||||
expectedDownCallsOrder := `["down3","down1"]` // revert in the applied order
|
||||
|
||||
downCallsOrder, err := json.Marshal(callsOrder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := string(downCallsOrder); v != expectedDownCallsOrder {
|
||||
t.Fatalf("Expected Down() calls order %s, got %s", expectedDownCallsOrder, downCallsOrder)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+275
-66
@@ -85,72 +85,119 @@ func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (
|
||||
}
|
||||
|
||||
func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
lName, lParams, lErr := f.resolveToken(expr.Left, fieldResolver)
|
||||
if lName == "" || lErr != nil {
|
||||
return nil, fmt.Errorf("Invalid left operand %q - %v.", expr.Left.Literal, lErr)
|
||||
lResult, lErr := resolveToken(expr.Left, fieldResolver)
|
||||
if lErr != nil || lResult.Identifier == "" {
|
||||
return nil, fmt.Errorf("invalid left operand %q - %v", expr.Left.Literal, lErr)
|
||||
}
|
||||
|
||||
rName, rParams, rErr := f.resolveToken(expr.Right, fieldResolver)
|
||||
if rName == "" || rErr != nil {
|
||||
return nil, fmt.Errorf("Invalid right operand %q - %v.", expr.Right.Literal, rErr)
|
||||
rResult, rErr := resolveToken(expr.Right, fieldResolver)
|
||||
if rErr != nil || rResult.Identifier == "" {
|
||||
return nil, fmt.Errorf("invalid right operand %q - %v", expr.Right.Literal, rErr)
|
||||
}
|
||||
|
||||
switch expr.Op {
|
||||
case fexpr.SignEq:
|
||||
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') = COALESCE(%s, '')", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignNeq:
|
||||
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') != COALESCE(%s, '')", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignLike:
|
||||
// the right side is a column and therefor wrap it with "%" for contains like behavior
|
||||
if len(rParams) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%') ESCAPE '\\'", lName, rName), lParams), nil
|
||||
}
|
||||
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s ESCAPE '\\'", lName, rName), mergeParams(lParams, wrapLikeParams(rParams))), nil
|
||||
case fexpr.SignNlike:
|
||||
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
|
||||
if len(rParams) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%') ESCAPE '\\'", lName, rName), lParams), nil
|
||||
}
|
||||
|
||||
// normalize operands and switch sides if the left operand is a number/text, but the right one is a column
|
||||
// (usually this shouldn't be needed, but it's kept for backward compatibility)
|
||||
if len(lParams) > 0 && len(rParams) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", rName, lName), wrapLikeParams(lParams)), nil
|
||||
}
|
||||
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", lName, rName), mergeParams(lParams, wrapLikeParams(rParams))), nil
|
||||
case fexpr.SignLt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s < %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignLte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s <= %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignGt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s > %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignGte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s >= %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown expression operator %q", expr.Op)
|
||||
return buildExpr(lResult, expr.Op, rResult)
|
||||
}
|
||||
|
||||
func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver) (name string, params dbx.Params, err error) {
|
||||
func buildExpr(
|
||||
left *ResolverResult,
|
||||
op fexpr.SignOp,
|
||||
right *ResolverResult,
|
||||
) (dbx.Expression, error) {
|
||||
var expr dbx.Expression
|
||||
|
||||
switch op {
|
||||
case fexpr.SignEq, fexpr.SignAnyEq:
|
||||
expr = dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') = COALESCE(%s, '')", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignNeq, fexpr.SignAnyNeq:
|
||||
expr = dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') != COALESCE(%s, '')", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignLike, fexpr.SignAnyLike:
|
||||
// the right side is a column and therefor wrap it with "%" for contains like behavior
|
||||
if len(right.Params) == 0 {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
|
||||
} else {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
|
||||
}
|
||||
case fexpr.SignNlike, fexpr.SignAnyNlike:
|
||||
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
|
||||
if len(right.Params) == 0 {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%') ESCAPE '\\'", left.Identifier, right.Identifier), left.Params)
|
||||
} else {
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s ESCAPE '\\'", left.Identifier, right.Identifier), mergeParams(left.Params, wrapLikeParams(right.Params)))
|
||||
}
|
||||
case fexpr.SignLt, fexpr.SignAnyLt:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s < %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignLte, fexpr.SignAnyLte:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s <= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignGt, fexpr.SignAnyGt:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s > %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
case fexpr.SignGte, fexpr.SignAnyGte:
|
||||
expr = dbx.NewExp(fmt.Sprintf("%s >= %s", left.Identifier, right.Identifier), mergeParams(left.Params, right.Params))
|
||||
}
|
||||
|
||||
if expr == nil {
|
||||
return nil, fmt.Errorf("unknown expression operator %q", op)
|
||||
}
|
||||
|
||||
// multi-match expressions
|
||||
if !isAnyMatchOp(op) {
|
||||
if left.MultiMatchSubQuery != nil && right.MultiMatchSubQuery != nil {
|
||||
mm := &manyVsManyExpr{
|
||||
leftSubQuery: left.MultiMatchSubQuery,
|
||||
rightSubQuery: right.MultiMatchSubQuery,
|
||||
op: op,
|
||||
}
|
||||
|
||||
expr = dbx.And(expr, mm)
|
||||
} else if left.MultiMatchSubQuery != nil {
|
||||
mm := &manyVsOneExpr{
|
||||
subQuery: left.MultiMatchSubQuery,
|
||||
op: op,
|
||||
otherOperand: right,
|
||||
}
|
||||
|
||||
expr = dbx.And(expr, mm)
|
||||
} else if right.MultiMatchSubQuery != nil {
|
||||
mm := &manyVsOneExpr{
|
||||
subQuery: right.MultiMatchSubQuery,
|
||||
op: op,
|
||||
otherOperand: left,
|
||||
inverse: true,
|
||||
}
|
||||
|
||||
expr = dbx.And(expr, mm)
|
||||
}
|
||||
}
|
||||
|
||||
if left.AfterBuild != nil {
|
||||
expr = left.AfterBuild(expr)
|
||||
}
|
||||
|
||||
if right.AfterBuild != nil {
|
||||
expr = right.AfterBuild(expr)
|
||||
}
|
||||
|
||||
return expr, nil
|
||||
}
|
||||
|
||||
func resolveToken(token fexpr.Token, fieldResolver FieldResolver) (*ResolverResult, error) {
|
||||
switch token.Type {
|
||||
case fexpr.TokenIdentifier:
|
||||
// current datetime constant
|
||||
// ---
|
||||
if token.Literal == "@now" {
|
||||
placeholder := "t" + security.PseudorandomString(8)
|
||||
name := fmt.Sprintf("{:%s}", placeholder)
|
||||
params := dbx.Params{placeholder: types.NowDateTime().String()}
|
||||
placeholder := "t" + security.PseudorandomString(5)
|
||||
|
||||
return name, params, nil
|
||||
return &ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: types.NowDateTime().String()},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// custom resolver
|
||||
// ---
|
||||
name, params, err := fieldResolver.Resolve(token.Literal)
|
||||
result, err := fieldResolver.Resolve(token.Literal)
|
||||
|
||||
if name == "" || err != nil {
|
||||
if err != nil || result.Identifier == "" {
|
||||
m := map[string]string{
|
||||
// if `null` field is missing, treat `null` identifier as NULL token
|
||||
"null": "NULL",
|
||||
@@ -160,27 +207,46 @@ func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver)
|
||||
"false": "0",
|
||||
}
|
||||
if v, ok := m[strings.ToLower(token.Literal)]; ok {
|
||||
return v, nil, nil
|
||||
return &ResolverResult{Identifier: v}, nil
|
||||
}
|
||||
return "", nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return name, params, err
|
||||
return result, err
|
||||
case fexpr.TokenText:
|
||||
placeholder := "t" + security.PseudorandomString(8)
|
||||
name := fmt.Sprintf("{:%s}", placeholder)
|
||||
params := dbx.Params{placeholder: token.Literal}
|
||||
placeholder := "t" + security.PseudorandomString(5)
|
||||
|
||||
return name, params, nil
|
||||
return &ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: token.Literal},
|
||||
}, nil
|
||||
case fexpr.TokenNumber:
|
||||
placeholder := "t" + security.PseudorandomString(8)
|
||||
name := fmt.Sprintf("{:%s}", placeholder)
|
||||
params := dbx.Params{placeholder: cast.ToFloat64(token.Literal)}
|
||||
placeholder := "t" + security.PseudorandomString(5)
|
||||
|
||||
return name, params, nil
|
||||
return &ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: cast.ToFloat64(token.Literal)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return "", nil, errors.New("Unresolvable token type.")
|
||||
return nil, errors.New("unresolvable token type")
|
||||
}
|
||||
|
||||
func isAnyMatchOp(op fexpr.SignOp) bool {
|
||||
switch op {
|
||||
case
|
||||
fexpr.SignAnyEq,
|
||||
fexpr.SignAnyNeq,
|
||||
fexpr.SignAnyLike,
|
||||
fexpr.SignAnyNlike,
|
||||
fexpr.SignAnyLt,
|
||||
fexpr.SignAnyLte,
|
||||
fexpr.SignAnyGt,
|
||||
fexpr.SignAnyGte:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// mergeParams returns new dbx.Params where each provided params item
|
||||
@@ -218,18 +284,24 @@ func wrapLikeParams(params dbx.Params) dbx.Params {
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*opExpr)(nil)
|
||||
|
||||
// opExpr defines an expression that contains a raw sql operator string.
|
||||
type opExpr struct {
|
||||
op string
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *opExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
return e.op
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*concatExpr)(nil)
|
||||
|
||||
// concatExpr defines an expression that concatenates multiple
|
||||
// other expressions with a specified separator.
|
||||
type concatExpr struct {
|
||||
@@ -237,7 +309,7 @@ type concatExpr struct {
|
||||
separator string
|
||||
}
|
||||
|
||||
// Build converts an expression into a SQL fragment.
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
@@ -247,12 +319,12 @@ func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
|
||||
stringParts := make([]string, 0, len(e.parts))
|
||||
|
||||
for _, a := range e.parts {
|
||||
if a == nil {
|
||||
for _, p := range e.parts {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if sql := a.Build(db, params); sql != "" {
|
||||
if sql := p.Build(db, params); sql != "" {
|
||||
stringParts = append(stringParts, sql)
|
||||
}
|
||||
}
|
||||
@@ -267,3 +339,140 @@ func (e *concatExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
|
||||
return "(" + strings.Join(stringParts, e.separator) + ")"
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*manyVsManyExpr)(nil)
|
||||
|
||||
// manyVsManyExpr constructs a multi-match many<->many db where expression.
|
||||
//
|
||||
// Expects leftSubQuery and rightSubQuery to return a subquery with a
|
||||
// single "multiMatchValue" column.
|
||||
type manyVsManyExpr struct {
|
||||
leftSubQuery dbx.Expression
|
||||
rightSubQuery dbx.Expression
|
||||
op fexpr.SignOp
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *manyVsManyExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
if e.leftSubQuery == nil || e.rightSubQuery == nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
lAlias := "__ml" + security.PseudorandomString(5)
|
||||
rAlias := "__mr" + security.PseudorandomString(5)
|
||||
|
||||
whereExpr, buildErr := buildExpr(
|
||||
&ResolverResult{
|
||||
Identifier: "[[" + lAlias + ".multiMatchValue]]",
|
||||
},
|
||||
e.op,
|
||||
&ResolverResult{
|
||||
Identifier: "[[" + rAlias + ".multiMatchValue]]",
|
||||
// note: the AfterBuild needs to be handled only once and it
|
||||
// doesn't matter whether it is applied on the left or right subquery operand
|
||||
AfterBuild: multiMatchAfterBuildFunc(e.op, lAlias, rAlias),
|
||||
},
|
||||
)
|
||||
|
||||
if buildErr != nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} LEFT JOIN (%s) {{%s}} WHERE %s)",
|
||||
e.leftSubQuery.Build(db, params),
|
||||
lAlias,
|
||||
e.rightSubQuery.Build(db, params),
|
||||
rAlias,
|
||||
whereExpr.Build(db, params),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ dbx.Expression = (*manyVsOneExpr)(nil)
|
||||
|
||||
// manyVsManyExpr constructs a multi-match many<->one db where expression.
|
||||
//
|
||||
// Expects subQuery to return a subquery with a single "multiMatchValue" column.
|
||||
//
|
||||
// You can set inverse=false to reverse the condition sides (aka. one<->many).
|
||||
type manyVsOneExpr struct {
|
||||
subQuery dbx.Expression
|
||||
op fexpr.SignOp
|
||||
otherOperand *ResolverResult
|
||||
inverse bool
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (e *manyVsOneExpr) Build(db *dbx.DB, params dbx.Params) string {
|
||||
if e.subQuery == nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
alias := "__sm" + security.PseudorandomString(5)
|
||||
|
||||
r1 := &ResolverResult{
|
||||
Identifier: "[[" + alias + ".multiMatchValue]]",
|
||||
AfterBuild: multiMatchAfterBuildFunc(e.op, alias),
|
||||
}
|
||||
|
||||
r2 := &ResolverResult{
|
||||
Identifier: e.otherOperand.Identifier,
|
||||
Params: e.otherOperand.Params,
|
||||
}
|
||||
|
||||
var whereExpr dbx.Expression
|
||||
var buildErr error
|
||||
|
||||
if e.inverse {
|
||||
whereExpr, buildErr = buildExpr(r2, e.op, r1)
|
||||
} else {
|
||||
whereExpr, buildErr = buildExpr(r1, e.op, r2)
|
||||
}
|
||||
|
||||
if buildErr != nil {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"NOT EXISTS (SELECT 1 FROM (%s) {{%s}} WHERE %s)",
|
||||
e.subQuery.Build(db, params),
|
||||
alias,
|
||||
whereExpr.Build(db, params),
|
||||
)
|
||||
}
|
||||
|
||||
func multiMatchAfterBuildFunc(op fexpr.SignOp, multiMatchAliases ...string) func(dbx.Expression) dbx.Expression {
|
||||
return func(expr dbx.Expression) dbx.Expression {
|
||||
expr = dbx.Not(expr) // inverse for the not-exist expression
|
||||
|
||||
if op == fexpr.SignEq {
|
||||
return expr
|
||||
}
|
||||
|
||||
orExprs := make([]dbx.Expression, len(multiMatchAliases)+1)
|
||||
orExprs[0] = expr
|
||||
|
||||
// Add an optional "IS NULL" condition(s) to handle the empty rows result.
|
||||
//
|
||||
// For example, let's assume that some "rel" field is [nonemptyRel1, nonemptyRel2, emptyRel3],
|
||||
// The filter "rel.total > 0" will ensures that the above will return true only if all relations
|
||||
// are existing and match the condition.
|
||||
//
|
||||
// The "=" operator is excluded because it will never equal directly with NULL anyway
|
||||
// and also because we want in case "rel.id = ''" is specified to allow
|
||||
// matching the empty relations (they will match due to the applied COALESCE).
|
||||
for i, mAlias := range multiMatchAliases {
|
||||
orExprs[i+1] = dbx.NewExp("[[" + mAlias + ".multiMatchValue]] IS NULL")
|
||||
}
|
||||
|
||||
return dbx.Or(orExprs...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,12 +495,12 @@ func (t *testFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) Resolve(field string) (name string, placeholderParams dbx.Params, err error) {
|
||||
func (t *testFieldResolver) Resolve(field string) (*ResolverResult, error) {
|
||||
t.ResolveCalls++
|
||||
|
||||
if field == "unknown" {
|
||||
return "", nil, errors.New("test error")
|
||||
return nil, errors.New("test error")
|
||||
}
|
||||
|
||||
return field, nil, nil
|
||||
return &ResolverResult{Identifier: field}, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,25 @@ import (
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
// ResolverResult defines a single FieldResolver.Resolve() successfully parsed result.
|
||||
type ResolverResult struct {
|
||||
// Identifier is the plain SQL identifier/column that will be used
|
||||
// in the final db expression as left or right operand.
|
||||
Identifier string
|
||||
|
||||
// Params is a map with db placeholder->value pairs that will be added
|
||||
// to the query when building both resolved operands/sides in a single expression.
|
||||
Params dbx.Params
|
||||
|
||||
// MultiMatchSubQuery is an optional sub query expression that will be added
|
||||
// in addition to the combined ResolverResult expression during build.
|
||||
MultiMatchSubQuery dbx.Expression
|
||||
|
||||
// AfterBuild is an optional function that will be called after building
|
||||
// and combining the result of both resolved operands/sides in a single expression.
|
||||
AfterBuild func(expr dbx.Expression) dbx.Expression
|
||||
}
|
||||
|
||||
// FieldResolver defines an interface for managing search fields.
|
||||
type FieldResolver interface {
|
||||
// UpdateQuery allows to updated the provided db query based on the
|
||||
@@ -18,7 +37,7 @@ type FieldResolver interface {
|
||||
|
||||
// Resolve parses the provided field and returns a properly
|
||||
// formatted db identifier (eg. NULL, quoted column, placeholder parameter, etc.).
|
||||
Resolve(field string) (name string, placeholderParams dbx.Params, err error)
|
||||
Resolve(field string) (*ResolverResult, error)
|
||||
}
|
||||
|
||||
// NewSimpleFieldResolver creates a new `SimpleFieldResolver` with the
|
||||
@@ -49,10 +68,12 @@ func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
// Resolve implements `search.Resolve` interface.
|
||||
//
|
||||
// Returns error if `field` is not in `r.allowedFields`.
|
||||
func (r *SimpleFieldResolver) Resolve(field string) (resultName string, placeholderParams dbx.Params, err error) {
|
||||
func (r *SimpleFieldResolver) Resolve(field string) (*ResolverResult, error) {
|
||||
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
|
||||
return "", nil, fmt.Errorf("Failed to resolve field %q.", field)
|
||||
return nil, fmt.Errorf("Failed to resolve field %q.", field)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[[%s]]", inflector.Columnify(field)), nil, nil
|
||||
return &ResolverResult{
|
||||
Identifier: "[[" + inflector.Columnify(field) + "]]",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestSimpleFieldResolverResolve(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
name, params, err := r.Resolve(s.fieldName)
|
||||
r, err := r.Resolve(s.fieldName)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
@@ -69,13 +69,17 @@ func TestSimpleFieldResolverResolve(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
if name != s.expectName {
|
||||
t.Errorf("(%d) Expected name %q, got %q", i, s.expectName, name)
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if r.Identifier != s.expectName {
|
||||
t.Errorf("(%d) Expected r.Identifier %q, got %q", i, s.expectName, r.Identifier)
|
||||
}
|
||||
|
||||
// params should be empty
|
||||
if len(params) != 0 {
|
||||
t.Errorf("(%d) Expected 0 params, got %v", i, params)
|
||||
if len(r.Params) != 0 {
|
||||
t.Errorf("(%d) Expected 0 r.Params, got %v", i, r.Params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+13
-6
@@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const randomSortKey string = "@random"
|
||||
|
||||
// sort field directions
|
||||
const (
|
||||
SortAsc string = "ASC"
|
||||
@@ -19,14 +21,19 @@ type SortField struct {
|
||||
|
||||
// BuildExpr resolves the sort field into a valid db sort expression.
|
||||
func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
|
||||
name, params, err := fieldResolver.Resolve(s.Name)
|
||||
|
||||
// invalidate empty fields and non-column identifiers
|
||||
if err != nil || len(params) > 0 || name == "" || strings.ToLower(name) == "null" {
|
||||
return "", fmt.Errorf("Invalid sort field %q.", s.Name)
|
||||
// special case for random sort
|
||||
if s.Name == randomSortKey {
|
||||
return "RANDOM()", nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s", name, s.Direction), nil
|
||||
result, err := fieldResolver.Resolve(s.Name)
|
||||
|
||||
// invalidate empty fields and non-column identifiers
|
||||
if err != nil || len(result.Params) > 0 || result.Identifier == "" || strings.ToLower(result.Identifier) == "null" {
|
||||
return "", fmt.Errorf("invalid sort field %q", s.Name)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s", result.Identifier, s.Direction), nil
|
||||
}
|
||||
|
||||
// ParseSortFromString parses the provided string expression
|
||||
|
||||
@@ -27,6 +27,8 @@ func TestSortFieldBuildExpr(t *testing.T) {
|
||||
{search.SortField{"test1", search.SortAsc}, false, "[[test1]] ASC"},
|
||||
// allowed field - desc
|
||||
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
|
||||
// special @random field (ignore direction)
|
||||
{search.SortField{"@random", search.SortDesc}, false, "RANDOM()"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
@@ -54,6 +56,7 @@ func TestParseSortFromString(t *testing.T) {
|
||||
{"+test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"-test", `[{"name":"test","direction":"DESC"}]`},
|
||||
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
|
||||
{"@random,-test", `[{"name":"@random","direction":"ASC"},{"name":"test","direction":"DESC"}]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
|
||||
Reference in New Issue
Block a user