initial public commit
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// AuthUser defines a standardized oauth2 user data structure.
|
||||
type AuthUser struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatarUrl"`
|
||||
}
|
||||
|
||||
// Provider defines a common interface for an OAuth2 client.
|
||||
type Provider interface {
|
||||
// Scopes returns the provider access permissions that will be requested.
|
||||
Scopes() []string
|
||||
|
||||
// SetScopes sets the provider access permissions that will be requested later.
|
||||
SetScopes(scopes []string)
|
||||
|
||||
// ClientId returns the provider client's app ID.
|
||||
ClientId() string
|
||||
|
||||
// SetClientId sets the provider client's ID.
|
||||
SetClientId(clientId string)
|
||||
|
||||
// ClientId returns the provider client's app secret.
|
||||
ClientSecret() string
|
||||
|
||||
// SetClientSecret sets the provider client's app secret.
|
||||
SetClientSecret(secret string)
|
||||
|
||||
// RedirectUrl returns the end address to redirect the user
|
||||
// going through the OAuth flow.
|
||||
RedirectUrl() string
|
||||
|
||||
// SetRedirectUrl sets the provider's RedirectUrl.
|
||||
SetRedirectUrl(url string)
|
||||
|
||||
// AuthUrl returns the provider's authorization service url.
|
||||
AuthUrl() string
|
||||
|
||||
// SetAuthUrl sets the provider's AuthUrl.
|
||||
SetAuthUrl(url string)
|
||||
|
||||
// TokenUrl returns the provider's token exchange service url.
|
||||
TokenUrl() string
|
||||
|
||||
// SetTokenUrl sets the provider's TokenUrl.
|
||||
SetTokenUrl(url string)
|
||||
|
||||
// UserApiUrl returns the provider's user info api url.
|
||||
UserApiUrl() string
|
||||
|
||||
// SetUserApiUrl sets the provider's UserApiUrl.
|
||||
SetUserApiUrl(url string)
|
||||
|
||||
// Client returns an http client using the provided token.
|
||||
Client(token *oauth2.Token) *http.Client
|
||||
|
||||
// BuildAuthUrl returns a URL to the provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
BuildAuthUrl(state string, opts ...oauth2.AuthCodeOption) string
|
||||
|
||||
// FetchToken converts an authorization code to token.
|
||||
FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
|
||||
// FetchRawUserData requests and marshalizes into `result` the
|
||||
// the OAuth user api response.
|
||||
FetchRawUserData(token *oauth2.Token, result any) error
|
||||
|
||||
// FetchAuthUser is similar to FetchRawUserData, but normalizes and
|
||||
// marshalizes the user api response into a standardized AuthUser struct.
|
||||
FetchAuthUser(token *oauth2.Token) (user *AuthUser, err error)
|
||||
}
|
||||
|
||||
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
|
||||
func NewProviderByName(name string) (Provider, error) {
|
||||
switch name {
|
||||
case NameGoogle:
|
||||
return NewGoogleProvider(), nil
|
||||
case NameFacebook:
|
||||
return NewFacebookProvider(), nil
|
||||
case NameGithub:
|
||||
return NewGithubProvider(), nil
|
||||
case NameGitlab:
|
||||
return NewGitlabProvider(), nil
|
||||
default:
|
||||
return nil, errors.New("Missing provider " + name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
)
|
||||
|
||||
func TestNewProviderByName(t *testing.T) {
|
||||
var err error
|
||||
var p auth.Provider
|
||||
|
||||
// invalid
|
||||
p, err = auth.NewProviderByName("invalid")
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if p != nil {
|
||||
t.Errorf("Expected provider to be nil, got %v", p)
|
||||
}
|
||||
|
||||
// google
|
||||
p, err = auth.NewProviderByName(auth.NameGoogle)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Google); !ok {
|
||||
t.Error("Expected to be instance of *auth.Google")
|
||||
}
|
||||
|
||||
// facebook
|
||||
p, err = auth.NewProviderByName(auth.NameFacebook)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Facebook); !ok {
|
||||
t.Error("Expected to be instance of *auth.Facebook")
|
||||
}
|
||||
|
||||
// github
|
||||
p, err = auth.NewProviderByName(auth.NameGithub)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Github); !ok {
|
||||
t.Error("Expected to be instance of *auth.Github")
|
||||
}
|
||||
|
||||
// gitlab
|
||||
p, err = auth.NewProviderByName(auth.NameGitlab)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Gitlab); !ok {
|
||||
t.Error("Expected to be instance of *auth.Gitlab")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// baseProvider defines common fields and methods used by OAuth2 client providers.
|
||||
type baseProvider struct {
|
||||
scopes []string
|
||||
clientId string
|
||||
clientSecret string
|
||||
redirectUrl string
|
||||
authUrl string
|
||||
tokenUrl string
|
||||
userApiUrl string
|
||||
}
|
||||
|
||||
// Scopes implements Provider.Scopes interface.
|
||||
func (p *baseProvider) Scopes() []string {
|
||||
return p.scopes
|
||||
}
|
||||
|
||||
// SetScopes implements Provider.SetScopes interface.
|
||||
func (p *baseProvider) SetScopes(scopes []string) {
|
||||
p.scopes = scopes
|
||||
}
|
||||
|
||||
// ClientId implements Provider.ClientId interface.
|
||||
func (p *baseProvider) ClientId() string {
|
||||
return p.clientId
|
||||
}
|
||||
|
||||
// SetClientId implements Provider.SetClientId interface.
|
||||
func (p *baseProvider) SetClientId(clientId string) {
|
||||
p.clientId = clientId
|
||||
}
|
||||
|
||||
// ClientSecret implements Provider.ClientSecret interface.
|
||||
func (p *baseProvider) ClientSecret() string {
|
||||
return p.clientSecret
|
||||
}
|
||||
|
||||
// SetClientSecret implements Provider.SetClientSecret interface.
|
||||
func (p *baseProvider) SetClientSecret(secret string) {
|
||||
p.clientSecret = secret
|
||||
}
|
||||
|
||||
// RedirectUrl implements Provider.RedirectUrl interface.
|
||||
func (p *baseProvider) RedirectUrl() string {
|
||||
return p.redirectUrl
|
||||
}
|
||||
|
||||
// SetRedirectUrl implements Provider.SetRedirectUrl interface.
|
||||
func (p *baseProvider) SetRedirectUrl(url string) {
|
||||
p.redirectUrl = url
|
||||
}
|
||||
|
||||
// AuthUrl implements Provider.AuthUrl interface.
|
||||
func (p *baseProvider) AuthUrl() string {
|
||||
return p.authUrl
|
||||
}
|
||||
|
||||
// SetAuthUrl implements Provider.SetAuthUrl interface.
|
||||
func (p *baseProvider) SetAuthUrl(url string) {
|
||||
p.authUrl = url
|
||||
}
|
||||
|
||||
// TokenUrl implements Provider.TokenUrl interface.
|
||||
func (p *baseProvider) TokenUrl() string {
|
||||
return p.tokenUrl
|
||||
}
|
||||
|
||||
// SetTokenUrl implements Provider.SetTokenUrl interface.
|
||||
func (p *baseProvider) SetTokenUrl(url string) {
|
||||
p.tokenUrl = url
|
||||
}
|
||||
|
||||
// UserApiUrl implements Provider.UserApiUrl interface.
|
||||
func (p *baseProvider) UserApiUrl() string {
|
||||
return p.userApiUrl
|
||||
}
|
||||
|
||||
// SetUserApiUrl implements Provider.SetUserApiUrl interface.
|
||||
func (p *baseProvider) SetUserApiUrl(url string) {
|
||||
p.userApiUrl = url
|
||||
}
|
||||
|
||||
// BuildAuthUrl implements Provider.BuildAuthUrl interface.
|
||||
func (p *baseProvider) BuildAuthUrl(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
return p.oauth2Config().AuthCodeURL(state, opts...)
|
||||
}
|
||||
|
||||
// FetchToken implements Provider.FetchToken interface.
|
||||
func (p *baseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return p.oauth2Config().Exchange(context.Background(), code, opts...)
|
||||
}
|
||||
|
||||
// Client implements Provider.Client interface.
|
||||
func (p *baseProvider) Client(token *oauth2.Token) *http.Client {
|
||||
return p.oauth2Config().Client(context.Background(), token)
|
||||
}
|
||||
|
||||
// FetchRawUserData implements Provider.FetchRawUserData interface.
|
||||
func (p *baseProvider) FetchRawUserData(token *oauth2.Token, result any) error {
|
||||
client := p.Client(token)
|
||||
|
||||
response, err := client.Get(p.userApiUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
content, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(content, &result)
|
||||
}
|
||||
|
||||
// oauth2Config constructs a oauth2.Config instance based on the provider settings.
|
||||
func (p *baseProvider) oauth2Config() *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
RedirectURL: p.redirectUrl,
|
||||
ClientID: p.clientId,
|
||||
ClientSecret: p.clientSecret,
|
||||
Scopes: p.scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: p.authUrl,
|
||||
TokenURL: p.tokenUrl,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestScopes(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.Scopes()
|
||||
if len(before) != 0 {
|
||||
t.Errorf("Expected 0 scopes, got %v", before)
|
||||
}
|
||||
|
||||
b.SetScopes([]string{"test1", "test2"})
|
||||
|
||||
after := b.Scopes()
|
||||
if len(after) != 2 {
|
||||
t.Errorf("Expected 2 scopes, got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientId(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.ClientId()
|
||||
if before != "" {
|
||||
t.Errorf("Expected clientId to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetClientId("test")
|
||||
|
||||
after := b.ClientId()
|
||||
if after != "test" {
|
||||
t.Errorf("Expected clientId to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSecret(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.ClientSecret()
|
||||
if before != "" {
|
||||
t.Errorf("Expected clientSecret to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetClientSecret("test")
|
||||
|
||||
after := b.ClientSecret()
|
||||
if after != "test" {
|
||||
t.Errorf("Expected clientSecret to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.RedirectUrl()
|
||||
if before != "" {
|
||||
t.Errorf("Expected RedirectUrl to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetRedirectUrl("test")
|
||||
|
||||
after := b.RedirectUrl()
|
||||
if after != "test" {
|
||||
t.Errorf("Expected RedirectUrl to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.AuthUrl()
|
||||
if before != "" {
|
||||
t.Errorf("Expected authUrl to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetAuthUrl("test")
|
||||
|
||||
after := b.AuthUrl()
|
||||
if after != "test" {
|
||||
t.Errorf("Expected authUrl to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.TokenUrl()
|
||||
if before != "" {
|
||||
t.Errorf("Expected tokenUrl to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetTokenUrl("test")
|
||||
|
||||
after := b.TokenUrl()
|
||||
if after != "test" {
|
||||
t.Errorf("Expected tokenUrl to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserApiUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
before := b.UserApiUrl()
|
||||
if before != "" {
|
||||
t.Errorf("Expected userApiUrl to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetUserApiUrl("test")
|
||||
|
||||
after := b.UserApiUrl()
|
||||
if after != "test" {
|
||||
t.Errorf("Expected userApiUrl to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthUrl(t *testing.T) {
|
||||
b := baseProvider{
|
||||
authUrl: "authUrl_test",
|
||||
tokenUrl: "tokenUrl_test",
|
||||
redirectUrl: "redirectUrl_test",
|
||||
clientId: "clientId_test",
|
||||
clientSecret: "clientSecret_test",
|
||||
scopes: []string{"test_scope"},
|
||||
}
|
||||
|
||||
expected := "authUrl_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectUrl_test&response_type=code&scope=test_scope&state=state_test"
|
||||
result := b.BuildAuthUrl("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected auth url %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
|
||||
result := b.Client(&oauth2.Token{})
|
||||
if result == nil {
|
||||
t.Error("Expected *http.Client instance, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauth2Config(t *testing.T) {
|
||||
b := baseProvider{
|
||||
authUrl: "authUrl_test",
|
||||
tokenUrl: "tokenUrl_test",
|
||||
redirectUrl: "redirectUrl_test",
|
||||
clientId: "clientId_test",
|
||||
clientSecret: "clientSecret_test",
|
||||
scopes: []string{"test"},
|
||||
}
|
||||
|
||||
result := b.oauth2Config()
|
||||
|
||||
if result.RedirectURL != b.RedirectUrl() {
|
||||
t.Errorf("Expected redirectUrl %s, got %s", b.RedirectUrl(), result.RedirectURL)
|
||||
}
|
||||
|
||||
if result.ClientID != b.ClientId() {
|
||||
t.Errorf("Expected clientId %s, got %s", b.ClientId(), result.ClientID)
|
||||
}
|
||||
|
||||
if result.ClientSecret != b.ClientSecret() {
|
||||
t.Errorf("Expected clientSecret %s, got %s", b.ClientSecret(), result.ClientSecret)
|
||||
}
|
||||
|
||||
if result.Endpoint.AuthURL != b.AuthUrl() {
|
||||
t.Errorf("Expected authUrl %s, got %s", b.AuthUrl(), result.Endpoint.AuthURL)
|
||||
}
|
||||
|
||||
if result.Endpoint.TokenURL != b.TokenUrl() {
|
||||
t.Errorf("Expected authUrl %s, got %s", b.TokenUrl(), result.Endpoint.TokenURL)
|
||||
}
|
||||
|
||||
if len(result.Scopes) != len(b.Scopes()) || result.Scopes[0] != b.Scopes()[0] {
|
||||
t.Errorf("Expected scopes %s, got %s", b.Scopes(), result.Scopes)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var _ Provider = (*Facebook)(nil)
|
||||
|
||||
// NameFacebook is the unique name of the Facebook provider.
|
||||
const NameFacebook string = "facebook"
|
||||
|
||||
// Facebook allows authentication via Facebook OAuth2.
|
||||
type Facebook struct {
|
||||
*baseProvider
|
||||
}
|
||||
|
||||
// NewFacebookProvider creates new Facebook provider instance with some defaults.
|
||||
func NewFacebookProvider() *Facebook {
|
||||
return &Facebook{&baseProvider{
|
||||
scopes: []string{"email"},
|
||||
authUrl: "https://www.facebook.com/dialog/oauth",
|
||||
tokenUrl: "https://graph.facebook.com/oauth/access_token",
|
||||
userApiUrl: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Facebook's user api.
|
||||
func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
// https://developers.facebook.com/docs/graph-api/reference/user/
|
||||
rawData := struct {
|
||||
Id string
|
||||
Name string
|
||||
Email string
|
||||
Picture struct {
|
||||
Data struct{ Url string }
|
||||
}
|
||||
}{}
|
||||
|
||||
if err := p.FetchRawUserData(token, &rawData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: rawData.Id,
|
||||
Name: rawData.Name,
|
||||
Email: rawData.Email,
|
||||
AvatarUrl: rawData.Picture.Data.Url,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var _ Provider = (*Github)(nil)
|
||||
|
||||
// NameGithub is the unique name of the Github provider.
|
||||
const NameGithub string = "github"
|
||||
|
||||
// Github allows authentication via Github OAuth2.
|
||||
type Github struct {
|
||||
*baseProvider
|
||||
}
|
||||
|
||||
// NewGithubProvider creates new Github provider instance with some defaults.
|
||||
func NewGithubProvider() *Github {
|
||||
return &Github{&baseProvider{
|
||||
scopes: []string{"user"},
|
||||
authUrl: "https://github.com/login/oauth/authorize",
|
||||
tokenUrl: "https://github.com/login/oauth/access_token",
|
||||
userApiUrl: "https://api.github.com/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Github's user api.
|
||||
func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
// https://docs.github.com/en/rest/reference/users#get-the-authenticated-user
|
||||
rawData := struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
}{}
|
||||
|
||||
if err := p.FetchRawUserData(token, &rawData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(rawData.Id),
|
||||
Name: rawData.Name,
|
||||
Email: rawData.Email,
|
||||
AvatarUrl: rawData.AvatarUrl,
|
||||
}
|
||||
|
||||
// in case user set "Keep my email address private",
|
||||
// email should be retrieved via extra API request
|
||||
if user.Email == "" {
|
||||
client := p.Client(token)
|
||||
|
||||
response, err := client.Get(p.userApiUrl + "/emails")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
content, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
emails := []struct {
|
||||
Email string
|
||||
Verified bool
|
||||
Primary bool
|
||||
}{}
|
||||
if err := json.Unmarshal(content, &emails); err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// extract the verified primary email
|
||||
for _, email := range emails {
|
||||
if email.Verified && email.Primary {
|
||||
user.Email = email.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var _ Provider = (*Gitlab)(nil)
|
||||
|
||||
// NameGitlab is the unique name of the Gitlab provider.
|
||||
const NameGitlab string = "gitlab"
|
||||
|
||||
// Gitlab allows authentication via Gitlab OAuth2.
|
||||
type Gitlab struct {
|
||||
*baseProvider
|
||||
}
|
||||
|
||||
// NewGitlabProvider creates new Gitlab provider instance with some defaults.
|
||||
func NewGitlabProvider() *Gitlab {
|
||||
return &Gitlab{&baseProvider{
|
||||
scopes: []string{"read_user"},
|
||||
authUrl: "https://gitlab.com/oauth/authorize",
|
||||
tokenUrl: "https://gitlab.com/oauth/token",
|
||||
userApiUrl: "https://gitlab.com/api/v4/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Gitlab's user api.
|
||||
func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
// https://docs.gitlab.com/ee/api/users.html#for-admin
|
||||
rawData := struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
}{}
|
||||
|
||||
if err := p.FetchRawUserData(token, &rawData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(rawData.Id),
|
||||
Name: rawData.Name,
|
||||
Email: rawData.Email,
|
||||
AvatarUrl: rawData.AvatarUrl,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var _ Provider = (*Google)(nil)
|
||||
|
||||
// NameGoogle is the unique name of the Google provider.
|
||||
const NameGoogle string = "google"
|
||||
|
||||
// Google allows authentication via Google OAuth2.
|
||||
type Google struct {
|
||||
*baseProvider
|
||||
}
|
||||
|
||||
// NewGoogleProvider creates new Google provider instance with some defaults.
|
||||
func NewGoogleProvider() *Google {
|
||||
return &Google{&baseProvider{
|
||||
scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
},
|
||||
authUrl: "https://accounts.google.com/o/oauth2/auth",
|
||||
tokenUrl: "https://accounts.google.com/o/oauth2/token",
|
||||
userApiUrl: "https://www.googleapis.com/oauth2/v1/userinfo",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Google's user api.
|
||||
func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
// https://cloud.google.com/identity-platform/docs/reference/rest/v1/UserInfo
|
||||
rawData := struct {
|
||||
LocalId string `json:"localId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
PhotoUrl string `json:"photoUrl"`
|
||||
}{}
|
||||
|
||||
if err := p.FetchRawUserData(token, &rawData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: rawData.LocalId,
|
||||
Name: rawData.DisplayName,
|
||||
Email: rawData.Email,
|
||||
AvatarUrl: rawData.PhotoUrl,
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/disintegration/imaging"
|
||||
"gocloud.dev/blob"
|
||||
"gocloud.dev/blob/fileblob"
|
||||
"gocloud.dev/blob/s3blob"
|
||||
)
|
||||
|
||||
type System struct {
|
||||
ctx context.Context
|
||||
bucket *blob.Bucket
|
||||
}
|
||||
|
||||
// NewS3 initializes an S3 filesystem instance.
|
||||
//
|
||||
// NB! Make sure to call `Close()` after you are done working with it.
|
||||
func NewS3(
|
||||
bucketName string,
|
||||
region string,
|
||||
endpoint string,
|
||||
accessKey string,
|
||||
secretKey string,
|
||||
) (*System, error) {
|
||||
ctx := context.Background() // default context
|
||||
|
||||
cred := credentials.NewStaticCredentials(accessKey, secretKey, "")
|
||||
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String(region),
|
||||
Endpoint: aws.String(endpoint),
|
||||
Credentials: cred,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bucket, err := s3blob.OpenBucket(ctx, sess, bucketName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &System{ctx: ctx, bucket: bucket}, nil
|
||||
}
|
||||
|
||||
// NewLocal initializes a new local filesystem instance.
|
||||
//
|
||||
// NB! Make sure to call `Close()` after you are done working with it.
|
||||
func NewLocal(dirPath string) (*System, error) {
|
||||
ctx := context.Background() // default context
|
||||
|
||||
// makes sure that the directory exist
|
||||
if err := os.MkdirAll(dirPath, os.ModePerm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bucket, err := fileblob.OpenBucket(dirPath, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &System{ctx: ctx, bucket: bucket}, nil
|
||||
}
|
||||
|
||||
// Close releases any resources used for the related filesystem.
|
||||
func (s *System) Close() error {
|
||||
return s.bucket.Close()
|
||||
}
|
||||
|
||||
// Exists checks if file with fileKey path exists or not.
|
||||
func (s *System) Exists(fileKey string) (bool, error) {
|
||||
return s.bucket.Exists(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// Attributes returns the attributes for the file with fileKey path.
|
||||
func (s *System) Attributes(fileKey string) (*blob.Attributes, error) {
|
||||
return s.bucket.Attributes(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// Upload writes content into the fileKey location.
|
||||
func (s *System) Upload(content []byte, fileKey string) error {
|
||||
w, writerErr := s.bucket.NewWriter(s.ctx, fileKey, nil)
|
||||
if writerErr != nil {
|
||||
return writerErr
|
||||
}
|
||||
|
||||
if _, err := w.Write(content); err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// Delete deletes stored file at fileKey location.
|
||||
func (s *System) Delete(fileKey string) error {
|
||||
return s.bucket.Delete(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// DeletePrefix deletes everything starting with the specified prefix.
|
||||
func (s *System) DeletePrefix(prefix string) []error {
|
||||
failed := []error{}
|
||||
|
||||
if prefix == "" {
|
||||
failed = append(failed, errors.New("Prefix mustn't be empty."))
|
||||
return failed
|
||||
}
|
||||
|
||||
dirsMap := map[string]struct{}{}
|
||||
dirsMap[prefix] = struct{}{}
|
||||
|
||||
opts := blob.ListOptions{
|
||||
Prefix: prefix,
|
||||
}
|
||||
|
||||
// delete all files witht the prefix
|
||||
// ---
|
||||
iter := s.bucket.List(&opts)
|
||||
for {
|
||||
obj, err := iter.Next(s.ctx)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
failed = append(failed, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.Delete(obj.Key); err != nil {
|
||||
failed = append(failed, err)
|
||||
} else {
|
||||
dirsMap[filepath.Dir(obj.Key)] = struct{}{}
|
||||
}
|
||||
}
|
||||
// ---
|
||||
|
||||
// try to delete the empty remaining dir objects
|
||||
// (this operation usually is optional and there is no need to strictly check the result)
|
||||
// ---
|
||||
// fill dirs slice
|
||||
dirs := []string{}
|
||||
for d := range dirsMap {
|
||||
dirs = append(dirs, d)
|
||||
}
|
||||
|
||||
// sort the child dirs first, aka. ["a/b/c", "a/b", "a"]
|
||||
sort.SliceStable(dirs, func(i, j int) bool {
|
||||
return len(strings.Split(dirs[i], "/")) > len(strings.Split(dirs[j], "/"))
|
||||
})
|
||||
|
||||
// delete dirs
|
||||
for _, d := range dirs {
|
||||
if d != "" {
|
||||
s.Delete(d)
|
||||
}
|
||||
}
|
||||
// ---
|
||||
|
||||
return failed
|
||||
}
|
||||
|
||||
// Serve serves the file at fileKey location to an HTTP response.
|
||||
func (s *System) Serve(response http.ResponseWriter, fileKey string, name string) error {
|
||||
r, readErr := s.bucket.NewReader(s.ctx, fileKey, nil)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// All HTTP date/time stamps MUST be represented in Greenwich Mean Time (GMT)
|
||||
// (see https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.3.1)
|
||||
location, _ := time.LoadLocation("GMT")
|
||||
|
||||
response.Header().Set("Content-Disposition", "attachment; filename="+name)
|
||||
response.Header().Set("Content-Type", r.ContentType())
|
||||
response.Header().Set("Content-Length", strconv.FormatInt(r.Size(), 10))
|
||||
response.Header().Set("Last-Modified", r.ModTime().In(location).Format("Mon, 02 Jan 06 15:04:05 MST"))
|
||||
|
||||
// copy from the read range to response.
|
||||
_, err := io.Copy(response, r)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateThumb creates a new thumb image for the file at originalKey location.
|
||||
// The new thumb file is stored at thumbKey location.
|
||||
//
|
||||
// thumbSize is in the format "WxH", eg. "100x50".
|
||||
func (s *System) CreateThumb(originalKey string, thumbKey, thumbSize string, cropCenter bool) error {
|
||||
thumbSizeParts := strings.SplitN(thumbSize, "x", 2)
|
||||
if len(thumbSizeParts) != 2 {
|
||||
return errors.New("Thumb size must be in WxH format.")
|
||||
}
|
||||
|
||||
width, _ := strconv.Atoi(thumbSizeParts[0])
|
||||
height, _ := strconv.Atoi(thumbSizeParts[1])
|
||||
|
||||
// fetch the original
|
||||
r, readErr := s.bucket.NewReader(s.ctx, originalKey, nil)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// create imaging object from the origial reader
|
||||
img, decodeErr := imaging.Decode(r)
|
||||
if decodeErr != nil {
|
||||
return decodeErr
|
||||
}
|
||||
|
||||
// determine crop anchor
|
||||
cropAnchor := imaging.Center
|
||||
if !cropCenter {
|
||||
cropAnchor = imaging.Top
|
||||
}
|
||||
|
||||
// create thumb imaging object
|
||||
thumbImg := imaging.Fill(img, width, height, cropAnchor, imaging.CatmullRom)
|
||||
|
||||
// open a thumb storage writer (aka. prepare for upload)
|
||||
w, writerErr := s.bucket.NewWriter(s.ctx, thumbKey, nil)
|
||||
if writerErr != nil {
|
||||
return writerErr
|
||||
}
|
||||
|
||||
// thumb encode (aka. upload)
|
||||
if err := imaging.Encode(w, thumbImg, imaging.PNG); err != nil {
|
||||
w.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// check for close errors to ensure that the thumb was really saved
|
||||
return w.Close()
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package filesystem_test
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/png"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
func TestFileSystemExists(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
file string
|
||||
exists bool
|
||||
}{
|
||||
{"sub1.txt", false},
|
||||
{"test/sub1.txt", true},
|
||||
{"test/sub2.txt", true},
|
||||
{"file.png", true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
exists, _ := fs.Exists(scenario.file)
|
||||
|
||||
if exists != scenario.exists {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.exists, exists)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemAttributes(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
file string
|
||||
expectError bool
|
||||
}{
|
||||
{"sub1.txt", true},
|
||||
{"test/sub1.txt", false},
|
||||
{"test/sub2.txt", false},
|
||||
{"file.png", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
attr, err := fs.Attributes(scenario.file)
|
||||
|
||||
if err == nil && scenario.expectError {
|
||||
t.Errorf("(%d) Expected error, got nil", i)
|
||||
}
|
||||
|
||||
if err != nil && !scenario.expectError {
|
||||
t.Errorf("(%d) Expected nil, got error, %v", i, err)
|
||||
}
|
||||
|
||||
if err == nil && attr.ContentType != "application/octet-stream" {
|
||||
t.Errorf("(%d) Expected attr.ContentType to be %q, got %q", i, "application/octet-stream", attr.ContentType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemDelete(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
if err := fs.Delete("missing.txt"); err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if err := fs.Delete("file.png"); err != nil {
|
||||
t.Fatalf("Expected nil, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemDeletePrefix(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
if errs := fs.DeletePrefix(""); len(errs) == 0 {
|
||||
t.Fatal("Expected error, got nil", errs)
|
||||
}
|
||||
|
||||
if errs := fs.DeletePrefix("missing/"); len(errs) != 0 {
|
||||
t.Fatalf("Not existing prefix shouldn't error, got %v", errs)
|
||||
}
|
||||
|
||||
if errs := fs.DeletePrefix("test"); len(errs) != 0 {
|
||||
t.Fatalf("Expected nil, got errors %v", errs)
|
||||
}
|
||||
|
||||
// ensure that the test/ files are deleted
|
||||
if exists, _ := fs.Exists("test/sub1.txt"); exists {
|
||||
t.Fatalf("Expected test/sub1.txt to be deleted")
|
||||
}
|
||||
if exists, _ := fs.Exists("test/sub2.txt"); exists {
|
||||
t.Fatalf("Expected test/sub2.txt to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemUpload(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
uploadErr := fs.Upload([]byte("demo"), "newdir/newkey.txt")
|
||||
if uploadErr != nil {
|
||||
t.Fatal(uploadErr)
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists("newdir/newkey.txt"); !exists {
|
||||
t.Fatalf("Expected newdir/newkey.txt to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemServe(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
r := httptest.NewRecorder()
|
||||
|
||||
// serve missing file
|
||||
if err := fs.Serve(r, "missing.txt", "download.txt"); err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
// serve existing file
|
||||
if err := fs.Serve(r, "test/sub1.txt", "download.txt"); err != nil {
|
||||
t.Fatal("Expected nil, got error")
|
||||
}
|
||||
|
||||
result := r.Result()
|
||||
|
||||
// check headers
|
||||
scenarios := []struct {
|
||||
header string
|
||||
expected string
|
||||
}{
|
||||
{"Content-Disposition", "attachment; filename=download.txt"},
|
||||
{"Content-Type", "application/octet-stream"},
|
||||
{"Content-Length", "0"},
|
||||
}
|
||||
for i, scenario := range scenarios {
|
||||
v := result.Header.Get(scenario.header)
|
||||
if v != scenario.expected {
|
||||
t.Errorf("(%d) Expected value %q for header %q, got %q", i, scenario.expected, scenario.header, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemCreateThumb(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
file string
|
||||
thumb string
|
||||
cropCenter bool
|
||||
expectError bool
|
||||
}{
|
||||
// missing
|
||||
{"missing.txt", "thumb_test_missing", true, true},
|
||||
// non-image existing file
|
||||
{"test/sub1.txt", "thumb_test_sub1", true, true},
|
||||
// existing image file - crop center
|
||||
{"file.png", "thumb_file_center", true, false},
|
||||
// existing image file - crop top
|
||||
{"file.png", "thumb_file_top", false, false},
|
||||
// existing image file with existing thumb path = should fail
|
||||
{"file.png", "test", true, true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := fs.CreateThumb(scenario.file, scenario.thumb, "100x100", scenario.cropCenter)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists(scenario.thumb); !exists {
|
||||
t.Errorf("(%d) Couldn't find %q thumb", i, scenario.thumb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
func createTestDir(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "pb_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "test"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
file1, err := os.OpenFile(filepath.Join(dir, "test/sub1.txt"), os.O_WRONLY|os.O_CREATE, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
file1.Close()
|
||||
|
||||
file2, err := os.OpenFile(filepath.Join(dir, "test/sub2.txt"), os.O_WRONLY|os.O_CREATE, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
file2.Close()
|
||||
|
||||
file3, err := os.OpenFile(filepath.Join(dir, "file.png"), os.O_WRONLY|os.O_CREATE, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// tiny 1x1 png
|
||||
imgRect := image.Rect(0, 0, 1, 1)
|
||||
png.Encode(file3, imgRect)
|
||||
file3.Close()
|
||||
|
||||
return dir
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var StopPropagation = errors.New("Event hook propagation stopped")
|
||||
|
||||
// Handler defines a hook handler function.
|
||||
type Handler[T any] func(data T) error
|
||||
|
||||
// Hook defines a concurrent safe structure for handling event hooks
|
||||
// (aka. callbacks propagation).
|
||||
type Hook[T any] struct {
|
||||
mux sync.RWMutex
|
||||
handlers []Handler[T]
|
||||
}
|
||||
|
||||
// Add registers a new handler to the hook.
|
||||
func (h *Hook[T]) Add(fn Handler[T]) {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.handlers = append(h.handlers, fn)
|
||||
}
|
||||
|
||||
// Reset removes all registered handlers.
|
||||
func (h *Hook[T]) Reset() {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
h.handlers = nil
|
||||
}
|
||||
|
||||
// Trigger executes all registered hook handlers one by one
|
||||
// with the specified `data` as an argument.
|
||||
//
|
||||
// Optionally, this method allows also to register additional one off
|
||||
// handlers that will be temporary appended to the handlers queue.
|
||||
//
|
||||
// The execution stops when:
|
||||
// - hook.StopPropagation is returned in one of the handlers
|
||||
// - any non-nil error is returned in one of the handlers
|
||||
func (h *Hook[T]) Trigger(data T, oneOffHandlers ...Handler[T]) error {
|
||||
h.mux.Lock()
|
||||
handlers := append(h.handlers, oneOffHandlers...)
|
||||
h.mux.Unlock() // unlock is not deferred to avoid deadlocks when Trigger is called recursive in the handlers
|
||||
|
||||
for _, fn := range handlers {
|
||||
err := fn(data)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if errors.Is(err, StopPropagation) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
h := Hook[int]{}
|
||||
|
||||
if total := len(h.handlers); total != 0 {
|
||||
t.Fatalf("Expected no handlers, found %d", total)
|
||||
}
|
||||
|
||||
h.Add(func(data int) error { return nil })
|
||||
h.Add(func(data int) error { return nil })
|
||||
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected 2 handlers, found %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReset(t *testing.T) {
|
||||
h := Hook[int]{}
|
||||
|
||||
h.Reset() // should do nothing and not panic
|
||||
|
||||
h.Add(func(data int) error { return nil })
|
||||
h.Add(func(data int) error { return nil })
|
||||
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected 2 handlers before Reset, found %d", total)
|
||||
}
|
||||
|
||||
h.Reset()
|
||||
|
||||
if total := len(h.handlers); total != 0 {
|
||||
t.Fatalf("Expected no handlers after Reset, found %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrigger(t *testing.T) {
|
||||
err1 := errors.New("demo")
|
||||
err2 := errors.New("demo")
|
||||
|
||||
scenarios := []struct {
|
||||
handlers []Handler[int]
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
[]Handler[int]{
|
||||
func(data int) error { return nil },
|
||||
func(data int) error { return nil },
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]Handler[int]{
|
||||
func(data int) error { return nil },
|
||||
func(data int) error { return err1 },
|
||||
func(data int) error { return err2 },
|
||||
},
|
||||
err1,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
h := Hook[int]{}
|
||||
for _, handler := range scenario.handlers {
|
||||
h.Add(handler)
|
||||
}
|
||||
result := h.Trigger(1)
|
||||
if result != scenario.expectedError {
|
||||
t.Fatalf("(%d) Expected %v, got %v", i, scenario.expectedError, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerStopPropagation(t *testing.T) {
|
||||
called1 := false
|
||||
f1 := func(data int) error { called1 = true; return nil }
|
||||
|
||||
called2 := false
|
||||
f2 := func(data int) error { called2 = true; return nil }
|
||||
|
||||
called3 := false
|
||||
f3 := func(data int) error { called3 = true; return nil }
|
||||
|
||||
called4 := false
|
||||
f4 := func(data int) error { called4 = true; return StopPropagation }
|
||||
|
||||
called5 := false
|
||||
f5 := func(data int) error { called5 = true; return nil }
|
||||
|
||||
called6 := false
|
||||
f6 := func(data int) error { called6 = true; return nil }
|
||||
|
||||
h := Hook[int]{}
|
||||
h.Add(f1)
|
||||
h.Add(f2)
|
||||
|
||||
result := h.Trigger(123, f3, f4, f5, f6)
|
||||
|
||||
if result != nil {
|
||||
t.Fatalf("Expected nil after StopPropagation, got %v", result)
|
||||
}
|
||||
|
||||
// ensure that the trigger handler were not persisted
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected 2 handlers, found %d", total)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
called bool
|
||||
expected bool
|
||||
}{
|
||||
{called1, true},
|
||||
{called2, true},
|
||||
{called3, true},
|
||||
{called4, true}, // StopPropagation
|
||||
{called5, false},
|
||||
{called6, false},
|
||||
}
|
||||
for i, scenario := range scenarios {
|
||||
if scenario.called != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, scenario.called)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package inflector
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var columnifyRemoveRegex = regexp.MustCompile(`[^\w\.\*\-\_\@\#]+`)
|
||||
var snakecaseSplitRegex = regexp.MustCompile(`[\W_]+`)
|
||||
var usernamifySplitRegex = regexp.MustCompile(`\W+`)
|
||||
|
||||
// UcFirst converts the first character of a string into uppercase.
|
||||
func UcFirst(str string) string {
|
||||
if str == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := []rune(str)
|
||||
|
||||
return string(unicode.ToUpper(s[0])) + string(s[1:])
|
||||
}
|
||||
|
||||
// Columnify strips invalid db identifier characters.
|
||||
func Columnify(str string) string {
|
||||
return columnifyRemoveRegex.ReplaceAllString(str, "")
|
||||
}
|
||||
|
||||
// Sentenize converts and normalizes string into a sentence.
|
||||
func Sentenize(str string) string {
|
||||
str = strings.TrimSpace(str)
|
||||
if str == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := []rune(str)
|
||||
sentence := string(unicode.ToUpper(s[0])) + string(s[1:])
|
||||
|
||||
lastChar := string(s[len(s)-1:])
|
||||
if lastChar != "." && lastChar != "?" && lastChar != "!" {
|
||||
return sentence + "."
|
||||
}
|
||||
|
||||
return sentence
|
||||
}
|
||||
|
||||
// Sanitize sanitizes `str` by removing all characters satisfying `removePattern`.
|
||||
// Returns an error if the pattern is not valid regex string.
|
||||
func Sanitize(str string, removePattern string) (string, error) {
|
||||
exp, err := regexp.Compile(removePattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return exp.ReplaceAllString(str, ""), nil
|
||||
}
|
||||
|
||||
// Snakecase removes all non word characters and converts any english text into a snakecase.
|
||||
// "ABBREVIATIONS" are preserved, eg. "myTestDB" will become "my_test_db".
|
||||
func Snakecase(str string) string {
|
||||
var result strings.Builder
|
||||
|
||||
// split at any non word character and underscore
|
||||
words := snakecaseSplitRegex.Split(str, -1)
|
||||
|
||||
for _, word := range words {
|
||||
if word == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
result.WriteString("_")
|
||||
}
|
||||
|
||||
for i, c := range word {
|
||||
if unicode.IsUpper(c) && i > 0 &&
|
||||
// is not a following uppercase character
|
||||
!unicode.IsUpper(rune(word[i-1])) {
|
||||
result.WriteString("_")
|
||||
}
|
||||
|
||||
result.WriteRune(c)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// Usernamify generates a properly formatted username from the provided string.
|
||||
// Returns "unknown" if `str` is empty or contains only non word characters.
|
||||
//
|
||||
// ```go
|
||||
// Usernamify("John Doe, hello") // "john.doe.hello"
|
||||
// ```
|
||||
func Usernamify(str string) string {
|
||||
// split at any non word character
|
||||
words := usernamifySplitRegex.Split(strings.ToLower(str), -1)
|
||||
|
||||
// concatenate any non empty word with a dot
|
||||
var result strings.Builder
|
||||
for _, word := range words {
|
||||
if word == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
result.WriteString(".")
|
||||
}
|
||||
|
||||
result.WriteString(word)
|
||||
}
|
||||
|
||||
if result.Len() == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package inflector_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
func TestUcFirst(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"Test", "Test"},
|
||||
{"test", "Test"},
|
||||
{"test test2", "Test test2"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.UcFirst(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumnify(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"123", "123"},
|
||||
{"Test.", "Test."},
|
||||
{" test ", "test"},
|
||||
{"test1.test2", "test1.test2"},
|
||||
{"@test!abc", "@testabc"},
|
||||
{"#test?abc", "#testabc"},
|
||||
{"123test(123)#", "123test123#"},
|
||||
{"test1--test2", "test1--test2"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Columnify(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSentenize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"Test", "Test."},
|
||||
{" test ", "Test."},
|
||||
{"hello world", "Hello world."},
|
||||
{"hello world.", "Hello world."},
|
||||
{"hello world!", "Hello world!"},
|
||||
{"hello world?", "Hello world?"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Sentenize(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
pattern string
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{"", ``, "", false},
|
||||
{" ", ``, " ", false},
|
||||
{" ", ` `, "", false},
|
||||
{"", `[A-Z]`, "", false},
|
||||
{"abcABC", `[A-Z]`, "abc", false},
|
||||
{"abcABC", `[A-Z`, "", true}, // invlid pattern
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := inflector.Sanitize(scenario.val, scenario.pattern)
|
||||
hasErr := err != nil
|
||||
|
||||
if scenario.expectErr != hasErr {
|
||||
if scenario.expectErr {
|
||||
t.Errorf("(%d) Expected error, got nil", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Didn't expect error, got", err)
|
||||
}
|
||||
}
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnakecase(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"!@#$%^", ""},
|
||||
{"...", ""},
|
||||
{"_", ""},
|
||||
{"John Doe", "john_doe"},
|
||||
{"John_Doe", "john_doe"},
|
||||
{".a!b@c#d$e%123. ", "a_b_c_d_e_123"},
|
||||
{"HelloWorld", "hello_world"},
|
||||
{"HelloWorld1HelloWorld2", "hello_world1_hello_world2"},
|
||||
{"TEST", "test"},
|
||||
{"testABR", "test_abr"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Snakecase(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernamify(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", "unknown"},
|
||||
{" ", "unknown"},
|
||||
{"!@#$%^", "unknown"},
|
||||
{"...", "unknown"},
|
||||
{"_", "_"}, // underscore is valid word character
|
||||
{"John Doe", "john.doe"},
|
||||
{"John_Doe", "john_doe"},
|
||||
{".a!b@c#d$e%123. ", "a.b.c.d.e.123"},
|
||||
{"Hello, world", "hello.world"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Usernamify(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var cachedPatterns = map[string]*regexp.Regexp{}
|
||||
|
||||
// ExustInSlice checks whether a comparable element exists in a slice of the same type.
|
||||
func ExistInSlice[T comparable](item T, list []T) bool {
|
||||
if len(list) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, v := range list {
|
||||
if v == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExistInSliceWithRegex checks whether a string exists in a slice
|
||||
// either by direct match, or by a regular expression (eg. `^\w+$`).
|
||||
//
|
||||
// _Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!_
|
||||
func ExistInSliceWithRegex(str string, list []string) bool {
|
||||
for _, field := range list {
|
||||
isRegex := strings.HasPrefix(field, "^") && strings.HasSuffix(field, "$")
|
||||
|
||||
if !isRegex {
|
||||
// check for direct match
|
||||
if str == field {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// check for regex match
|
||||
pattern, ok := cachedPatterns[field]
|
||||
if !ok {
|
||||
var patternErr error
|
||||
pattern, patternErr = regexp.Compile(field)
|
||||
if patternErr != nil {
|
||||
continue
|
||||
}
|
||||
// "cache" the pattern to avoid compiling it every time
|
||||
cachedPatterns[field] = pattern
|
||||
}
|
||||
if pattern != nil && pattern.MatchString(str) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ToInterfaceSlice converts a generic slice to slice of interfaces.
|
||||
func ToInterfaceSlice[T any](list []T) []any {
|
||||
result := make([]any, len(list))
|
||||
|
||||
for i := range list {
|
||||
result[i] = list[i]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NonzeroUniques returns only the nonzero unique values from a slice.
|
||||
func NonzeroUniques[T comparable](list []T) []T {
|
||||
result := []T{}
|
||||
existMap := map[T]bool{}
|
||||
|
||||
var zeroVal T
|
||||
|
||||
for _, val := range list {
|
||||
if !existMap[val] && val != zeroVal {
|
||||
existMap[val] = true
|
||||
result = append(result, val)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToUniqueStringSlice casts `value` to a slice of non-zero unique strings.
|
||||
func ToUniqueStringSlice(value any) []string {
|
||||
strings := []string{}
|
||||
|
||||
switch val := value.(type) {
|
||||
case nil:
|
||||
// nothing to cast
|
||||
case []string:
|
||||
strings = val
|
||||
case string:
|
||||
if val == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// check if it is a json encoded array of strings
|
||||
if err := json.Unmarshal([]byte(val), &strings); err != nil {
|
||||
// not a json array, just add the string as single array element
|
||||
strings = append(strings, val)
|
||||
}
|
||||
case json.Marshaler: // eg. JsonArray
|
||||
raw, _ := val.MarshalJSON()
|
||||
json.Unmarshal(raw, &strings)
|
||||
default:
|
||||
strings = cast.ToStringSlice(value)
|
||||
}
|
||||
|
||||
return NonzeroUniques(strings)
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package list_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestExistInSliceString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item string
|
||||
list []string
|
||||
expected bool
|
||||
}{
|
||||
{"", []string{""}, true},
|
||||
{"", []string{"1", "2", "test 123"}, false},
|
||||
{"test", []string{}, false},
|
||||
{"test", []string{"TEST"}, false},
|
||||
{"test", []string{"1", "2", "test 123"}, false},
|
||||
{"test", []string{"1", "2", "test"}, true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ExistInSlice(scenario.item, scenario.list)
|
||||
if result != scenario.expected {
|
||||
if scenario.expected {
|
||||
t.Errorf("(%d) Expected to exist in the list", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Expected NOT to exist in the list", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistInSliceInt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item int
|
||||
list []int
|
||||
expected bool
|
||||
}{
|
||||
{0, []int{}, false},
|
||||
{0, []int{0}, true},
|
||||
{4, []int{1, 2, 3}, false},
|
||||
{1, []int{1, 2, 3}, true},
|
||||
{-1, []int{0, 1, 2, 3}, false},
|
||||
{-1, []int{0, -1, -2, -3, -4}, true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ExistInSlice(scenario.item, scenario.list)
|
||||
if result != scenario.expected {
|
||||
if scenario.expected {
|
||||
t.Errorf("(%d) Expected to exist in the list", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Expected NOT to exist in the list", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistInSliceWithRegex(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item string
|
||||
list []string
|
||||
expected bool
|
||||
}{
|
||||
{"", []string{``}, true},
|
||||
{"", []string{`^\W+$`}, false},
|
||||
{" ", []string{`^\W+$`}, true},
|
||||
{"test", []string{`^\invalid[+$`}, false}, // invalid regex
|
||||
{"test", []string{`^\W+$`, "test"}, true},
|
||||
{`^\W+$`, []string{`^\W+$`, "test"}, false}, // direct match shouldn't work for this case
|
||||
{`\W+$`, []string{`\W+$`, "test"}, true}, // direct match should work for this case because it is not an actual supported pattern format
|
||||
{"!?@", []string{`\W+$`, "test"}, false}, // the method requires the pattern elems to start with '^'
|
||||
{"!?@", []string{`^\W+`, "test"}, false}, // the method requires the pattern elems to end with '$'
|
||||
{"!?@", []string{`^\W+$`, "test"}, true},
|
||||
{"!?@test", []string{`^\W+$`, "test"}, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ExistInSliceWithRegex(scenario.item, scenario.list)
|
||||
if result != scenario.expected {
|
||||
if scenario.expected {
|
||||
t.Errorf("(%d) Expected the string to exist in the list", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Expected the string NOT to exist in the list", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToInterfaceSlice(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
items []string
|
||||
}{
|
||||
{[]string{}},
|
||||
{[]string{""}},
|
||||
{[]string{"1", "test"}},
|
||||
{[]string{"test1", "test2", "test3"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ToInterfaceSlice(scenario.items)
|
||||
|
||||
if len(result) != len(scenario.items) {
|
||||
t.Errorf("(%d) Result list length doesn't match with the original list", i)
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != scenario.items[j] {
|
||||
t.Errorf("(%d:%d) Result list item should match with the original list item", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonzeroUniquesString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
items []string
|
||||
expected []string
|
||||
}{
|
||||
{[]string{}, []string{}},
|
||||
{[]string{""}, []string{}},
|
||||
{[]string{"1", "test"}, []string{"1", "test"}},
|
||||
{[]string{"test1", "", "test2", "Test2", "test1", "test3"}, []string{"test1", "test2", "Test2", "test3"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.NonzeroUniques(scenario.items)
|
||||
|
||||
if len(result) != len(scenario.expected) {
|
||||
t.Errorf("(%d) Result list length doesn't match with the expected list", i)
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != scenario.expected[j] {
|
||||
t.Errorf("(%d:%d) Result list item should match with the expected list item", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToUniqueStringSlice(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expected []string
|
||||
}{
|
||||
{nil, []string{}},
|
||||
{"", []string{}},
|
||||
{[]any{}, []string{}},
|
||||
{[]int{}, []string{}},
|
||||
{"test", []string{"test"}},
|
||||
{[]int{1, 2, 3}, []string{"1", "2", "3"}},
|
||||
{[]any{0, 1, "test", ""}, []string{"0", "1", "test"}},
|
||||
{[]string{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
{`["test1", "test2", "test2"]`, []string{"test1", "test2"}},
|
||||
{types.JsonArray{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ToUniqueStringSlice(scenario.value)
|
||||
|
||||
if len(result) != len(scenario.expected) {
|
||||
t.Errorf("(%d) Result list length doesn't match with the expected list", i)
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != scenario.expected[j] {
|
||||
t.Errorf("(%d:%d) Result list item should match with the expected list item", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package mailer
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/mail"
|
||||
)
|
||||
|
||||
// Mailer defines a base mail client interface.
|
||||
type Mailer interface {
|
||||
// Send sends an email with HTML body to the specified recipient.
|
||||
Send(
|
||||
fromEmail mail.Address,
|
||||
toEmail mail.Address,
|
||||
subject string,
|
||||
htmlBody string,
|
||||
attachments map[string]io.Reader,
|
||||
) error
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package mailer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
var _ Mailer = (*Sendmail)(nil)
|
||||
|
||||
// Sendmail implements `mailer.Mailer` interface and defines a mail
|
||||
// client that sends emails via the `sendmail` *nix command.
|
||||
//
|
||||
// This client is usually recommended only for development and testing.
|
||||
type Sendmail struct {
|
||||
}
|
||||
|
||||
// Send implements `mailer.Mailer` interface.
|
||||
//
|
||||
// Attachments are currently not supported.
|
||||
func (m *Sendmail) Send(
|
||||
fromEmail mail.Address,
|
||||
toEmail mail.Address,
|
||||
subject string,
|
||||
htmlBody string,
|
||||
attachments map[string]io.Reader,
|
||||
) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Subject", mime.QEncoding.Encode("utf-8", subject))
|
||||
headers.Set("From", fromEmail.String())
|
||||
headers.Set("To", toEmail.String())
|
||||
headers.Set("Content-Type", "text/html; charset=UTF-8")
|
||||
|
||||
cmdPath, err := findSendmailPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
|
||||
// write
|
||||
// ---
|
||||
if err := headers.Write(&buffer); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := buffer.Write([]byte("\r\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := buffer.Write([]byte(htmlBody)); err != nil {
|
||||
return err
|
||||
}
|
||||
// ---
|
||||
|
||||
sendmail := exec.Command(cmdPath, toEmail.Address)
|
||||
sendmail.Stdin = &buffer
|
||||
|
||||
return sendmail.Run()
|
||||
}
|
||||
|
||||
func findSendmailPath() (string, error) {
|
||||
options := []string{
|
||||
"/usr/sbin/sendmail",
|
||||
"/usr/bin/sendmail",
|
||||
"sendmail",
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
path, err := exec.LookPath(option)
|
||||
if err == nil {
|
||||
return path, err
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("Failed to locate a sendmail executable path.")
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package mailer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/mail"
|
||||
"net/smtp"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/domodwyer/mailyak/v3"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
)
|
||||
|
||||
var _ Mailer = (*SmtpClient)(nil)
|
||||
|
||||
// regex to select all tabs
|
||||
var tabsRegex = regexp.MustCompile(`\t+`)
|
||||
|
||||
// NewSmtpClient creates new `SmtpClient` with the provided configuration.
|
||||
func NewSmtpClient(
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
tls bool,
|
||||
) *SmtpClient {
|
||||
return &SmtpClient{
|
||||
host: host,
|
||||
port: port,
|
||||
username: username,
|
||||
password: password,
|
||||
tls: tls,
|
||||
}
|
||||
}
|
||||
|
||||
// SmtpClient defines a SMTP mail client structure that implements
|
||||
// `mailer.Mailer` interface.
|
||||
type SmtpClient struct {
|
||||
mail *mailyak.MailYak
|
||||
|
||||
host string
|
||||
port int
|
||||
username string
|
||||
password string
|
||||
tls bool
|
||||
}
|
||||
|
||||
// Send implements `mailer.Mailer` interface.
|
||||
func (m *SmtpClient) Send(
|
||||
fromEmail mail.Address,
|
||||
toEmail mail.Address,
|
||||
subject string,
|
||||
htmlBody string,
|
||||
attachments map[string]io.Reader,
|
||||
) error {
|
||||
smtpAuth := smtp.PlainAuth("", m.username, m.password, m.host)
|
||||
|
||||
// create mail instance
|
||||
var yak *mailyak.MailYak
|
||||
if m.tls {
|
||||
var tlsErr error
|
||||
yak, tlsErr = mailyak.NewWithTLS(fmt.Sprintf("%s:%d", m.host, m.port), smtpAuth, nil)
|
||||
if tlsErr != nil {
|
||||
return tlsErr
|
||||
}
|
||||
} else {
|
||||
yak = mailyak.New(fmt.Sprintf("%s:%d", m.host, m.port), smtpAuth)
|
||||
}
|
||||
|
||||
if fromEmail.Name != "" {
|
||||
yak.FromName(fromEmail.Name)
|
||||
}
|
||||
yak.From(fromEmail.Address)
|
||||
yak.To(toEmail.Address)
|
||||
yak.Subject(subject)
|
||||
yak.HTML().Set(htmlBody)
|
||||
|
||||
// set also plain text content
|
||||
policy := bluemonday.StrictPolicy() // strips all tags
|
||||
yak.Plain().Set(strings.TrimSpace(tabsRegex.ReplaceAllString(policy.Sanitize(htmlBody), "")))
|
||||
|
||||
for name, data := range attachments {
|
||||
yak.Attach(name, data)
|
||||
}
|
||||
|
||||
return yak.Send()
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type migration struct {
|
||||
file string
|
||||
up func(db dbx.Builder) error
|
||||
down func(db dbx.Builder) error
|
||||
}
|
||||
|
||||
// MigrationsList defines a list with migration definitions
|
||||
type MigrationsList struct {
|
||||
list []*migration
|
||||
}
|
||||
|
||||
// Item returns a single migration from the list by its index.
|
||||
func (l *MigrationsList) Item(index int) *migration {
|
||||
return l.list[index]
|
||||
}
|
||||
|
||||
// Items returns the internal migrations list slice.
|
||||
func (l *MigrationsList) Items() []*migration {
|
||||
return l.list
|
||||
}
|
||||
|
||||
// Register adds new migration definition to the list.
|
||||
//
|
||||
// If `optFilename` is not provided, it will try to get the name from its .go file.
|
||||
//
|
||||
// The list will be sorted automatically based on the migrations file name.
|
||||
func (l *MigrationsList) Register(
|
||||
up func(db dbx.Builder) error,
|
||||
down func(db dbx.Builder) error,
|
||||
optFilename ...string,
|
||||
) {
|
||||
var file string
|
||||
if len(optFilename) > 0 {
|
||||
file = optFilename[0]
|
||||
} else {
|
||||
_, path, _, _ := runtime.Caller(1)
|
||||
file = filepath.Base(path)
|
||||
}
|
||||
|
||||
l.list = append(l.list, &migration{
|
||||
file: file,
|
||||
up: up,
|
||||
down: down,
|
||||
})
|
||||
|
||||
sort.Slice(l.list, func(i int, j int) bool {
|
||||
return l.list[i].file < l.list[j].file
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigrationsList(t *testing.T) {
|
||||
l := MigrationsList{}
|
||||
|
||||
l.Register(nil, nil, "3_test.go")
|
||||
l.Register(nil, nil, "1_test.go")
|
||||
l.Register(nil, nil, "2_test.go")
|
||||
l.Register(nil, nil /* auto detect file name */)
|
||||
|
||||
expected := []string{
|
||||
"1_test.go",
|
||||
"2_test.go",
|
||||
"3_test.go",
|
||||
"list_test.go",
|
||||
}
|
||||
|
||||
items := l.Items()
|
||||
if len(items) != len(expected) {
|
||||
t.Fatalf("Expected %d items, got %d: \n%#v", len(expected), len(items), items)
|
||||
}
|
||||
|
||||
for i, name := range expected {
|
||||
item := l.Item(i)
|
||||
if item.file != name {
|
||||
t.Fatalf("Expected name %s for index %d, got %s", name, i, item.file)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const migrationsTable = "_migrations"
|
||||
|
||||
// Runner defines a simple struct for managing the execution of db migrations.
|
||||
type Runner struct {
|
||||
db *dbx.DB
|
||||
migrationsList MigrationsList
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewRunner creates and initializes a new db migrations Runner instance.
|
||||
func NewRunner(db *dbx.DB, migrationsList MigrationsList) (*Runner, error) {
|
||||
runner := &Runner{
|
||||
db: db,
|
||||
migrationsList: migrationsList,
|
||||
tableName: migrationsTable,
|
||||
}
|
||||
|
||||
if err := runner.createMigrationsTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
// Run interactively executes the current runner with the provided args.
|
||||
//
|
||||
// The following commands are supported:
|
||||
// - up - applies all migrations
|
||||
// - down [n] - reverts the last n applied migrations
|
||||
// - create NEW_MIGRATION_NAME - create NEW_MIGRATION_NAME.go file from a migration template
|
||||
func (r *Runner) Run(args ...string) error {
|
||||
cmd := "up"
|
||||
if len(args) > 0 {
|
||||
cmd = args[0]
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "up":
|
||||
applied, err := r.Up()
|
||||
if err != nil {
|
||||
color.Red(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if len(applied) == 0 {
|
||||
color.Green("No new migrations to apply.")
|
||||
} else {
|
||||
for _, file := range applied {
|
||||
color.Green("Applied %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "down":
|
||||
toRevertCount := 1
|
||||
if len(args) > 1 {
|
||||
toRevertCount = cast.ToInt(args[1])
|
||||
if toRevertCount < 0 {
|
||||
// revert all applied migrations
|
||||
toRevertCount = len(r.migrationsList.Items())
|
||||
}
|
||||
}
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you really want to revert the last %d applied migration(s)?", toRevertCount),
|
||||
}
|
||||
survey.AskOne(prompt, &confirm)
|
||||
if !confirm {
|
||||
fmt.Println("The command has been cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
reverted, err := r.Down(toRevertCount)
|
||||
if err != nil {
|
||||
color.Red(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if len(reverted) == 0 {
|
||||
color.Green("No migrations to revert.")
|
||||
} else {
|
||||
for _, file := range reverted {
|
||||
color.Green("Reverted %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "create":
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("Missing migration file name")
|
||||
}
|
||||
|
||||
name := args[1]
|
||||
|
||||
var dir string
|
||||
if len(args) == 3 {
|
||||
dir = args[2]
|
||||
}
|
||||
if dir == "" {
|
||||
// If not specified, auto point to the default migrations folder.
|
||||
//
|
||||
// NB!
|
||||
// Since the create command makes sense only during development,
|
||||
// it is expected the user to be in the app working directory
|
||||
// and to be using `go run ...`
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dir = path.Join(wd, "migrations")
|
||||
}
|
||||
|
||||
resultFilePath := path.Join(
|
||||
dir,
|
||||
fmt.Sprintf("%d_%s.go", time.Now().Unix(), inflector.Snakecase(name)),
|
||||
)
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you really want to create migration %q?", resultFilePath),
|
||||
}
|
||||
survey.AskOne(prompt, &confirm)
|
||||
if !confirm {
|
||||
fmt.Println("The command has been cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensure that migrations dir exist
|
||||
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resultFilePath, []byte(createTemplateContent), 0644); err != nil {
|
||||
return fmt.Errorf("Failed to save migration file %q\n", resultFilePath)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully created file %q\n", resultFilePath)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Unsupported command: %q\n", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Up executes all unapplied migrations for the provided runner.
|
||||
//
|
||||
// On success returns list with the applied migrations file names.
|
||||
func (r *Runner) Up() ([]string, error) {
|
||||
applied := []string{}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
// skip applied
|
||||
if r.isMigrationApplied(tx, m.file) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.up(tx); err != nil {
|
||||
return fmt.Errorf("Failed to apply migration %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
if err := r.saveAppliedMigration(tx, m.file); err != nil {
|
||||
return fmt.Errorf("Failed to save applied migration info for %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
applied = append(applied, m.file)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// Down reverts the last `toRevertCount` applied migrations.
|
||||
//
|
||||
// On success returns list with the reverted migrations file names.
|
||||
func (r *Runner) Down(toRevertCount int) ([]string, error) {
|
||||
applied := []string{}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
totalReverted := 0
|
||||
|
||||
for i := len(r.migrationsList.Items()) - 1; i >= 0; i-- {
|
||||
m := r.migrationsList.Item(i)
|
||||
|
||||
// skip unapplied
|
||||
if !r.isMigrationApplied(tx, m.file) {
|
||||
continue
|
||||
}
|
||||
|
||||
// revert limit reached
|
||||
if toRevertCount-totalReverted <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if err := m.down(tx); err != nil {
|
||||
return fmt.Errorf("Failed to revert migration %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
if err := r.saveRevertedMigration(tx, m.file); err != nil {
|
||||
return fmt.Errorf("Failed to save reverted migration info for %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
applied = append(applied, m.file)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *Runner) createMigrationsTable() error {
|
||||
rawQuery := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %v (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
r.db.QuoteTableName(r.tableName),
|
||||
)
|
||||
|
||||
_, err := r.db.NewQuery(rawQuery).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) isMigrationApplied(tx dbx.Builder, file string) bool {
|
||||
var exists bool
|
||||
|
||||
err := tx.Select("count(*)").
|
||||
From(r.tableName).
|
||||
Where(dbx.HashExp{"file": file}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists
|
||||
}
|
||||
|
||||
func (r *Runner) saveAppliedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Insert(r.tableName, dbx.Params{
|
||||
"file": file,
|
||||
"applied": time.Now().Unix(),
|
||||
}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) saveRevertedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Delete(r.tableName, dbx.HashExp{"file": file}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestNewRunner(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(nil, nil, "1_test.go")
|
||||
l.Register(nil, nil, "2_test.go")
|
||||
l.Register(nil, nil, "3_test.go")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(r.migrationsList.Items()) != len(l.Items()) {
|
||||
t.Fatalf("Expected the same migrations list to be assigned, got \n%#v", r.migrationsList)
|
||||
}
|
||||
|
||||
expectedQueries := []string{
|
||||
"CREATE TABLE IF NOT EXISTS `_migrations` (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
}
|
||||
if len(expectedQueries) != len(testDB.CalledQueries) {
|
||||
t.Fatalf("Expected %d queries, got %d: \n%v", len(expectedQueries), len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
for _, q := range expectedQueries {
|
||||
if !list.ExistInSlice(q, testDB.CalledQueries) {
|
||||
t.Fatalf("Query %s was not found in \n%v", q, testDB.CalledQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerUpAndDown(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
var test1UpCalled bool
|
||||
var test1DownCalled bool
|
||||
var test2UpCalled bool
|
||||
var test2DownCalled bool
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
test1UpCalled = true
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
test1DownCalled = true
|
||||
return nil
|
||||
}, "1_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
test2UpCalled = true
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
test2DownCalled = true
|
||||
return nil
|
||||
}, "2_test")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// simulate partially run migration
|
||||
r.saveAppliedMigration(testDB, r.migrationsList.Item(0).file)
|
||||
|
||||
// Up()
|
||||
// ---
|
||||
if _, err := r.Up(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test1UpCalled {
|
||||
t.Fatalf("Didn't expect 1_test to be called")
|
||||
}
|
||||
|
||||
if !test2UpCalled {
|
||||
t.Fatalf("Expected 2_test to be called")
|
||||
}
|
||||
|
||||
// simulate unrun migration
|
||||
var test3DownCalled bool
|
||||
r.migrationsList.Register(nil, func(db dbx.Builder) error {
|
||||
test3DownCalled = true
|
||||
return nil
|
||||
}, "3_test")
|
||||
|
||||
// Down()
|
||||
// ---
|
||||
if _, err := r.Down(2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test3DownCalled {
|
||||
t.Fatal("Didn't expect 3_test to be reverted.")
|
||||
}
|
||||
|
||||
if !test1DownCalled || !test2DownCalled {
|
||||
t.Fatalf("Expected 1_test and 2_test to be reverted, got %v and %v", test1DownCalled, test2DownCalled)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testDB struct {
|
||||
*dbx.DB
|
||||
CalledQueries []string
|
||||
}
|
||||
|
||||
// NB! Don't forget to call `db.Close()` at the end of the test.
|
||||
func createTestDB() (*testDB, error) {
|
||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package migrate
|
||||
|
||||
const createTemplateContent = `package migrations
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
m "github.com/pocketbase/pocketbase/migrations"
|
||||
)
|
||||
|
||||
func init() {
|
||||
m.Register(func(db dbx.Builder) error {
|
||||
// add up queries...
|
||||
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
// add down queries...
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
@@ -0,0 +1,107 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
// ApiError defines the properties for a basic api error response.
|
||||
type ApiError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]any `json:"data"`
|
||||
|
||||
// stores unformatted error data (could be an internal error, text, etc.)
|
||||
rawData any
|
||||
}
|
||||
|
||||
// Error makes it compatible with the `error` interface.
|
||||
func (e *ApiError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *ApiError) RawData() any {
|
||||
return e.rawData
|
||||
}
|
||||
|
||||
// NewNotFoundError creates and returns 404 `ApiError`.
|
||||
func NewNotFoundError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "The requested resource wasn't found."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusNotFound, message, data)
|
||||
}
|
||||
|
||||
// NewBadRequestError creates and returns 400 `ApiError`.
|
||||
func NewBadRequestError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusBadRequest, message, data)
|
||||
}
|
||||
|
||||
// NewForbiddenError creates and returns 403 `ApiError`.
|
||||
func NewForbiddenError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "You are not allowed to perform this request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusForbidden, message, data)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError creates and returns 401 `ApiError`.
|
||||
func NewUnauthorizedError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Missing or invalid authentication token."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusUnauthorized, message, data)
|
||||
}
|
||||
|
||||
// NewApiError creates and returns new normalized `ApiError` instance.
|
||||
func NewApiError(status int, message string, data any) *ApiError {
|
||||
message = inflector.Sentenize(message)
|
||||
|
||||
formattedData := map[string]any{}
|
||||
|
||||
if v, ok := data.(validation.Errors); ok {
|
||||
formattedData = resolveValidationErrors(v)
|
||||
}
|
||||
|
||||
return &ApiError{
|
||||
rawData: data,
|
||||
Data: formattedData,
|
||||
Code: status,
|
||||
Message: strings.TrimSpace(message),
|
||||
}
|
||||
}
|
||||
|
||||
func resolveValidationErrors(validationErrors validation.Errors) map[string]any {
|
||||
result := map[string]any{}
|
||||
|
||||
// extract from each validation error its error code and message.
|
||||
for name, err := range validationErrors {
|
||||
// check for nested errors
|
||||
if nestedErrs, ok := err.(validation.Errors); ok {
|
||||
result[name] = resolveValidationErrors(nestedErrs)
|
||||
continue
|
||||
}
|
||||
|
||||
errCode := "validation_invalid_value" // default
|
||||
if errObj, ok := err.(validation.ErrorObject); ok {
|
||||
errCode = errObj.Code()
|
||||
}
|
||||
|
||||
result[name] = map[string]string{
|
||||
"code": errCode,
|
||||
"message": inflector.Sentenize(err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestNewApiErrorWithRawData(t *testing.T) {
|
||||
e := rest.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
"rawData_test",
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"code":300,"message":"Message_test.","data":{}}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() != "rawData_test" {
|
||||
t.Errorf("Expected rawData %v, got %v", "rawData_test", e.RawData())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewApiErrorWithValidationData(t *testing.T) {
|
||||
e := rest.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
validation.Errors{
|
||||
"err1": errors.New("test error"),
|
||||
"err2": validation.ErrRequired,
|
||||
"err3": validation.Errors{
|
||||
"sub1": errors.New("test error"),
|
||||
"sub2": validation.ErrRequired,
|
||||
"sub3": validation.Errors{
|
||||
"sub11": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"code":300,"message":"Message_test.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"sub1":{"code":"validation_invalid_value","message":"Test error."},"sub2":{"code":"validation_required","message":"Cannot be blank."},"sub3":{"sub11":{"code":"validation_required","message":"Cannot be blank."}}}}}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() == nil {
|
||||
t.Error("Expected non-nil rawData")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":404,"message":"The requested resource wasn't found.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":404,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":404,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewNotFoundError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadRequestError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":400,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":400,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewBadRequestError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForbiddenError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":403,"message":"You are not allowed to perform this request.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":403,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":403,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewForbiddenError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":401,"message":"Missing or invalid authentication token.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":401,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":401,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewUnauthorizedError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
)
|
||||
|
||||
// BindBody binds request body content to i.
|
||||
//
|
||||
// This is similar to `echo.BindBody()`, but for JSON requests uses
|
||||
// custom json reader that **copies** the request body, allowing multiple reads.
|
||||
func BindBody(c echo.Context, i interface{}) error {
|
||||
req := c.Request()
|
||||
if req.ContentLength == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctype := req.Header.Get(echo.HeaderContentType)
|
||||
switch {
|
||||
case strings.HasPrefix(ctype, echo.MIMEApplicationJSON):
|
||||
err := ReadJsonBodyCopy(c.Request(), i)
|
||||
if err != nil {
|
||||
return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
// fallback to the default binder
|
||||
return echo.BindBody(c, i)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadJsonBodyCopy reads the request body into i by
|
||||
// creating a copy of `r.Body` to allow multiple reads.
|
||||
func ReadJsonBodyCopy(r *http.Request, i interface{}) error {
|
||||
body := r.Body
|
||||
|
||||
// this usually shouldn't be needed because the Server calls close for us
|
||||
// but we are changing the request body with a new reader
|
||||
defer body.Close()
|
||||
|
||||
limitReader := io.LimitReader(body, DefaultMaxMemory)
|
||||
|
||||
bodyBytes, readErr := io.ReadAll(limitReader)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
|
||||
err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(i)
|
||||
|
||||
// set new body reader
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestBindBody(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
body io.Reader
|
||||
contentType string
|
||||
result map[string]string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
strings.NewReader(""),
|
||||
echo.MIMEApplicationJSON,
|
||||
map[string]string{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
strings.NewReader(`{"test":"invalid`),
|
||||
echo.MIMEApplicationJSON,
|
||||
map[string]string{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
strings.NewReader(`{"test":"test123"}`),
|
||||
echo.MIMEApplicationJSON,
|
||||
map[string]string{"test": "test123"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
strings.NewReader(url.Values{"test": []string{"test123"}}.Encode()),
|
||||
echo.MIMEApplicationForm,
|
||||
map[string]string{"test": "test123"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", scenario.body)
|
||||
req.Header.Set(echo.HeaderContentType, scenario.contentType)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
result := map[string]string{}
|
||||
err := rest.BindBody(c, &result)
|
||||
|
||||
if err == nil && scenario.expectError {
|
||||
t.Errorf("(%d) Expected error, got nil", i)
|
||||
}
|
||||
|
||||
if err != nil && !scenario.expectError {
|
||||
t.Errorf("(%d) Expected nil, got error %v", i, err)
|
||||
}
|
||||
|
||||
if len(result) != len(scenario.result) {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.result, result)
|
||||
}
|
||||
|
||||
for k, v := range result {
|
||||
if sv, ok := scenario.result[k]; !ok || v != sv {
|
||||
t.Errorf("(%d) Expected value %v for key %s, got %v", i, sv, k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadJsonBodyCopy(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"test":"test123"}`))
|
||||
|
||||
// simulate multiple reads from the same request
|
||||
result1 := map[string]string{}
|
||||
rest.ReadJsonBodyCopy(req, &result1)
|
||||
result2 := map[string]string{}
|
||||
rest.ReadJsonBodyCopy(req, &result2)
|
||||
|
||||
if len(result1) == 0 {
|
||||
t.Error("Expected result1 to be filled")
|
||||
}
|
||||
|
||||
if len(result2) == 0 {
|
||||
t.Error("Expected result2 to be filled")
|
||||
}
|
||||
|
||||
if v, ok := result1["test"]; !ok || v != "test123" {
|
||||
t.Errorf("Expected result1.test to be %q, got %q", "test123", v)
|
||||
}
|
||||
|
||||
if v, ok := result2["test"]; !ok || v != "test123" {
|
||||
t.Errorf("Expected result2.test to be %q, got %q", "test123", v)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// DefaultMaxMemory defines the default max memory bytes that
|
||||
// will be used when parsing a form request body.
|
||||
const DefaultMaxMemory = 32 << 20 // 32mb
|
||||
|
||||
// UploadedFile defines a single multipart uploaded file instance.
|
||||
type UploadedFile struct {
|
||||
name string
|
||||
header *multipart.FileHeader
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// Name returns an assigned unique name to the uploaded file.
|
||||
func (f *UploadedFile) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
// Header returns the file header that comes with the multipart request.
|
||||
func (f *UploadedFile) Header() *multipart.FileHeader {
|
||||
return f.header
|
||||
}
|
||||
|
||||
// Bytes returns a slice with the file content.
|
||||
func (f *UploadedFile) Bytes() []byte {
|
||||
return f.bytes
|
||||
}
|
||||
|
||||
// FindUploadedFiles extracts all form files of `key` from a http request
|
||||
// and returns a slice with `UploadedFile` instances (if any).
|
||||
func FindUploadedFiles(r *http.Request, key string) ([]*UploadedFile, error) {
|
||||
if r.MultipartForm == nil {
|
||||
err := r.ParseMultipartForm(DefaultMaxMemory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 {
|
||||
return nil, http.ErrMissingFile
|
||||
}
|
||||
|
||||
result := make([]*UploadedFile, len(r.MultipartForm.File[key]))
|
||||
|
||||
for i, fh := range r.MultipartForm.File[key] {
|
||||
file, err := fh.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
if _, err := io.Copy(buf, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result[i] = &UploadedFile{
|
||||
name: fmt.Sprintf("%s%s", security.RandomString(32), filepath.Ext(fh.Filename)),
|
||||
header: fh,
|
||||
bytes: buf.Bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestFindUploadedFiles(t *testing.T) {
|
||||
// create a test temporary file
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), "tmpfile-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := tmpFile.Write([]byte("test")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tmpFile.Seek(0, 0)
|
||||
defer tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
// ---
|
||||
|
||||
// stub multipart form file body
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.Copy(w, tmpFile); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mp.Close()
|
||||
// ---
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := rest.FindUploadedFiles(req, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 file, got %d", len(result))
|
||||
}
|
||||
|
||||
if result[0].Header().Size != 4 {
|
||||
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Header().Size)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(result[0].Name(), ".txt") {
|
||||
t.Fatalf("Expected the file name to have suffix .txt - %v", result[0].Name())
|
||||
}
|
||||
|
||||
if string(result[0].Bytes()) != "test" {
|
||||
t.Fatalf("Expected the file content to be %q, got %q", "test", string(result[0].Bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUploadedFilesMissing(t *testing.T) {
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
mp.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := rest.FindUploadedFiles(req, "test")
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected result to be nil, got %v", result)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package routine
|
||||
|
||||
import (
|
||||
"log"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FireAndForget executes `f()` in a new go routine and auto recovers if panic.
|
||||
//
|
||||
// **Note:** Use this only if you are not interested in the result of `f()`
|
||||
// and don't want to block the parent go routine.
|
||||
func FireAndForget(f func(), wg ...*sync.WaitGroup) {
|
||||
if len(wg) > 0 && wg[0] != nil {
|
||||
wg[0].Add(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if len(wg) > 0 && wg[0] != nil {
|
||||
defer wg[0].Done()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("RECOVERED FROM PANIC: %v", err)
|
||||
log.Printf("%s\n", string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
f()
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package routine_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func TestFireAndForget(t *testing.T) {
|
||||
called := false
|
||||
|
||||
fn := func() {
|
||||
called = true
|
||||
panic("test")
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
routine.FireAndForget(fn, wg)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if !called {
|
||||
t.Error("Expected fn to be called.")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ganigeorgiev/fexpr"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// FilterData is a filter expession string following the `fexpr` package grammar.
|
||||
//
|
||||
// Example:
|
||||
// var filter FilterData = "id = null || (name = 'test' && status = true)"
|
||||
// resolver := search.NewSimpleFieldResolver("id", "name", "status")
|
||||
// expr, err := filter.BuildExpr(resolver)
|
||||
type FilterData string
|
||||
|
||||
// parsedFilterData holds a cache with previously parsed filter data expressions
|
||||
// (initialized with some prealocated empty data map)
|
||||
var parsedFilterData = store.New(make(map[string][]fexpr.ExprGroup, 50))
|
||||
|
||||
// BuildExpr parses the current filter data and returns a new db WHERE expression.
|
||||
func (f FilterData) BuildExpr(fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
raw := string(f)
|
||||
var data []fexpr.ExprGroup
|
||||
|
||||
if parsedFilterData.Has(raw) {
|
||||
data = parsedFilterData.Get(raw)
|
||||
} else {
|
||||
var err error
|
||||
data, err = fexpr.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// store in cache
|
||||
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
|
||||
parsedFilterData.SetIfLessThanLimit(raw, data, 500)
|
||||
}
|
||||
|
||||
return f.build(data, fieldResolver)
|
||||
}
|
||||
|
||||
func (f FilterData) build(data []fexpr.ExprGroup, fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("Empty filter expression.")
|
||||
}
|
||||
|
||||
var result dbx.Expression
|
||||
|
||||
for _, group := range data {
|
||||
var expr dbx.Expression
|
||||
var exprErr error
|
||||
|
||||
switch item := group.Item.(type) {
|
||||
case fexpr.Expr:
|
||||
expr, exprErr = f.resolveTokenizedExpr(item, fieldResolver)
|
||||
case fexpr.ExprGroup:
|
||||
expr, exprErr = f.build([]fexpr.ExprGroup{item}, fieldResolver)
|
||||
case []fexpr.ExprGroup:
|
||||
expr, exprErr = f.build(item, fieldResolver)
|
||||
default:
|
||||
exprErr = errors.New("Unsupported expression item.")
|
||||
}
|
||||
|
||||
if exprErr != nil {
|
||||
return nil, exprErr
|
||||
}
|
||||
|
||||
if group.Join == fexpr.JoinAnd {
|
||||
result = dbx.And(result, expr)
|
||||
} else {
|
||||
result = dbx.Or(result, expr)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldResolver) (dbx.Expression, error) {
|
||||
lName, lParams, lErr := f.resolveToken(expr.Left, fieldResolver)
|
||||
if lName == "" || lErr != nil {
|
||||
return nil, fmt.Errorf("Invalid left operand %q - %v.", expr.Left.Literal, lErr)
|
||||
}
|
||||
|
||||
rName, rParams, rErr := f.resolveToken(expr.Right, fieldResolver)
|
||||
if rName == "" || rErr != nil {
|
||||
return nil, fmt.Errorf("Invalid right operand %q - %v.", expr.Right.Literal, rErr)
|
||||
}
|
||||
|
||||
// merge both operands parameters (if any)
|
||||
params := dbx.Params{}
|
||||
if len(lParams) > 0 {
|
||||
for k, v := range lParams {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
if len(rParams) > 0 {
|
||||
for k, v := range rParams {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
switch expr.Op {
|
||||
case fexpr.SignEq:
|
||||
op := "="
|
||||
if strings.ToLower(lName) == "null" || strings.ToLower(rName) == "null" {
|
||||
op = "IS"
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s %s %s", lName, op, rName), params), nil
|
||||
case fexpr.SignNeq:
|
||||
op := "!="
|
||||
if strings.ToLower(lName) == "null" || strings.ToLower(rName) == "null" {
|
||||
op = "IS NOT"
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s %s %s", lName, op, rName), params), nil
|
||||
case fexpr.SignLike:
|
||||
// normalize operands and switch sides if the left operand is a number or text
|
||||
if len(lParams) > 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", rName, lName), f.normalizeLikeParams(params)), nil
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", lName, rName), f.normalizeLikeParams(params)), nil
|
||||
case fexpr.SignNlike:
|
||||
// normalize operands and switch sides if the left operand is a number or text
|
||||
if len(lParams) > 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", rName, lName), f.normalizeLikeParams(params)), nil
|
||||
}
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", lName, rName), f.normalizeLikeParams(params)), nil
|
||||
case fexpr.SignLt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s < %s", lName, rName), params), nil
|
||||
case fexpr.SignLte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s <= %s", lName, rName), params), nil
|
||||
case fexpr.SignGt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s > %s", lName, rName), params), nil
|
||||
case fexpr.SignGte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s >= %s", lName, rName), params), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown expression operator %q", expr.Op)
|
||||
}
|
||||
|
||||
func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver) (name string, params dbx.Params, err error) {
|
||||
if token.Type == fexpr.TokenIdentifier {
|
||||
name, params, err := fieldResolver.Resolve(token.Literal)
|
||||
|
||||
if name == "" || err != nil {
|
||||
// if `null` field is missing, treat `null` identifier as NULL token
|
||||
if strings.ToLower(token.Literal) == "null" {
|
||||
return "NULL", nil, nil
|
||||
}
|
||||
|
||||
// if `true` field is missing, treat `true` identifier as TRUE token
|
||||
if strings.ToLower(token.Literal) == "true" {
|
||||
return "1", nil, nil
|
||||
}
|
||||
|
||||
// if `false` field is missing, treat `false` identifier as FALSE token
|
||||
if strings.ToLower(token.Literal) == "false" {
|
||||
return "0", nil, nil
|
||||
}
|
||||
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return name, params, err
|
||||
}
|
||||
|
||||
if token.Type == fexpr.TokenNumber || token.Type == fexpr.TokenText {
|
||||
placeholder := "t" + security.RandomString(7)
|
||||
name := fmt.Sprintf("{:%s}", placeholder)
|
||||
params := dbx.Params{placeholder: token.Literal}
|
||||
return name, params, nil
|
||||
}
|
||||
|
||||
return "", nil, errors.New("Unresolvable token type.")
|
||||
}
|
||||
|
||||
func (f FilterData) normalizeLikeParams(params dbx.Params) dbx.Params {
|
||||
result := dbx.Params{}
|
||||
|
||||
if len(params) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
for k, v := range params {
|
||||
vStr := cast.ToString(v)
|
||||
if !strings.Contains(vStr, "%") {
|
||||
vStr = "%" + vStr + "%"
|
||||
}
|
||||
result[k] = vStr
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestFilterDataBuildExpr(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", "test4.sub")
|
||||
|
||||
scenarios := []struct {
|
||||
filterData search.FilterData
|
||||
expectError bool
|
||||
expectPattern string
|
||||
}{
|
||||
// empty
|
||||
{"", true, ""},
|
||||
// invalid format
|
||||
{"(test1 > 1", true, ""},
|
||||
// invalid operator
|
||||
{"test1 + 123", true, ""},
|
||||
// unknown field
|
||||
{"test1 = 'example' && unknown > 1", true, ""},
|
||||
// simple expression
|
||||
{"test1 > 1", false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("[[test1]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}") +
|
||||
"$",
|
||||
},
|
||||
// complex expression
|
||||
{
|
||||
"((test1 > 1) || (test2 != 2)) && test3 ~ '%%example' && test4.sub = null",
|
||||
false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("((([[test1]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}) OR ([[test2]] != {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test4.sub]] IS NULL)") +
|
||||
"$",
|
||||
},
|
||||
// combination of special literals (null, true, false)
|
||||
{
|
||||
"test1=true && test2 != false && test3 = null || test4.sub != null",
|
||||
false,
|
||||
"^" + regexp.QuoteMeta("((([[test1]] = 1) AND ([[test2]] != 0)) AND ([[test3]] IS NULL)) OR ([[test4.sub]] IS NOT NULL)") + "$",
|
||||
},
|
||||
// all operators
|
||||
{
|
||||
"(test1 = test2 || test2 != test3) && (test2 ~ 'example' || test2 !~ '%%abc') && 'switch1%%' ~ test1 && 'switch2' !~ test2 && test3 > 1 && test3 >= 0 && test3 <= 4 && 2 < 5",
|
||||
false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("(((((((([[test1]] = [[test2]]) OR ([[test2]] != [[test3]])) AND (([[test2]] LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}) OR ([[test2]] NOT LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}))) AND ([[test1]] LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test2]] NOT LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] >= {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] <= {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ({:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("} < {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})") +
|
||||
"$",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
expr, err := s.filterData.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
dummyDB := &dbx.DB{}
|
||||
rawSql := expr.Build(dummyDB, map[string]any{})
|
||||
|
||||
pattern := regexp.MustCompile(s.expectPattern)
|
||||
if !pattern.MatchString(rawSql) {
|
||||
t.Errorf("(%d) Pattern %v don't match with expression: \n%v", i, s.expectPattern, rawSql)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// DefaultPerPage specifies the default returned search result items.
|
||||
const DefaultPerPage int = 30
|
||||
|
||||
// MaxPerPage specifies the maximum allowed search result items returned in a single page.
|
||||
const MaxPerPage int = 200
|
||||
|
||||
// url search query params
|
||||
const (
|
||||
PageQueryParam string = "page"
|
||||
PerPageQueryParam string = "perPage"
|
||||
SortQueryParam string = "sort"
|
||||
FilterQueryParam string = "filter"
|
||||
)
|
||||
|
||||
// Result defines the returned search result structure.
|
||||
type Result struct {
|
||||
Page int `json:"page"`
|
||||
PerPage int `json:"perPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
Items any `json:"items"`
|
||||
}
|
||||
|
||||
// Provider represents a single configured search provider instance.
|
||||
type Provider struct {
|
||||
fieldResolver FieldResolver
|
||||
query *dbx.SelectQuery
|
||||
page int
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
}
|
||||
|
||||
// NewProvider creates and returns a new search provider.
|
||||
//
|
||||
// Example:
|
||||
// baseQuery := db.Select("*").From("user")
|
||||
// fieldResolver := search.NewSimpleFieldResolver("id", "name")
|
||||
// models := []*YourDataStruct{}
|
||||
//
|
||||
// result, err := search.NewProvider(fieldResolver).
|
||||
// Query(baseQuery).
|
||||
// ParseAndExec("page=2&filter=id>0&sort=-name", &models)
|
||||
func NewProvider(fieldResolver FieldResolver) *Provider {
|
||||
return &Provider{
|
||||
fieldResolver: fieldResolver,
|
||||
page: 1,
|
||||
perPage: DefaultPerPage,
|
||||
sort: []SortField{},
|
||||
filter: []FilterData{},
|
||||
}
|
||||
}
|
||||
|
||||
// Query sets the base query that will be used to fetch the search items.
|
||||
func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
|
||||
s.query = query
|
||||
return s
|
||||
}
|
||||
|
||||
// Page sets the `page` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `page` value is done during `Exec()`.
|
||||
func (s *Provider) Page(page int) *Provider {
|
||||
s.page = page
|
||||
return s
|
||||
}
|
||||
|
||||
// PerPage sets the `perPage` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `perPage` value is done during `Exec()`.
|
||||
func (s *Provider) PerPage(perPage int) *Provider {
|
||||
s.perPage = perPage
|
||||
return s
|
||||
}
|
||||
|
||||
// Sort sets the `sort` field of the current search provider.
|
||||
func (s *Provider) Sort(sort []SortField) *Provider {
|
||||
s.sort = sort
|
||||
return s
|
||||
}
|
||||
|
||||
// AddSort appends the provided SortField to the existing provider's sort field.
|
||||
func (s *Provider) AddSort(field SortField) *Provider {
|
||||
s.sort = append(s.sort, field)
|
||||
return s
|
||||
}
|
||||
|
||||
// Filter sets the `filter` field of the current search provider.
|
||||
func (s *Provider) Filter(filter []FilterData) *Provider {
|
||||
s.filter = filter
|
||||
return s
|
||||
}
|
||||
|
||||
// AddFilter appends the provided FilterData to the existing provider's filter field.
|
||||
func (s *Provider) AddFilter(filter FilterData) *Provider {
|
||||
if filter != "" {
|
||||
s.filter = append(s.filter, filter)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Parse parses the search query parameter from the provided query string
|
||||
// and assigns the found fields to the current search provider.
|
||||
//
|
||||
// The data from the "sort" and "filter" query parameters are appended
|
||||
// to the existing provider's `sort` and `filter` fields
|
||||
// (aka. using `AddSort` and `AddFilter`).
|
||||
func (s *Provider) Parse(urlQuery string) error {
|
||||
params, err := url.ParseQuery(urlQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawPage := params.Get(PageQueryParam)
|
||||
if rawPage != "" {
|
||||
page, err := strconv.Atoi(rawPage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Page(page)
|
||||
}
|
||||
|
||||
rawPerPage := params.Get(PerPageQueryParam)
|
||||
if rawPerPage != "" {
|
||||
perPage, err := strconv.Atoi(rawPerPage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.PerPage(perPage)
|
||||
}
|
||||
|
||||
rawSort := params.Get(SortQueryParam)
|
||||
if rawSort != "" {
|
||||
for _, sortField := range ParseSortFromString(rawSort) {
|
||||
s.AddSort(sortField)
|
||||
}
|
||||
}
|
||||
|
||||
rawFilter := params.Get(FilterQueryParam)
|
||||
if rawFilter != "" {
|
||||
s.AddFilter(FilterData(rawFilter))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes the search provider and fills/scans
|
||||
// the provided `items` slice with the found models.
|
||||
func (s *Provider) Exec(items any) (*Result, error) {
|
||||
if s.query == nil {
|
||||
return nil, errors.New("Query is not set.")
|
||||
}
|
||||
|
||||
// clone provider's query
|
||||
modelsQuery := *s.query
|
||||
|
||||
// apply filters
|
||||
if len(s.filter) > 0 {
|
||||
for _, f := range s.filter {
|
||||
expr, err := f.BuildExpr(s.fieldResolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != nil {
|
||||
modelsQuery.AndWhere(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply sorting
|
||||
if len(s.sort) > 0 {
|
||||
for _, sortField := range s.sort {
|
||||
expr, err := sortField.BuildExpr(s.fieldResolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != "" {
|
||||
modelsQuery.AndOrderBy(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply field resolver query modifications (if any)
|
||||
updateQueryErr := s.fieldResolver.UpdateQuery(&modelsQuery)
|
||||
if updateQueryErr != nil {
|
||||
return nil, updateQueryErr
|
||||
}
|
||||
|
||||
// count
|
||||
var totalCount int64
|
||||
countQuery := modelsQuery
|
||||
if err := countQuery.Select("count(*)").Row(&totalCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// normalize perPage
|
||||
if s.perPage <= 0 {
|
||||
s.perPage = DefaultPerPage
|
||||
} else if s.perPage > MaxPerPage {
|
||||
s.perPage = MaxPerPage
|
||||
}
|
||||
|
||||
// normalize page accoring to the total count
|
||||
if s.page <= 0 || totalCount == 0 {
|
||||
s.page = 1
|
||||
} else if totalPages := int(math.Ceil(float64(totalCount) / float64(s.perPage))); s.page > totalPages {
|
||||
s.page = totalPages
|
||||
}
|
||||
|
||||
// apply pagination
|
||||
modelsQuery.Limit(int64(s.perPage))
|
||||
modelsQuery.Offset(int64(s.perPage * (s.page - 1)))
|
||||
|
||||
// fetch models
|
||||
if err := modelsQuery.All(items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Result{
|
||||
Page: s.page,
|
||||
PerPage: s.perPage,
|
||||
TotalItems: int(totalCount),
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseAndExec is a short conventient method to trigger both
|
||||
// `Parse()` and `Exec()` in a single call.
|
||||
func (s *Provider) ParseAndExec(urlQuery string, modelsSlice any) (*Result, error) {
|
||||
if err := s.Parse(urlQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Exec(modelsSlice)
|
||||
}
|
||||
@@ -0,0 +1,505 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r)
|
||||
|
||||
if p.page != 1 {
|
||||
t.Fatalf("Expected page %d, got %d", 1, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != DefaultPerPage {
|
||||
t.Fatalf("Expected perPage %d, got %d", DefaultPerPage, p.perPage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderQuery(t *testing.T) {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
querySql := query.Build().SQL()
|
||||
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).Query(query)
|
||||
|
||||
expected := p.query.Build().SQL()
|
||||
|
||||
if querySql != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, querySql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderPage(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).Page(10)
|
||||
|
||||
if p.page != 10 {
|
||||
t.Fatalf("Expected page %v, got %v", 10, p.page)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderPerPage(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).PerPage(456)
|
||||
|
||||
if p.perPage != 456 {
|
||||
t.Fatalf("Expected perPage %v, got %v", 456, p.perPage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderSort(t *testing.T) {
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Sort(initialSort).
|
||||
AddSort(SortField{"test3", SortDesc})
|
||||
|
||||
encoded, _ := json.Marshal(p.sort)
|
||||
expected := `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"test3","direction":"DESC"}]`
|
||||
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected sort %v, got \n%v", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFilter(t *testing.T) {
|
||||
initialFilter := []FilterData{"test1", "test2"}
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Filter(initialFilter).
|
||||
AddFilter("test3")
|
||||
|
||||
encoded, _ := json.Marshal(p.filter)
|
||||
expected := `["test1","test2","test3"]`
|
||||
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected filter %v, got \n%v", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderParse(t *testing.T) {
|
||||
initialPage := 2
|
||||
initialPerPage := 123
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
initialFilter := []FilterData{"test1", "test2"}
|
||||
|
||||
scenarios := []struct {
|
||||
query string
|
||||
expectError bool
|
||||
expectPage int
|
||||
expectPerPage int
|
||||
expectSort string
|
||||
expectFilter string
|
||||
}{
|
||||
// empty
|
||||
{
|
||||
"",
|
||||
false,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid query
|
||||
{
|
||||
"invalid;",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid page
|
||||
{
|
||||
"page=a",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// invalid perPage
|
||||
{
|
||||
"perPage=a",
|
||||
true,
|
||||
initialPage,
|
||||
initialPerPage,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"}]`,
|
||||
`["test1","test2"]`,
|
||||
},
|
||||
// valid query parameters
|
||||
{
|
||||
"page=3&perPage=456&filter=test3&sort=-a,b,+c&other=123",
|
||||
false,
|
||||
3,
|
||||
456,
|
||||
`[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"ASC"},{"name":"a","direction":"DESC"},{"name":"b","direction":"ASC"},{"name":"c","direction":"ASC"}]`,
|
||||
`["test1","test2","test3"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Page(initialPage).
|
||||
PerPage(initialPerPage).
|
||||
Sort(initialSort).
|
||||
Filter(initialFilter)
|
||||
|
||||
err := p.Parse(s.query)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if p.page != s.expectPage {
|
||||
t.Errorf("(%d) Expected page %v, got %v", i, s.expectPage, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != s.expectPerPage {
|
||||
t.Errorf("(%d) Expected perPage %v, got %v", i, s.expectPerPage, p.perPage)
|
||||
}
|
||||
|
||||
encodedSort, _ := json.Marshal(p.sort)
|
||||
if string(encodedSort) != s.expectSort {
|
||||
t.Errorf("(%d) Expected sort %v, got \n%v", i, s.expectSort, string(encodedSort))
|
||||
}
|
||||
|
||||
encodedFilter, _ := json.Marshal(p.filter)
|
||||
if string(encodedFilter) != s.expectFilter {
|
||||
t.Errorf("(%d) Expected filter %v, got \n%v", i, s.expectFilter, string(encodedFilter))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderExecEmptyQuery(t *testing.T) {
|
||||
p := NewProvider(&testFieldResolver{}).
|
||||
Query(nil)
|
||||
|
||||
_, err := p.Exec(&[]testTableStruct{})
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error with empty query, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
page int
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
expectError bool
|
||||
expectResult string
|
||||
expectQueries []string
|
||||
}{
|
||||
// page normalization
|
||||
{
|
||||
-1,
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
// perPage normalization
|
||||
{
|
||||
10, // will be capped by total count
|
||||
0, // fallback to default
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
`{"page":1,"perPage":30,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30",
|
||||
},
|
||||
},
|
||||
// invalid sort field
|
||||
{
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"unknown", SortAsc}},
|
||||
[]FilterData{},
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
// invalid filter
|
||||
{
|
||||
1,
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{"test2 = 'test2.1'", "invalid"},
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
},
|
||||
// valid sort and filter fields
|
||||
{
|
||||
1,
|
||||
5555, // will be limited by MaxPerPage
|
||||
[]SortField{{"test2", SortDesc}},
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
false,
|
||||
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (test2 IS NOT null)) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC",
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (test2 IS NOT null)) AND (test1 >= '2') ORDER BY `test1` ASC, `test2` DESC LIMIT 200",
|
||||
},
|
||||
},
|
||||
// valid sort and filter fields (zero results)
|
||||
{
|
||||
1,
|
||||
10,
|
||||
[]SortField{{"test3", SortAsc}},
|
||||
[]FilterData{"test3 != ''"},
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":0,"items":[]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (test3 != '') ORDER BY `test1` ASC, `test3` ASC",
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (test3 != '') ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
// pagination test
|
||||
{
|
||||
3,
|
||||
1,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
false,
|
||||
`{"page":2,"perPage":1,"totalItems":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT count(*) FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
p := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(s.page).
|
||||
PerPage(s.perPage).
|
||||
Sort(s.sort).
|
||||
Filter(s.filter)
|
||||
|
||||
result, err := p.Exec(&[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Errorf("(%d) Expected resolver.Update to be called %d, got %d", i, 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Errorf("(%d) Expected result %v, got \n%v", i, s.expectResult, string(encoded))
|
||||
}
|
||||
|
||||
if len(s.expectQueries) != len(testDB.CalledQueries) {
|
||||
t.Errorf("(%d) Expected %d queries, got %d: \n%v", i, len(s.expectQueries), len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, q := range testDB.CalledQueries {
|
||||
if !list.ExistInSliceWithRegex(q, s.expectQueries) {
|
||||
t.Errorf("(%d) Didn't expect query %v", i, q)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderParseAndExec(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
query := testDB.Select("*").
|
||||
From("test").
|
||||
Where(dbx.Not(dbx.HashExp{"test1": nil})).
|
||||
OrderBy("test1 ASC")
|
||||
|
||||
scenarios := []struct {
|
||||
queryString string
|
||||
expectError bool
|
||||
expectResult string
|
||||
}{
|
||||
// empty
|
||||
{
|
||||
"",
|
||||
false,
|
||||
`{"page":1,"perPage":123,"totalItems":2,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
},
|
||||
// invalid query
|
||||
{
|
||||
"invalid;",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid page
|
||||
{
|
||||
"page=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid perPage
|
||||
{
|
||||
"perPage=a",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid sorting field
|
||||
{
|
||||
"sort=-unknown",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// invalid filter field
|
||||
{
|
||||
"filter=unknown>1",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
// valid query params
|
||||
{
|
||||
"page=3&perPage=555&filter=test1>1&sort=-test2,test3",
|
||||
false,
|
||||
`{"page":1,"perPage":200,"totalItems":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
provider := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(2).
|
||||
PerPage(123).
|
||||
Sort([]SortField{{"test2", SortAsc}}).
|
||||
Filter([]FilterData{"test1 > 0"})
|
||||
|
||||
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Errorf("(%d) Expected resolver.Update to be called %d, got %d", i, 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
if len(testDB.CalledQueries) != 2 {
|
||||
t.Errorf("(%d) Expected %d db queries, got %d: \n%v", i, 2, len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Errorf("(%d) Expected result %v, got \n%v", i, s.expectResult, string(encoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testTableStruct struct {
|
||||
Test1 int `db:"test1" json:"test1"`
|
||||
Test2 string `db:"test2" json:"test2"`
|
||||
Test3 string `db:"test3" json:"test3"`
|
||||
}
|
||||
|
||||
type testDB struct {
|
||||
*dbx.DB
|
||||
CalledQueries []string
|
||||
}
|
||||
|
||||
// NB! Don't forget to call `db.Close()` at the end of the test.
|
||||
func createTestDB() (*testDB, error) {
|
||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.CreateTable("test", map[string]string{"test1": "int default 0", "test2": "text default ''", "test3": "text default ''"}).Execute()
|
||||
db.Insert("test", dbx.Params{"test1": 1, "test2": "test2.1"}).Execute()
|
||||
db.Insert("test", dbx.Params{"test1": 2, "test2": "test2.2"}).Execute()
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
type testFieldResolver struct {
|
||||
UpdateQueryCalls int
|
||||
ResolveCalls int
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
t.UpdateQueryCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testFieldResolver) Resolve(field string) (name string, placeholderParams dbx.Params, err error) {
|
||||
t.ResolveCalls++
|
||||
|
||||
if field == "unknown" {
|
||||
return "", nil, errors.New("test error")
|
||||
}
|
||||
|
||||
return field, nil, nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
// FieldResolver defines an interface for managing search fields.
|
||||
type FieldResolver interface {
|
||||
// UpdateQuery allows to updated the provided db query based on the
|
||||
// resolved search fields (eg. adding joins aliases, etc.).
|
||||
//
|
||||
// Called internally by `search.Provider` before executing the search request.
|
||||
UpdateQuery(query *dbx.SelectQuery) error
|
||||
|
||||
// Resolve parses the provided field and returns a properly
|
||||
// formatted db identifier (eg. NULL, quoted column, placeholder parameter, etc.).
|
||||
Resolve(field string) (name string, placeholderParams dbx.Params, err error)
|
||||
}
|
||||
|
||||
// NewSimpleFieldResolver creates a new `SimpleFieldResolver` with the
|
||||
// provided `allowedFields`.
|
||||
//
|
||||
// Each `allowedFields` could be a plain string (eg. "name")
|
||||
// or a regexp pattern (eg. `^\w+[\w\.]*$`).
|
||||
func NewSimpleFieldResolver(allowedFields ...string) *SimpleFieldResolver {
|
||||
return &SimpleFieldResolver{
|
||||
allowedFields: allowedFields,
|
||||
}
|
||||
}
|
||||
|
||||
// SimpleFieldResolver defines a generic search resolver that allows
|
||||
// only its listed fields to be resolved and take part in a search query.
|
||||
//
|
||||
// If `allowedFields` are empty no fields filtering is applied.
|
||||
type SimpleFieldResolver struct {
|
||||
allowedFields []string
|
||||
}
|
||||
|
||||
// UpdateQuery implements `search.UpdateQuery` interface.
|
||||
func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
// nothing to update...
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve implements `search.Resolve` interface.
|
||||
//
|
||||
// Returns error if `field` is not in `r.allowedFields`.
|
||||
func (r *SimpleFieldResolver) Resolve(field string) (resultName string, placeholderParams dbx.Params, err error) {
|
||||
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
|
||||
return "", nil, fmt.Errorf("Failed to resolve field %q.", field)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[[%s]]", inflector.Columnify(field)), nil, nil
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSimpleFieldResolverUpdateQuery(t *testing.T) {
|
||||
r := search.NewSimpleFieldResolver("test")
|
||||
|
||||
scenarios := []struct {
|
||||
fieldName string
|
||||
expectQuery string
|
||||
}{
|
||||
// missing field (the query shouldn't change)
|
||||
{"", `SELECT "id" FROM "test"`},
|
||||
// unknown field (the query shouldn't change)
|
||||
{"unknown", `SELECT "id" FROM "test"`},
|
||||
// allowed field (the query shouldn't change)
|
||||
{"test", `SELECT "id" FROM "test"`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
|
||||
r.Resolve(s.fieldName)
|
||||
|
||||
if err := r.UpdateQuery(nil); err != nil {
|
||||
t.Errorf("(%d) UpdateQuery failed with error %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
rawQuery := query.Build().SQL()
|
||||
// rawQuery := s.expectQuery
|
||||
|
||||
if rawQuery != s.expectQuery {
|
||||
t.Errorf("(%d) Expected query %v, got \n%v", i, s.expectQuery, rawQuery)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleFieldResolverResolve(t *testing.T) {
|
||||
r := search.NewSimpleFieldResolver("test", `^test_regex\d+$`, "Test columnify!")
|
||||
|
||||
scenarios := []struct {
|
||||
fieldName string
|
||||
expectError bool
|
||||
expectName string
|
||||
}{
|
||||
{"", true, ""},
|
||||
{" ", true, ""},
|
||||
{"unknown", true, ""},
|
||||
{"test", false, "[[test]]"},
|
||||
{"test.sub", true, ""},
|
||||
{"test_regex", true, ""},
|
||||
{"test_regex1", false, "[[test_regex1]]"},
|
||||
{"Test columnify!", false, "[[Testcolumnify]]"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
name, params, err := r.Resolve(s.fieldName)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if name != s.expectName {
|
||||
t.Errorf("(%d) Expected name %q, got %q", i, s.expectName, name)
|
||||
}
|
||||
|
||||
// params should be empty
|
||||
if len(params) != 0 {
|
||||
t.Errorf("(%d) Expected 0 params, got %v", i, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// sort field directions
|
||||
const (
|
||||
SortAsc string = "ASC"
|
||||
SortDesc string = "DESC"
|
||||
)
|
||||
|
||||
// SortField defines a single search sort field.
|
||||
type SortField struct {
|
||||
Name string `json:"name"`
|
||||
Direction string `json:"direction"`
|
||||
}
|
||||
|
||||
// BuildExpr resolves the sort field into a valid db sort expression.
|
||||
func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
|
||||
name, params, err := fieldResolver.Resolve(s.Name)
|
||||
|
||||
// invalidate empty fields and non-column identifiers
|
||||
if err != nil || len(params) > 0 || name == "" || strings.ToLower(name) == "null" {
|
||||
return "", fmt.Errorf("Invalid sort field %q.", s.Name)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s", name, s.Direction), nil
|
||||
}
|
||||
|
||||
// ParseSortFromString parses the provided string expression
|
||||
// into a slice of SortFields.
|
||||
//
|
||||
// Example:
|
||||
// fields := search.ParseSortFromString("-name,+created")
|
||||
func ParseSortFromString(str string) []SortField {
|
||||
result := []SortField{}
|
||||
|
||||
data := strings.Split(str, ",")
|
||||
|
||||
for _, field := range data {
|
||||
// trim whitespaces
|
||||
field = strings.TrimSpace(field)
|
||||
|
||||
var dir string
|
||||
if strings.HasPrefix(field, "-") {
|
||||
dir = SortDesc
|
||||
field = strings.TrimPrefix(field, "-")
|
||||
} else {
|
||||
dir = SortAsc
|
||||
field = strings.TrimPrefix(field, "+")
|
||||
}
|
||||
|
||||
result = append(result, SortField{field, dir})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSortFieldBuildExpr(t *testing.T) {
|
||||
resolver := search.NewSimpleFieldResolver("test1", "test2", "test3", "test4.sub")
|
||||
|
||||
scenarios := []struct {
|
||||
sortField search.SortField
|
||||
expectError bool
|
||||
expectExpression string
|
||||
}{
|
||||
// empty
|
||||
{search.SortField{"", search.SortDesc}, true, ""},
|
||||
// unknown field
|
||||
{search.SortField{"unknown", search.SortAsc}, true, ""},
|
||||
// placeholder field
|
||||
{search.SortField{"'test'", search.SortAsc}, true, ""},
|
||||
// null field
|
||||
{search.SortField{"null", search.SortAsc}, true, ""},
|
||||
// allowed field - asc
|
||||
{search.SortField{"test1", search.SortAsc}, false, "[[test1]] ASC"},
|
||||
// allowed field - desc
|
||||
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.sortField.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if result != s.expectExpression {
|
||||
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectExpression, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortFromString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value string
|
||||
expectedJson string
|
||||
}{
|
||||
{"", `[{"name":"","direction":"ASC"}]`},
|
||||
{"test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"+test", `[{"name":"test","direction":"ASC"}]`},
|
||||
{"-test", `[{"name":"test","direction":"DESC"}]`},
|
||||
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := search.ParseSortFromString(s.value)
|
||||
encoded, _ := json.Marshal(result)
|
||||
|
||||
if string(encoded) != s.expectedJson {
|
||||
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectedJson, string(encoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
crand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// S256Challenge creates base64 encoded sha256 challenge string derived from code.
|
||||
// The padding of the result base64 string is stripped per [RFC 7636].
|
||||
//
|
||||
// [RFC 7636]: https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
|
||||
func S256Challenge(code string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(code))
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(h.Sum(nil)), "=")
|
||||
}
|
||||
|
||||
// Encrypt encrypts data with key (must be valid 32 char aes key).
|
||||
func Encrypt(data []byte, key string) (string, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
|
||||
// populates the nonce with a cryptographically secure random sequence
|
||||
if _, err := io.ReadFull(crand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cipherByte := gcm.Seal(nonce, nonce, data, nil)
|
||||
|
||||
result := base64.StdEncoding.EncodeToString(cipherByte)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts encrypted text with key (must be valid 32 chars aes key).
|
||||
func Decrypt(cipherText string, key string) ([]byte, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
|
||||
cipherByte, err := base64.StdEncoding.DecodeString(cipherText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce, cipherByteClean := cipherByte[:nonceSize], cipherByte[nonceSize:]
|
||||
plainData, err := gcm.Open(nil, nonce, cipherByteClean, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plainData, nil
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestS256Challenge(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{"", "47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU"},
|
||||
{"123", "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := security.S256Challenge(scenario.code)
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
data string
|
||||
key string
|
||||
expectError bool
|
||||
}{
|
||||
{"", "", true},
|
||||
{"123", "test", true}, // key must be valid 32 char aes string
|
||||
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Encrypt([]byte(scenario.data), scenario.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
|
||||
if scenario.expectError && result != "" {
|
||||
t.Errorf("(%d) Expected empty string, got %q", i, result)
|
||||
}
|
||||
if !scenario.expectError && result == "" {
|
||||
t.Errorf("(%d) Expected non empty encrypted result string", i)
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
if result != "" {
|
||||
decrypted, _ := security.Decrypt(result, scenario.key)
|
||||
if string(decrypted) != scenario.data {
|
||||
t.Errorf("(%d) Expected decrypted value to match with the data input, got %q", i, decrypted)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
cipher string
|
||||
key string
|
||||
expectError bool
|
||||
expectedData string
|
||||
}{
|
||||
{"", "", true, ""},
|
||||
{"123", "test", true, ""}, // key must be valid 32 char aes string
|
||||
{"8kcEqilvvYKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", true, ""}, // illegal base64 encoded cipherText
|
||||
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Decrypt(scenario.cipher, scenario.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
|
||||
resultStr := string(result)
|
||||
if resultStr != scenario.expectedData {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expectedData, resultStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// ParseUnverifiedJWT parses JWT token and returns its claims
|
||||
// but DOES NOT verify the signature.
|
||||
func ParseUnverifiedJWT(token string) (jwt.MapClaims, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
parser := &jwt.Parser{}
|
||||
_, _, err := parser.ParseUnverified(token, claims)
|
||||
|
||||
if err == nil {
|
||||
err = claims.Valid()
|
||||
}
|
||||
|
||||
return claims, err
|
||||
}
|
||||
|
||||
// ParseJWT verifies and parses JWT token and returns its claims.
|
||||
func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
|
||||
parser := &jwt.Parser{
|
||||
ValidMethods: []string{"HS256"},
|
||||
}
|
||||
|
||||
parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) {
|
||||
return []byte(verificationKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("Unable to parse token.")
|
||||
}
|
||||
|
||||
// NewToken generates and returns new HS256 signed JWT token.
|
||||
func NewToken(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
|
||||
seconds := time.Duration(secondsDuration) * time.Second
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"exp": time.Now().Add(seconds).Unix(),
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
for k, v := range payload {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestParseUnverifiedJWT(t *testing.T) {
|
||||
// invalid formatted JWT token
|
||||
result1, err1 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9")
|
||||
if err1 == nil {
|
||||
t.Error("Expected error got nil")
|
||||
}
|
||||
if len(result1) > 0 {
|
||||
t.Error("Expected no parsed claims, got", result1)
|
||||
}
|
||||
|
||||
// properly formatted JWT token with INVALID claims
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
result2, err2 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU")
|
||||
if err2 == nil {
|
||||
t.Error("Expected error got nil")
|
||||
}
|
||||
if len(result2) != 2 || result2["name"] != "test" {
|
||||
t.Errorf("Expected to have 2 claims, got %v", result2)
|
||||
}
|
||||
|
||||
// properly formatted JWT token with VALID claims
|
||||
// {"name": "test"}
|
||||
result3, err3 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9.ml0QsTms3K9wMygTu41ZhKlTyjmW9zHQtoS8FUsCCjU")
|
||||
if err3 != nil {
|
||||
t.Error("Expected nil, got", err3)
|
||||
}
|
||||
if len(result3) != 1 || result3["name"] != "test" {
|
||||
t.Errorf("Expected to have 2 claims, got %v", result3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJWT(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
token string
|
||||
secret string
|
||||
expectError bool
|
||||
expectClaims jwt.MapClaims
|
||||
}{
|
||||
// invalid formatted JWT token
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with INVALID claims and INVALID secret
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU",
|
||||
"invalid",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with INVALID claims and VALID secret
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with VALID claims and INVALID secret
|
||||
// {"name": "test", "exp": 1898636137}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTg5ODYzNjEzN30.gqRkHjpK5s1PxxBn9qPaWEWxTbpc1PPSD-an83TsXRY",
|
||||
"invalid",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted EXPIRED JWT token with VALID secret
|
||||
// {"name": "test", "exp": 1652097610}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6OTU3ODczMzc0fQ.0oUUKUnsQHs4nZO1pnxQHahKtcHspHu4_AplN2sGC4A",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with VALID claims and VALID secret
|
||||
// {"name": "test", "exp": 1898636137}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTg5ODYzNjEzN30.gqRkHjpK5s1PxxBn9qPaWEWxTbpc1PPSD-an83TsXRY",
|
||||
"test",
|
||||
false,
|
||||
jwt.MapClaims{"name": "test", "exp": 1898636137.0},
|
||||
},
|
||||
// properly formatted JWT token with VALID claims (without exp) and VALID secret
|
||||
// {"name": "test"}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9.ml0QsTms3K9wMygTu41ZhKlTyjmW9zHQtoS8FUsCCjU",
|
||||
"test",
|
||||
false,
|
||||
jwt.MapClaims{"name": "test"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.ParseJWT(scenario.token, scenario.secret)
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
if len(result) != len(scenario.expectClaims) {
|
||||
t.Errorf("(%d) Expected %v got %v", i, scenario.expectClaims, result)
|
||||
}
|
||||
for k, v := range scenario.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Errorf("(%d) Missing expected claim %q", i, k)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, v, k, v2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToken(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
claims jwt.MapClaims
|
||||
key string
|
||||
duration int64
|
||||
expectError bool
|
||||
}{
|
||||
// empty, zero duration
|
||||
{jwt.MapClaims{}, "", 0, true},
|
||||
// empty, 10 seconds duration
|
||||
{jwt.MapClaims{}, "", 10, false},
|
||||
// non-empty, 10 seconds duration
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
token, tokenErr := security.NewToken(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Errorf("(%d) Expected NewToken to succeed, got error %v", i, tokenErr)
|
||||
continue
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasParseErr to be %v, got %v (%v)", i, scenario.expectError, hasParseErr, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Errorf("(%d) Missing required claim exp, got %v", i, claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Errorf("(%d) Expected %v claims, got %v", i, scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, claims[j], k, scenario.claims[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
)
|
||||
|
||||
// RandomString generates a random string of specified length.
|
||||
//
|
||||
// The generated string is cryptographically random and matches
|
||||
// [A-Za-z0-9]+ (aka. it's transparent to URL-encoding).
|
||||
func RandomString(length int) string {
|
||||
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
for i, b := range bytes {
|
||||
bytes[i] = alphabet[b%byte(len(alphabet))]
|
||||
}
|
||||
|
||||
return string(bytes)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestRandomString(t *testing.T) {
|
||||
generated := []string{}
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
length := 5 + i
|
||||
result := security.RandomString(length)
|
||||
|
||||
if len(result) != length {
|
||||
t.Errorf("(%d) Expected the length of the string to be %d, got %d", i, length, len(result))
|
||||
}
|
||||
|
||||
if match, _ := regexp.MatchString("[a-zA-Z0-9]+", result); !match {
|
||||
t.Errorf("(%d) The generated strings should have only [a-zA-Z0-9]+ characters, got %q", i, result)
|
||||
}
|
||||
|
||||
for _, str := range generated {
|
||||
if str == result {
|
||||
t.Errorf("(%d) Repeating random string - found %q in \n%v", i, result, generated)
|
||||
}
|
||||
}
|
||||
|
||||
generated = append(generated, result)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package store
|
||||
|
||||
import "sync"
|
||||
|
||||
// Store defines a concurrent safe in memory key-value data store.
|
||||
type Store[T any] struct {
|
||||
mux sync.RWMutex
|
||||
data map[string]T
|
||||
}
|
||||
|
||||
// New creates a new Store[T] instance.
|
||||
func New[T any](data map[string]T) *Store[T] {
|
||||
return &Store[T]{data: data}
|
||||
}
|
||||
|
||||
// Remove removes a single entry from the store.
|
||||
//
|
||||
// Remove does nothing if key doesn't exist in the store.
|
||||
func (s *Store[T]) Remove(key string) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
delete(s.data, key)
|
||||
}
|
||||
|
||||
// Has checks if element with the specified key exist or not.
|
||||
func (s *Store[T]) Has(key string) bool {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
_, ok := s.data[key]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// Get returns a single element value from the store.
|
||||
//
|
||||
// If key is not set, the zero T value is returned.
|
||||
func (s *Store[T]) Get(key string) T {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
return s.data[key]
|
||||
}
|
||||
|
||||
// Set sets (or overwrite if already exist) a new value for key.
|
||||
func (s *Store[T]) Set(key string, value T) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[string]T)
|
||||
}
|
||||
|
||||
s.data[key] = value
|
||||
}
|
||||
|
||||
// SetIfLessThanLimit sets (or overwrite if already exist) a new value for key.
|
||||
//
|
||||
// This is method is similar to Set() but **it will skip adding new elements**
|
||||
// to the store if the store length has reached the specified limit.
|
||||
// `false` is returned if maxAllowedElements limit is reached.
|
||||
func (s *Store[T]) SetIfLessThanLimit(key string, value T, maxAllowedElements int) bool {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
// init map if not already
|
||||
if s.data == nil {
|
||||
s.data = make(map[string]T)
|
||||
}
|
||||
|
||||
// check for existing item
|
||||
_, ok := s.data[key]
|
||||
|
||||
if !ok && len(s.data) >= maxAllowedElements {
|
||||
// cannot add more items
|
||||
return false
|
||||
}
|
||||
|
||||
// add/overwrite item
|
||||
s.data[key] = value
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
s := store.New(map[string]int{"test": 1})
|
||||
|
||||
if s.Get("test") != 1 {
|
||||
t.Error("Expected the initizialized store map to be loaded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
s := store.New(map[string]bool{"test": true})
|
||||
|
||||
keys := []string{"test", "missing"}
|
||||
|
||||
for i, key := range keys {
|
||||
s.Remove(key)
|
||||
if s.Has(key) {
|
||||
t.Errorf("(%d) Expected %q to be removed", i, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHas(t *testing.T) {
|
||||
s := store.New(map[string]int{"test1": 0, "test2": 1})
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
exist bool
|
||||
}{
|
||||
{"test1", true},
|
||||
{"test2", true},
|
||||
{"missing", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
exist := s.Has(scenario.key)
|
||||
if exist != scenario.exist {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.exist, exist)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
s := store.New(map[string]int{"test1": 0, "test2": 1})
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
expect int
|
||||
}{
|
||||
{"test1", 0},
|
||||
{"test2", 1},
|
||||
{"missing", 0}, // should auto fallback to the zero value
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
val := s.Get(scenario.key)
|
||||
if val != scenario.expect {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expect, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
s := store.New[int](nil)
|
||||
|
||||
data := map[string]int{"test1": 0, "test2": 1, "test3": 3}
|
||||
|
||||
// set values
|
||||
for k, v := range data {
|
||||
s.Set(k, v)
|
||||
}
|
||||
|
||||
// verify that the values are set
|
||||
for k, v := range data {
|
||||
if !s.Has(k) {
|
||||
t.Errorf("Expected key %q", k)
|
||||
}
|
||||
|
||||
val := s.Get(k)
|
||||
if val != v {
|
||||
t.Errorf("Expected %v, got %v for key %q", v, val, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIfLessThanLimit(t *testing.T) {
|
||||
s := store.New[int](nil)
|
||||
|
||||
limit := 2
|
||||
|
||||
// set values
|
||||
scenarios := []struct {
|
||||
key string
|
||||
value int
|
||||
expected bool
|
||||
}{
|
||||
{"test1", 1, true},
|
||||
{"test2", 2, true},
|
||||
{"test3", 3, false},
|
||||
{"test2", 4, true}, // overwrite
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := s.SetIfLessThanLimit(scenario.key, scenario.value, limit)
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected result %v, got %v", i, scenario.expected, result)
|
||||
}
|
||||
|
||||
if !scenario.expected && s.Has(scenario.key) {
|
||||
t.Errorf("(%d) Expected key %q to not be set", i, scenario.key)
|
||||
}
|
||||
|
||||
val := s.Get(scenario.key)
|
||||
if scenario.expected && val != scenario.value {
|
||||
t.Errorf("(%d) Expected value %v, got %v", i, scenario.value, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Broker defines a struct for managing subscriptions clients.
|
||||
type Broker struct {
|
||||
mux sync.RWMutex
|
||||
clients map[string]Client
|
||||
}
|
||||
|
||||
// NewBroker initializes and returns a new Broker instance.
|
||||
func NewBroker() *Broker {
|
||||
return &Broker{
|
||||
clients: make(map[string]Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Clients returns all registered clients.
|
||||
func (b *Broker) Clients() map[string]Client {
|
||||
return b.clients
|
||||
}
|
||||
|
||||
// ClientById finds a registered client by its id.
|
||||
//
|
||||
// Returns non-nil error when client with clientId is not registered.
|
||||
func (b *Broker) ClientById(clientId string) (Client, error) {
|
||||
client, ok := b.clients[clientId]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("No client associated with connection ID %q", clientId)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Register adds a new client to the broker instance.
|
||||
func (b *Broker) Register(client Client) {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
|
||||
b.clients[client.Id()] = client
|
||||
}
|
||||
|
||||
// Unregister removes a single client by its id.
|
||||
//
|
||||
// If client with clientId doesn't exist, this method does nothing.
|
||||
func (b *Broker) Unregister(clientId string) {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
|
||||
// Note:
|
||||
// There is no need to explicitly close the client's channel since it will be GC-ed anyway.
|
||||
// Addinitionally, closing the channel explicitly could panic when there are several
|
||||
// subscriptions attached to the client that needs to receive the same event.
|
||||
delete(b.clients, clientId)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestNewBroker(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
if b.Clients() == nil {
|
||||
t.Fatal("Expected clients map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
if total := len(b.Clients()); total != 0 {
|
||||
t.Fatalf("Expected no clients, got %v", total)
|
||||
}
|
||||
|
||||
b.Register(subscriptions.NewDefaultClient())
|
||||
b.Register(subscriptions.NewDefaultClient())
|
||||
|
||||
if total := len(b.Clients()); total != 2 {
|
||||
t.Fatalf("Expected 2 clients, got %v", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientById(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
clientA := subscriptions.NewDefaultClient()
|
||||
clientB := subscriptions.NewDefaultClient()
|
||||
b.Register(clientA)
|
||||
b.Register(clientB)
|
||||
|
||||
resultClient, err := b.ClientById(clientA.Id())
|
||||
if err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", clientA.Id(), err)
|
||||
}
|
||||
if resultClient.Id() != clientA.Id() {
|
||||
t.Fatalf("Expected client %s, got %s", clientA.Id(), resultClient.Id())
|
||||
}
|
||||
|
||||
if c, err := b.ClientById("missing"); err == nil {
|
||||
t.Fatalf("Expected error, found client %v", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
b.Register(client)
|
||||
|
||||
if _, err := b.ClientById(client.Id()); err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", client.Id(), err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregister(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
clientA := subscriptions.NewDefaultClient()
|
||||
clientB := subscriptions.NewDefaultClient()
|
||||
b.Register(clientA)
|
||||
b.Register(clientB)
|
||||
|
||||
if _, err := b.ClientById(clientA.Id()); err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", clientA.Id(), err)
|
||||
}
|
||||
|
||||
b.Unregister(clientA.Id())
|
||||
|
||||
if c, err := b.ClientById(clientA.Id()); err == nil {
|
||||
t.Fatalf("Expected error, found client %v", c)
|
||||
}
|
||||
|
||||
// clientB shouldn't have been removed
|
||||
if _, err := b.ClientById(clientB.Id()); err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", clientB.Id(), err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// Message defines a client's channel data.
|
||||
type Message struct {
|
||||
Name string
|
||||
Data string
|
||||
}
|
||||
|
||||
// Client is an interface for a generic subscription client.
|
||||
type Client interface {
|
||||
// Id Returns the unique id of the client.
|
||||
Id() string
|
||||
|
||||
// Channel returns the client's communication channel.
|
||||
Channel() chan Message
|
||||
|
||||
// Subscriptions returns all subscriptions to which the client has subscribed to.
|
||||
Subscriptions() map[string]struct{}
|
||||
|
||||
// Subscribe subscribes the client to the provided subscriptions list.
|
||||
Subscribe(subs ...string)
|
||||
|
||||
// Unsubscribe unsubscribes the client from the provided subscriptions list.
|
||||
Unsubscribe(subs ...string)
|
||||
|
||||
// HasSubscription checks if the client is subscribed to `sub`.
|
||||
HasSubscription(sub string) bool
|
||||
|
||||
// Set stores any value to the client's context.
|
||||
Set(key string, value any)
|
||||
|
||||
// Get retrieves the key value from the client's context.
|
||||
Get(key string) any
|
||||
}
|
||||
|
||||
// ensures that DefaultClient satisfies the Client interface
|
||||
var _ Client = (*DefaultClient)(nil)
|
||||
|
||||
// DefaultClient defines a generic subscription client.
|
||||
type DefaultClient struct {
|
||||
mux sync.RWMutex
|
||||
id string
|
||||
store map[string]any
|
||||
channel chan Message
|
||||
subscriptions map[string]struct{}
|
||||
}
|
||||
|
||||
// NewDefaultClient creates and returns a new DefaultClient instance.
|
||||
func NewDefaultClient() *DefaultClient {
|
||||
return &DefaultClient{
|
||||
id: security.RandomString(40),
|
||||
store: map[string]any{},
|
||||
channel: make(chan Message),
|
||||
subscriptions: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Id implements the Client.Id interface method.
|
||||
func (c *DefaultClient) Id() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// Channel implements the Client.Channel interface method.
|
||||
func (c *DefaultClient) Channel() chan Message {
|
||||
return c.channel
|
||||
}
|
||||
|
||||
// Subscriptions implements the Client.Subscriptions interface method.
|
||||
func (c *DefaultClient) Subscriptions() map[string]struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
return c.subscriptions
|
||||
}
|
||||
|
||||
// Subscribe implements the Client.Subscribe interface method.
|
||||
//
|
||||
// Empty subscriptions (aka. "") are ignored.
|
||||
func (c *DefaultClient) Subscribe(subs ...string) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
for _, s := range subs {
|
||||
if s == "" {
|
||||
continue // skip empty
|
||||
}
|
||||
|
||||
c.subscriptions[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe implements the Client.Unsubscribe interface method.
|
||||
//
|
||||
// If subs is not set, this method removes all registered client's subscriptions.
|
||||
func (c *DefaultClient) Unsubscribe(subs ...string) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if len(subs) > 0 {
|
||||
for _, s := range subs {
|
||||
delete(c.subscriptions, s)
|
||||
}
|
||||
} else {
|
||||
// unsubsribe all
|
||||
for s := range c.subscriptions {
|
||||
delete(c.subscriptions, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasSubscription implements the Client.HasSubscription interface method.
|
||||
func (c *DefaultClient) HasSubscription(sub string) bool {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
_, ok := c.subscriptions[sub]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// Get implements the Client.Get interface method.
|
||||
func (c *DefaultClient) Get(key string) any {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
return c.store[key]
|
||||
}
|
||||
|
||||
// Set implements the Client.Set interface method.
|
||||
func (c *DefaultClient) Set(key string, value any) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.store[key] = value
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestNewDefaultClient(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.Channel() == nil {
|
||||
t.Errorf("Expected channel to be initialized")
|
||||
}
|
||||
|
||||
if c.Subscriptions() == nil {
|
||||
t.Errorf("Expected subscriptions map to be initialized")
|
||||
}
|
||||
|
||||
if c.Id() == "" {
|
||||
t.Errorf("Expected unique id to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestId(t *testing.T) {
|
||||
clients := []*subscriptions.DefaultClient{
|
||||
subscriptions.NewDefaultClient(),
|
||||
subscriptions.NewDefaultClient(),
|
||||
subscriptions.NewDefaultClient(),
|
||||
subscriptions.NewDefaultClient(),
|
||||
}
|
||||
|
||||
ids := map[string]struct{}{}
|
||||
for i, c := range clients {
|
||||
// check uniqueness
|
||||
if _, ok := ids[c.Id()]; ok {
|
||||
t.Errorf("(%d) Expected unique id, got %v", i, c.Id())
|
||||
} else {
|
||||
ids[c.Id()] = struct{}{}
|
||||
}
|
||||
|
||||
// check length
|
||||
if len(c.Id()) != 40 {
|
||||
t.Errorf("(%d) Expected unique id to have 40 chars length, got %v", i, c.Id())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannel(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.Channel() == nil {
|
||||
t.Errorf("Expected channel to be initialized, got")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptions(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected subscriptions to be empty")
|
||||
}
|
||||
|
||||
c.Subscribe("sub1", "sub2", "sub3")
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
subs := []string{"", "sub1", "sub2", "sub3"}
|
||||
expected := []string{"sub1", "sub2", "sub3"}
|
||||
|
||||
c.Subscribe(subs...) // empty string should be skipped
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
}
|
||||
|
||||
for i, s := range expected {
|
||||
if !c.HasSubscription(s) {
|
||||
t.Errorf("(%d) Expected sub %s", i, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
c.Subscribe("sub1", "sub2", "sub3")
|
||||
|
||||
c.Unsubscribe("sub1")
|
||||
|
||||
if c.HasSubscription("sub1") {
|
||||
t.Error("Expected sub1 to be removed")
|
||||
}
|
||||
|
||||
c.Unsubscribe( /* all */ )
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSubscription(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.HasSubscription("missing") {
|
||||
t.Error("Expected false, got true")
|
||||
}
|
||||
|
||||
c.Subscribe("sub")
|
||||
|
||||
if !c.HasSubscription("sub") {
|
||||
t.Error("Expected true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAndGet(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
c.Set("demo", 1)
|
||||
|
||||
result, _ := c.Get("demo").(int)
|
||||
|
||||
if result != 1 {
|
||||
t.Errorf("Expected 1, got %v", result)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// DefaultDateLayout specifies the default app date strings layout.
|
||||
const DefaultDateLayout = "2006-01-02 15:04:05.000"
|
||||
|
||||
// NowDateTime returns new DateTime instance with the current local time.
|
||||
func NowDateTime() DateTime {
|
||||
return DateTime{t: time.Now()}
|
||||
}
|
||||
|
||||
// ParseDateTime creates a new DateTime from the provided value
|
||||
// (could be [cast.ToTime] supported string, [time.Time], etc.).
|
||||
func ParseDateTime(value any) (DateTime, error) {
|
||||
d := DateTime{}
|
||||
err := d.Scan(value)
|
||||
return d, err
|
||||
}
|
||||
|
||||
// DateTime represents a [time.Time] instance in UTC that is wrapped
|
||||
// and serialized using the app default date layout.
|
||||
type DateTime struct {
|
||||
t time.Time
|
||||
}
|
||||
|
||||
// Time returns the internal [time.Time] instance.
|
||||
func (d DateTime) Time() time.Time {
|
||||
return d.t
|
||||
}
|
||||
|
||||
// IsZero checks whether the current DateTime instance has zero time value.
|
||||
func (d DateTime) IsZero() bool {
|
||||
return d.Time().IsZero()
|
||||
}
|
||||
|
||||
// String serializes the current DateTime instance into a formated
|
||||
// UTC date string.
|
||||
//
|
||||
// The zero value is serialized to an empty string.
|
||||
func (d DateTime) String() string {
|
||||
if d.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return d.Time().UTC().Format(DefaultDateLayout)
|
||||
}
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
func (d DateTime) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the [json.Unmarshaler] interface.
|
||||
func (d *DateTime) UnmarshalJSON(b []byte) error {
|
||||
var raw string
|
||||
if err := json.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Scan(raw)
|
||||
}
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (d DateTime) Value() (driver.Value, error) {
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
// Scan implements [sql.Scanner] interface to scan the provided value
|
||||
// into the current DateTime instance.
|
||||
func (d *DateTime) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case DateTime:
|
||||
d.t = v.Time()
|
||||
case time.Time:
|
||||
d.t = v
|
||||
case int:
|
||||
d.t = cast.ToTime(v)
|
||||
default:
|
||||
str := cast.ToString(v)
|
||||
if str == "" {
|
||||
d.t = time.Time{}
|
||||
} else {
|
||||
d.t = cast.ToTime(str)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNowDateTime(t *testing.T) {
|
||||
now := time.Now().UTC().Format("2006-01-02 15:04:05") // without ms part for test consistency
|
||||
dt := types.NowDateTime()
|
||||
|
||||
if !strings.Contains(dt.String(), now) {
|
||||
t.Fatalf("Expected %q, got %q", now, dt.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDateTime(t *testing.T) {
|
||||
nowTime := time.Now().UTC()
|
||||
nowDateTime, _ := types.ParseDateTime(nowTime)
|
||||
nowStr := nowTime.Format(types.DefaultDateLayout)
|
||||
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expected string
|
||||
}{
|
||||
{nil, ""},
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
{nowDateTime, nowStr},
|
||||
{nowTime, nowStr},
|
||||
{1641024040, "2022-01-01 08:00:40.000"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
dt, err := types.ParseDateTime(s.value)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) Failed to parse %v: %v", i, s.value, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if dt.String() != s.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, dt.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeTime(t *testing.T) {
|
||||
str := "2022-01-01 11:23:45.678"
|
||||
|
||||
expected, err := time.Parse(types.DefaultDateLayout, str)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dt, err := types.ParseDateTime(str)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result := dt.Time()
|
||||
|
||||
if !expected.Equal(result) {
|
||||
t.Errorf("Expected time %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeIsZero(t *testing.T) {
|
||||
dt0 := types.DateTime{}
|
||||
if !dt0.IsZero() {
|
||||
t.Fatalf("Expected zero datatime, got %v", dt0)
|
||||
}
|
||||
|
||||
dt1 := types.NowDateTime()
|
||||
if dt1.IsZero() {
|
||||
t.Fatalf("Expected non-zero datatime, got %v", dt1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeString(t *testing.T) {
|
||||
dt0 := types.DateTime{}
|
||||
if dt0.String() != "" {
|
||||
t.Fatalf("Expected empty string for zer datetime, got %q", dt0.String())
|
||||
}
|
||||
|
||||
expected := "2022-01-01 11:23:45.678"
|
||||
dt1, _ := types.ParseDateTime(expected)
|
||||
if dt1.String() != expected {
|
||||
t.Fatalf("Expected %q, got %v", expected, dt1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeMarshalJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
date string
|
||||
expected string
|
||||
}{
|
||||
{"", `""`},
|
||||
{"2022-01-01 11:23:45.678", `"2022-01-01 11:23:45.678"`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
dt, err := types.ParseDateTime(s.date)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
}
|
||||
|
||||
result, err := dt.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
}
|
||||
|
||||
if string(result) != s.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeUnmarshalJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
date string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"invalid_json", ""},
|
||||
{"'123'", ""},
|
||||
{"2022-01-01 11:23:45.678", ""},
|
||||
{`"2022-01-01 11:23:45.678"`, "2022-01-01 11:23:45.678"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
dt := types.DateTime{}
|
||||
dt.UnmarshalJSON([]byte(s.date))
|
||||
|
||||
if dt.String() != s.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, dt.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeValue(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
{1641024040, "2022-01-01 08:00:40.000"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678"},
|
||||
{types.NowDateTime(), types.NowDateTime().String()},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
dt, _ := types.ParseDateTime(s.value)
|
||||
result, err := dt.Value()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if result != s.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTimeScan(t *testing.T) {
|
||||
now := time.Now().UTC().Format("2006-01-02 15:04:05") // without ms part for test consistency
|
||||
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expected string
|
||||
}{
|
||||
{nil, ""},
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
{types.NowDateTime(), now},
|
||||
{time.Now(), now},
|
||||
{1641024040, "2022-01-01 08:00:40.000"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
dt := types.DateTime{}
|
||||
|
||||
err := dt.Scan(s.value)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) Failed to parse %v: %v", i, s.value, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.Contains(dt.String(), s.expected) {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, dt.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// JsonArray defines a slice that is safe for json and db read/write.
|
||||
type JsonArray []any
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
func (m JsonArray) MarshalJSON() ([]byte, error) {
|
||||
type alias JsonArray // prevent recursion
|
||||
|
||||
// inialize an empty map to ensure that `[]` is returned as json
|
||||
if m == nil {
|
||||
m = JsonArray{}
|
||||
}
|
||||
|
||||
return json.Marshal(alias(m))
|
||||
}
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (m JsonArray) Value() (driver.Value, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
|
||||
return string(data), err
|
||||
}
|
||||
|
||||
// Scan implements [sql.Scanner] interface to scan the provided value
|
||||
// into the current `JsonArray` instance.
|
||||
func (m *JsonArray) Scan(value any) error {
|
||||
var data []byte
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
// no cast needed
|
||||
case []byte:
|
||||
data = v
|
||||
case string:
|
||||
data = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("Failed to unmarshal JsonArray value: %q.", value)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
data = []byte("[]")
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, m)
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestJsonArrayMarshalJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonArray
|
||||
expected string
|
||||
}{
|
||||
{nil, "[]"},
|
||||
{types.JsonArray{}, `[]`},
|
||||
{types.JsonArray{1, 2, 3}, `[1,2,3]`},
|
||||
{types.JsonArray{"test1", "test2", "test3"}, `["test1","test2","test3"]`},
|
||||
{types.JsonArray{1, "test"}, `[1,"test"]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.json.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
if string(result) != s.expected {
|
||||
t.Errorf("(%d) Expected %s, got %s", i, s.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonArrayValue(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonArray
|
||||
expected driver.Value
|
||||
}{
|
||||
{nil, nil},
|
||||
{types.JsonArray{}, `[]`},
|
||||
{types.JsonArray{1, 2, 3}, `[1,2,3]`},
|
||||
{types.JsonArray{"test1", "test2", "test3"}, `["test1","test2","test3"]`},
|
||||
{types.JsonArray{1, "test"}, `[1,"test"]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.json.Value()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
if result != s.expected {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonArrayScan(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expectError bool
|
||||
expectJson string
|
||||
}{
|
||||
{``, false, `[]`},
|
||||
{[]byte{}, false, `[]`},
|
||||
{nil, false, `[]`},
|
||||
{123, true, `[]`},
|
||||
{`""`, true, `[]`},
|
||||
{`invalid_json`, true, `[]`},
|
||||
{`"test"`, true, `[]`},
|
||||
{`1,2,3`, true, `[]`},
|
||||
{`[1, 2, 3`, true, `[]`},
|
||||
{`[1, 2, 3]`, false, `[1,2,3]`},
|
||||
{[]byte(`[1, 2, 3]`), false, `[1,2,3]`},
|
||||
{`[1, "test"]`, false, `[1,"test"]`},
|
||||
{`[]`, false, `[]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
arr := types.JsonArray{}
|
||||
scanErr := arr.Scan(s.value)
|
||||
|
||||
hasErr := scanErr != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, scanErr)
|
||||
continue
|
||||
}
|
||||
|
||||
result, _ := arr.MarshalJSON()
|
||||
|
||||
if string(result) != s.expectJson {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// JsonMap defines a map that is safe for json and db read/write.
|
||||
type JsonMap map[string]any
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
func (m JsonMap) MarshalJSON() ([]byte, error) {
|
||||
type alias JsonMap // prevent recursion
|
||||
|
||||
// inialize an empty map to ensure that `{}` is returned as json
|
||||
if m == nil {
|
||||
m = JsonMap{}
|
||||
}
|
||||
|
||||
return json.Marshal(alias(m))
|
||||
}
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (m JsonMap) Value() (driver.Value, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
|
||||
return string(data), err
|
||||
}
|
||||
|
||||
// Scan implements [sql.Scanner] interface to scan the provided value
|
||||
// into the current `JsonMap` instance.
|
||||
func (m *JsonMap) Scan(value any) error {
|
||||
var data []byte
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
// no cast needed
|
||||
case []byte:
|
||||
data = v
|
||||
case string:
|
||||
data = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("Failed to unmarshal JsonMap value: %q.", value)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
data = []byte("{}")
|
||||
}
|
||||
|
||||
return json.Unmarshal(data, m)
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestJsonMapMarshalJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonMap
|
||||
expected string
|
||||
}{
|
||||
{nil, "{}"},
|
||||
{types.JsonMap{}, `{}`},
|
||||
{types.JsonMap{"test1": 123, "test2": "lorem"}, `{"test1":123,"test2":"lorem"}`},
|
||||
{types.JsonMap{"test": []int{1, 2, 3}}, `{"test":[1,2,3]}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.json.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
if string(result) != s.expected {
|
||||
t.Errorf("(%d) Expected %s, got %s", i, s.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonMapValue(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonMap
|
||||
expected driver.Value
|
||||
}{
|
||||
{nil, nil},
|
||||
{types.JsonMap{}, `{}`},
|
||||
{types.JsonMap{"test1": 123, "test2": "lorem"}, `{"test1":123,"test2":"lorem"}`},
|
||||
{types.JsonMap{"test": []int{1, 2, 3}}, `{"test":[1,2,3]}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.json.Value()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
if result != s.expected {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonArrayMapScan(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expectError bool
|
||||
expectJson string
|
||||
}{
|
||||
{``, false, `{}`},
|
||||
{nil, false, `{}`},
|
||||
{[]byte{}, false, `{}`},
|
||||
{`{}`, false, `{}`},
|
||||
{123, true, `{}`},
|
||||
{`""`, true, `{}`},
|
||||
{`invalid_json`, true, `{}`},
|
||||
{`"test"`, true, `{}`},
|
||||
{`1,2,3`, true, `{}`},
|
||||
{`{"test": 1`, true, `{}`},
|
||||
{`{"test": 1}`, false, `{"test":1}`},
|
||||
{[]byte(`{"test": 1}`), false, `{"test":1}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
arr := types.JsonMap{}
|
||||
scanErr := arr.Scan(s.value)
|
||||
|
||||
hasErr := scanErr != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, scanErr)
|
||||
continue
|
||||
}
|
||||
|
||||
result, _ := arr.MarshalJSON()
|
||||
|
||||
if string(result) != s.expectJson {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// JsonRaw defines a json value type that is safe for db read/write.
|
||||
type JsonRaw []byte
|
||||
|
||||
// ParseJsonRaw creates a new JsonRaw instance from the provided value
|
||||
// (could be JsonRaw, int, float, string, []byte, etc.).
|
||||
func ParseJsonRaw(value any) (JsonRaw, error) {
|
||||
result := JsonRaw{}
|
||||
err := result.Scan(value)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// String returns the current JsonRaw instance as a json encoded string.
|
||||
func (j JsonRaw) String() string {
|
||||
return string(j)
|
||||
}
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
func (j JsonRaw) MarshalJSON() ([]byte, error) {
|
||||
if len(j) == 0 {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the [json.Unmarshaler] interface.
|
||||
func (j *JsonRaw) UnmarshalJSON(b []byte) error {
|
||||
if j == nil {
|
||||
return errors.New("JsonRaw: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
|
||||
*j = append((*j)[0:0], b...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (j JsonRaw) Value() (driver.Value, error) {
|
||||
if len(j) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return j.String(), nil
|
||||
}
|
||||
|
||||
// Scan implements [sql.Scanner] interface to scan the provided value
|
||||
// into the current JsonRaw instance.
|
||||
func (j *JsonRaw) Scan(value interface{}) error {
|
||||
var data []byte
|
||||
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
// no cast is needed
|
||||
case []byte:
|
||||
if len(v) != 0 {
|
||||
data = v
|
||||
}
|
||||
case string:
|
||||
if v != "" {
|
||||
data = []byte(v)
|
||||
}
|
||||
case JsonRaw:
|
||||
if len(v) != 0 {
|
||||
data = []byte(v)
|
||||
}
|
||||
default:
|
||||
bytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = bytes
|
||||
}
|
||||
|
||||
return j.UnmarshalJSON(data)
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestParseJsonRaw(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expectError bool
|
||||
expectJson string
|
||||
}{
|
||||
{nil, false, `null`},
|
||||
{``, false, `null`},
|
||||
{[]byte{}, false, `null`},
|
||||
{types.JsonRaw{}, false, `null`},
|
||||
{`{}`, false, `{}`},
|
||||
{`[]`, false, `[]`},
|
||||
{123, false, `123`},
|
||||
{`""`, false, `""`},
|
||||
{`test`, false, `test`},
|
||||
{`{"invalid"`, false, `{"invalid"`}, // treated as a byte casted string
|
||||
{`{"test":1}`, false, `{"test":1}`},
|
||||
{[]byte(`[1,2,3]`), false, `[1,2,3]`},
|
||||
{[]int{1, 2, 3}, false, `[1,2,3]`},
|
||||
{map[string]int{"test": 1}, false, `{"test":1}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
raw, parseErr := types.ParseJsonRaw(s.value)
|
||||
hasErr := parseErr != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
result, _ := raw.MarshalJSON()
|
||||
|
||||
if string(result) != s.expectJson {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonRawString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonRaw
|
||||
expected string
|
||||
}{
|
||||
{nil, ``},
|
||||
{types.JsonRaw{}, ``},
|
||||
{types.JsonRaw([]byte(`123`)), `123`},
|
||||
{types.JsonRaw(`{"demo":123}`), `{"demo":123}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := s.json.String()
|
||||
if result != s.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonRawMarshalJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonRaw
|
||||
expected string
|
||||
}{
|
||||
{nil, `null`},
|
||||
{types.JsonRaw{}, `null`},
|
||||
{types.JsonRaw([]byte(`123`)), `123`},
|
||||
{types.JsonRaw(`{"demo":123}`), `{"demo":123}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.json.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if string(result) != s.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonRawUnmarshalJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json []byte
|
||||
expectString string
|
||||
}{
|
||||
{nil, ""},
|
||||
{[]byte{0, 1, 2}, "\x00\x01\x02"},
|
||||
{[]byte("123"), "123"},
|
||||
{[]byte("test"), "test"},
|
||||
{[]byte(`{"test":123}`), `{"test":123}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
raw := types.JsonRaw{}
|
||||
err := raw.UnmarshalJSON(s.json)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.String() != s.expectString {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, s.expectString, raw.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonRawValue(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
json types.JsonRaw
|
||||
expected driver.Value
|
||||
}{
|
||||
{nil, nil},
|
||||
{types.JsonRaw{}, nil},
|
||||
{types.JsonRaw(``), nil},
|
||||
{types.JsonRaw(`test`), `test`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.json.Value()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
if result != s.expected {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonRawScan(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expectError bool
|
||||
expectJson string
|
||||
}{
|
||||
{nil, false, `null`},
|
||||
{``, false, `null`},
|
||||
{[]byte{}, false, `null`},
|
||||
{types.JsonRaw{}, false, `null`},
|
||||
{types.JsonRaw(`test`), false, `test`},
|
||||
{`{}`, false, `{}`},
|
||||
{`[]`, false, `[]`},
|
||||
{123, false, `123`},
|
||||
{`""`, false, `""`},
|
||||
{`test`, false, `test`},
|
||||
{`{"invalid"`, false, `{"invalid"`}, // treated as a byte casted string
|
||||
{`{"test":1}`, false, `{"test":1}`},
|
||||
{[]byte(`[1,2,3]`), false, `[1,2,3]`},
|
||||
{[]int{1, 2, 3}, false, `[1,2,3]`},
|
||||
{map[string]int{"test": 1}, false, `{"test":1}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
raw := types.JsonRaw{}
|
||||
scanErr := raw.Scan(s.value)
|
||||
hasErr := scanErr != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, scanErr)
|
||||
continue
|
||||
}
|
||||
|
||||
result, _ := raw.MarshalJSON()
|
||||
|
||||
if string(result) != s.expectJson {
|
||||
t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user