merge v0.23.0-rc changes
This commit is contained in:
@@ -3,6 +3,7 @@ package archive
|
||||
import (
|
||||
"archive/zip"
|
||||
"compress/flate"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
@@ -23,24 +24,21 @@ func Create(src string, dest string, skipPaths ...string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer zf.Close()
|
||||
|
||||
zw := zip.NewWriter(zf)
|
||||
defer zw.Close()
|
||||
|
||||
// register a custom Deflate compressor
|
||||
zw.RegisterCompressor(zip.Deflate, func(out io.Writer) (io.WriteCloser, error) {
|
||||
return flate.NewWriter(out, flate.BestSpeed)
|
||||
})
|
||||
|
||||
if err := zipAddFS(zw, os.DirFS(src), skipPaths...); err != nil {
|
||||
err = zipAddFS(zw, os.DirFS(src), skipPaths...)
|
||||
if err != nil {
|
||||
// try to cleanup at least the created zip file
|
||||
os.Remove(dest)
|
||||
|
||||
return err
|
||||
return errors.Join(err, zw.Close(), zf.Close(), os.Remove(dest))
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.Join(zw.Close(), zf.Close())
|
||||
}
|
||||
|
||||
// note remove after similar method is added in the std lib (https://github.com/golang/go/issues/54898)
|
||||
|
||||
+14
-10
@@ -18,6 +18,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameApple] = wrapFactory(NewAppleProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Apple)(nil)
|
||||
|
||||
// NameApple is the unique name of the Apple provider.
|
||||
@@ -27,23 +31,23 @@ const NameApple string = "apple"
|
||||
//
|
||||
// [OIDC differences]: https://bitbucket.org/openid/connect/src/master/How-Sign-in-with-Apple-differs-from-OpenID-Connect.md
|
||||
type Apple struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
|
||||
jwksUrl string
|
||||
jwksURL string
|
||||
}
|
||||
|
||||
// NewAppleProvider creates a new Apple provider instance with some defaults.
|
||||
func NewAppleProvider() *Apple {
|
||||
return &Apple{
|
||||
baseProvider: &baseProvider{
|
||||
BaseProvider: BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Apple",
|
||||
pkce: true,
|
||||
scopes: []string{"name", "email"},
|
||||
authUrl: "https://appleid.apple.com/auth/authorize",
|
||||
tokenUrl: "https://appleid.apple.com/auth/token",
|
||||
authURL: "https://appleid.apple.com/auth/authorize",
|
||||
tokenURL: "https://appleid.apple.com/auth/token",
|
||||
},
|
||||
jwksUrl: "https://appleid.apple.com/auth/keys",
|
||||
jwksURL: "https://appleid.apple.com/auth/keys",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +55,7 @@ func NewAppleProvider() *Apple {
|
||||
//
|
||||
// API reference: https://developer.apple.com/documentation/sign_in_with_apple/tokenresponse.
|
||||
func (p *Apple) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,11 +102,11 @@ func (p *Apple) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserData implements Provider.FetchRawUserData interface.
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
|
||||
//
|
||||
// Apple doesn't have a UserInfo endpoint and claims about users
|
||||
// are instead included in the "id_token" (https://openid.net/specs/openid-connect-core-1_0.html#id_tokenExample)
|
||||
func (p *Apple) FetchRawUserData(token *oauth2.Token) ([]byte, error) {
|
||||
func (p *Apple) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
idToken, _ := token.Extra("id_token").(string)
|
||||
|
||||
claims, err := p.parseAndVerifyIdToken(idToken)
|
||||
@@ -209,7 +213,7 @@ type jwk struct {
|
||||
}
|
||||
|
||||
func (p *Apple) fetchJWK(kid string) (*jwk, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.jwksUrl, nil)
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+73
-87
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
@@ -9,17 +10,22 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// AuthUser defines a standardized oauth2 user data structure.
|
||||
type AuthUser struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatarUrl"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
Expiry types.DateTime `json:"expiry"`
|
||||
RawUser map[string]any `json:"rawUser"`
|
||||
// ProviderFactoryFunc defines a function for initializing a new OAuth2 provider.
|
||||
type ProviderFactoryFunc func() Provider
|
||||
|
||||
// Providers defines a map with all of the available OAuth2 providers.
|
||||
//
|
||||
// To register a new provider append a new entry in the map.
|
||||
var Providers = map[string]ProviderFactoryFunc{}
|
||||
|
||||
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
|
||||
func NewProviderByName(name string) (Provider, error) {
|
||||
factory, ok := Providers[name]
|
||||
if !ok {
|
||||
return nil, errors.New("missing provider " + name)
|
||||
}
|
||||
|
||||
return factory(), nil
|
||||
}
|
||||
|
||||
// Provider defines a common interface for an OAuth2 client.
|
||||
@@ -61,104 +67,84 @@ type Provider interface {
|
||||
// SetClientSecret sets the provider client's app secret.
|
||||
SetClientSecret(secret string)
|
||||
|
||||
// RedirectUrl returns the end address to redirect the user
|
||||
// RedirectURL returns the end address to redirect the user
|
||||
// going through the OAuth flow.
|
||||
RedirectUrl() string
|
||||
RedirectURL() string
|
||||
|
||||
// SetRedirectUrl sets the provider's RedirectUrl.
|
||||
SetRedirectUrl(url string)
|
||||
// SetRedirectURL sets the provider's RedirectURL.
|
||||
SetRedirectURL(url string)
|
||||
|
||||
// AuthUrl returns the provider's authorization service url.
|
||||
AuthUrl() string
|
||||
// AuthURL returns the provider's authorization service url.
|
||||
AuthURL() string
|
||||
|
||||
// SetAuthUrl sets the provider's AuthUrl.
|
||||
SetAuthUrl(url string)
|
||||
// SetAuthURL sets the provider's AuthURL.
|
||||
SetAuthURL(url string)
|
||||
|
||||
// TokenUrl returns the provider's token exchange service url.
|
||||
TokenUrl() string
|
||||
// TokenURL returns the provider's token exchange service url.
|
||||
TokenURL() string
|
||||
|
||||
// SetTokenUrl sets the provider's TokenUrl.
|
||||
SetTokenUrl(url string)
|
||||
// SetTokenURL sets the provider's TokenURL.
|
||||
SetTokenURL(url string)
|
||||
|
||||
// UserApiUrl returns the provider's user info api url.
|
||||
UserApiUrl() string
|
||||
// UserInfoURL returns the provider's user info api url.
|
||||
UserInfoURL() string
|
||||
|
||||
// SetUserApiUrl sets the provider's UserApiUrl.
|
||||
SetUserApiUrl(url string)
|
||||
// SetUserInfoURL sets the provider's UserInfoURL.
|
||||
SetUserInfoURL(url string)
|
||||
|
||||
// Client returns an http client using the provided token.
|
||||
Client(token *oauth2.Token) *http.Client
|
||||
|
||||
// BuildAuthUrl returns a URL to the provider's consent page
|
||||
// BuildAuthURL returns a URL to the provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
BuildAuthUrl(state string, opts ...oauth2.AuthCodeOption) string
|
||||
BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
|
||||
// FetchToken converts an authorization code to token.
|
||||
FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
|
||||
// FetchRawUserData requests and marshalizes into `result` the
|
||||
// FetchRawUserInfo requests and marshalizes into `result` the
|
||||
// the OAuth user api response.
|
||||
FetchRawUserData(token *oauth2.Token) ([]byte, error)
|
||||
FetchRawUserInfo(token *oauth2.Token) ([]byte, error)
|
||||
|
||||
// FetchAuthUser is similar to FetchRawUserData, but normalizes and
|
||||
// FetchAuthUser is similar to FetchRawUserInfo, but normalizes and
|
||||
// marshalizes the user api response into a standardized AuthUser struct.
|
||||
FetchAuthUser(token *oauth2.Token) (user *AuthUser, err error)
|
||||
}
|
||||
|
||||
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
|
||||
func NewProviderByName(name string) (Provider, error) {
|
||||
switch name {
|
||||
case NameGoogle:
|
||||
return NewGoogleProvider(), nil
|
||||
case NameFacebook:
|
||||
return NewFacebookProvider(), nil
|
||||
case NameGithub:
|
||||
return NewGithubProvider(), nil
|
||||
case NameGitlab:
|
||||
return NewGitlabProvider(), nil
|
||||
case NameDiscord:
|
||||
return NewDiscordProvider(), nil
|
||||
case NameTwitter:
|
||||
return NewTwitterProvider(), nil
|
||||
case NameMicrosoft:
|
||||
return NewMicrosoftProvider(), nil
|
||||
case NameSpotify:
|
||||
return NewSpotifyProvider(), nil
|
||||
case NameKakao:
|
||||
return NewKakaoProvider(), nil
|
||||
case NameTwitch:
|
||||
return NewTwitchProvider(), nil
|
||||
case NameStrava:
|
||||
return NewStravaProvider(), nil
|
||||
case NameGitee:
|
||||
return NewGiteeProvider(), nil
|
||||
case NameLivechat:
|
||||
return NewLivechatProvider(), nil
|
||||
case NameGitea:
|
||||
return NewGiteaProvider(), nil
|
||||
case NameOIDC:
|
||||
return NewOIDCProvider(), nil
|
||||
case NameOIDC + "2":
|
||||
return NewOIDCProvider(), nil
|
||||
case NameOIDC + "3":
|
||||
return NewOIDCProvider(), nil
|
||||
case NameApple:
|
||||
return NewAppleProvider(), nil
|
||||
case NameInstagram:
|
||||
return NewInstagramProvider(), nil
|
||||
case NameVK:
|
||||
return NewVKProvider(), nil
|
||||
case NameYandex:
|
||||
return NewYandexProvider(), nil
|
||||
case NamePatreon:
|
||||
return NewPatreonProvider(), nil
|
||||
case NameMailcow:
|
||||
return NewMailcowProvider(), nil
|
||||
case NameBitbucket:
|
||||
return NewBitbucketProvider(), nil
|
||||
case NamePlanningcenter:
|
||||
return NewPlanningcenterProvider(), nil
|
||||
default:
|
||||
return nil, errors.New("Missing provider " + name)
|
||||
// wrapFactory is a helper that wraps a Provider specific factory
|
||||
// function and returns its result as Provider interface.
|
||||
func wrapFactory[T Provider](factory func() T) ProviderFactoryFunc {
|
||||
return func() Provider {
|
||||
return factory()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthUser defines a standardized OAuth2 user data structure.
|
||||
type AuthUser struct {
|
||||
Expiry types.DateTime `json:"expiry"`
|
||||
RawUser map[string]any `json:"rawUser"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatarURL"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use AvatarURL instead
|
||||
// AvatarUrl will be removed after dropping v0.22 support
|
||||
AvatarUrl string `json:"avatarUrl"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
//
|
||||
// @todo remove after dropping v0.22 support
|
||||
func (au AuthUser) MarshalJSON() ([]byte, error) {
|
||||
type alias AuthUser // prevent recursion
|
||||
|
||||
au2 := alias(au)
|
||||
au2.AvatarURL = au.AvatarURL // ensure that the legacy field is populated
|
||||
|
||||
return json.Marshal(au2)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,14 @@ import (
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
)
|
||||
|
||||
func TestProvidersCount(t *testing.T) {
|
||||
expected := 25
|
||||
|
||||
if total := len(auth.Providers); total != expected {
|
||||
t.Fatalf("Expected %d providers, got %d", expected, total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProviderByName(t *testing.T) {
|
||||
var err error
|
||||
var p auth.Provider
|
||||
|
||||
+57
-57
@@ -9,147 +9,147 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// baseProvider defines common fields and methods used by OAuth2 client providers.
|
||||
type baseProvider struct {
|
||||
// BaseProvider defines common fields and methods used by OAuth2 client providers.
|
||||
type BaseProvider struct {
|
||||
ctx context.Context
|
||||
clientId string
|
||||
clientSecret string
|
||||
displayName string
|
||||
redirectUrl string
|
||||
authUrl string
|
||||
tokenUrl string
|
||||
userApiUrl string
|
||||
redirectURL string
|
||||
authURL string
|
||||
tokenURL string
|
||||
userInfoURL string
|
||||
scopes []string
|
||||
pkce bool
|
||||
}
|
||||
|
||||
// Context implements Provider.Context() interface method.
|
||||
func (p *baseProvider) Context() context.Context {
|
||||
func (p *BaseProvider) Context() context.Context {
|
||||
return p.ctx
|
||||
}
|
||||
|
||||
// SetContext implements Provider.SetContext() interface method.
|
||||
func (p *baseProvider) SetContext(ctx context.Context) {
|
||||
func (p *BaseProvider) SetContext(ctx context.Context) {
|
||||
p.ctx = ctx
|
||||
}
|
||||
|
||||
// PKCE implements Provider.PKCE() interface method.
|
||||
func (p *baseProvider) PKCE() bool {
|
||||
func (p *BaseProvider) PKCE() bool {
|
||||
return p.pkce
|
||||
}
|
||||
|
||||
// SetPKCE implements Provider.SetPKCE() interface method.
|
||||
func (p *baseProvider) SetPKCE(enable bool) {
|
||||
func (p *BaseProvider) SetPKCE(enable bool) {
|
||||
p.pkce = enable
|
||||
}
|
||||
|
||||
// DisplayName implements Provider.DisplayName() interface method.
|
||||
func (p *baseProvider) DisplayName() string {
|
||||
func (p *BaseProvider) DisplayName() string {
|
||||
return p.displayName
|
||||
}
|
||||
|
||||
// SetDisplayName implements Provider.SetDisplayName() interface method.
|
||||
func (p *baseProvider) SetDisplayName(displayName string) {
|
||||
func (p *BaseProvider) SetDisplayName(displayName string) {
|
||||
p.displayName = displayName
|
||||
}
|
||||
|
||||
// Scopes implements Provider.Scopes() interface method.
|
||||
func (p *baseProvider) Scopes() []string {
|
||||
func (p *BaseProvider) Scopes() []string {
|
||||
return p.scopes
|
||||
}
|
||||
|
||||
// SetScopes implements Provider.SetScopes() interface method.
|
||||
func (p *baseProvider) SetScopes(scopes []string) {
|
||||
func (p *BaseProvider) SetScopes(scopes []string) {
|
||||
p.scopes = scopes
|
||||
}
|
||||
|
||||
// ClientId implements Provider.ClientId() interface method.
|
||||
func (p *baseProvider) ClientId() string {
|
||||
func (p *BaseProvider) ClientId() string {
|
||||
return p.clientId
|
||||
}
|
||||
|
||||
// SetClientId implements Provider.SetClientId() interface method.
|
||||
func (p *baseProvider) SetClientId(clientId string) {
|
||||
func (p *BaseProvider) SetClientId(clientId string) {
|
||||
p.clientId = clientId
|
||||
}
|
||||
|
||||
// ClientSecret implements Provider.ClientSecret() interface method.
|
||||
func (p *baseProvider) ClientSecret() string {
|
||||
func (p *BaseProvider) ClientSecret() string {
|
||||
return p.clientSecret
|
||||
}
|
||||
|
||||
// SetClientSecret implements Provider.SetClientSecret() interface method.
|
||||
func (p *baseProvider) SetClientSecret(secret string) {
|
||||
func (p *BaseProvider) SetClientSecret(secret string) {
|
||||
p.clientSecret = secret
|
||||
}
|
||||
|
||||
// RedirectUrl implements Provider.RedirectUrl() interface method.
|
||||
func (p *baseProvider) RedirectUrl() string {
|
||||
return p.redirectUrl
|
||||
// RedirectURL implements Provider.RedirectURL() interface method.
|
||||
func (p *BaseProvider) RedirectURL() string {
|
||||
return p.redirectURL
|
||||
}
|
||||
|
||||
// SetRedirectUrl implements Provider.SetRedirectUrl() interface method.
|
||||
func (p *baseProvider) SetRedirectUrl(url string) {
|
||||
p.redirectUrl = url
|
||||
// SetRedirectURL implements Provider.SetRedirectURL() interface method.
|
||||
func (p *BaseProvider) SetRedirectURL(url string) {
|
||||
p.redirectURL = url
|
||||
}
|
||||
|
||||
// AuthUrl implements Provider.AuthUrl() interface method.
|
||||
func (p *baseProvider) AuthUrl() string {
|
||||
return p.authUrl
|
||||
// AuthURL implements Provider.AuthURL() interface method.
|
||||
func (p *BaseProvider) AuthURL() string {
|
||||
return p.authURL
|
||||
}
|
||||
|
||||
// SetAuthUrl implements Provider.SetAuthUrl() interface method.
|
||||
func (p *baseProvider) SetAuthUrl(url string) {
|
||||
p.authUrl = url
|
||||
// SetAuthURL implements Provider.SetAuthURL() interface method.
|
||||
func (p *BaseProvider) SetAuthURL(url string) {
|
||||
p.authURL = url
|
||||
}
|
||||
|
||||
// TokenUrl implements Provider.TokenUrl() interface method.
|
||||
func (p *baseProvider) TokenUrl() string {
|
||||
return p.tokenUrl
|
||||
// TokenURL implements Provider.TokenURL() interface method.
|
||||
func (p *BaseProvider) TokenURL() string {
|
||||
return p.tokenURL
|
||||
}
|
||||
|
||||
// SetTokenUrl implements Provider.SetTokenUrl() interface method.
|
||||
func (p *baseProvider) SetTokenUrl(url string) {
|
||||
p.tokenUrl = url
|
||||
// SetTokenURL implements Provider.SetTokenURL() interface method.
|
||||
func (p *BaseProvider) SetTokenURL(url string) {
|
||||
p.tokenURL = url
|
||||
}
|
||||
|
||||
// UserApiUrl implements Provider.UserApiUrl() interface method.
|
||||
func (p *baseProvider) UserApiUrl() string {
|
||||
return p.userApiUrl
|
||||
// UserInfoURL implements Provider.UserInfoURL() interface method.
|
||||
func (p *BaseProvider) UserInfoURL() string {
|
||||
return p.userInfoURL
|
||||
}
|
||||
|
||||
// SetUserApiUrl implements Provider.SetUserApiUrl() interface method.
|
||||
func (p *baseProvider) SetUserApiUrl(url string) {
|
||||
p.userApiUrl = url
|
||||
// SetUserInfoURL implements Provider.SetUserInfoURL() interface method.
|
||||
func (p *BaseProvider) SetUserInfoURL(url string) {
|
||||
p.userInfoURL = url
|
||||
}
|
||||
|
||||
// BuildAuthUrl implements Provider.BuildAuthUrl() interface method.
|
||||
func (p *baseProvider) BuildAuthUrl(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
// BuildAuthURL implements Provider.BuildAuthURL() interface method.
|
||||
func (p *BaseProvider) BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
return p.oauth2Config().AuthCodeURL(state, opts...)
|
||||
}
|
||||
|
||||
// FetchToken implements Provider.FetchToken() interface method.
|
||||
func (p *baseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
func (p *BaseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return p.oauth2Config().Exchange(p.ctx, code, opts...)
|
||||
}
|
||||
|
||||
// Client implements Provider.Client() interface method.
|
||||
func (p *baseProvider) Client(token *oauth2.Token) *http.Client {
|
||||
func (p *BaseProvider) Client(token *oauth2.Token) *http.Client {
|
||||
return p.oauth2Config().Client(p.ctx, token)
|
||||
}
|
||||
|
||||
// FetchRawUserData implements Provider.FetchRawUserData() interface method.
|
||||
func (p *baseProvider) FetchRawUserData(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userApiUrl, nil)
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo() interface method.
|
||||
func (p *BaseProvider) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.sendRawUserDataRequest(req, token)
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
||||
|
||||
// sendRawUserDataRequest sends the specified user data request and return its raw response body.
|
||||
func (p *baseProvider) sendRawUserDataRequest(req *http.Request, token *oauth2.Token) ([]byte, error) {
|
||||
// sendRawUserInfoRequest sends the specified user info request and return its raw response body.
|
||||
func (p *BaseProvider) sendRawUserInfoRequest(req *http.Request, token *oauth2.Token) ([]byte, error) {
|
||||
client := p.Client(token)
|
||||
|
||||
res, err := client.Do(req)
|
||||
@@ -167,7 +167,7 @@ func (p *baseProvider) sendRawUserDataRequest(req *http.Request, token *oauth2.T
|
||||
if res.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to fetch OAuth2 user profile via %s (%d):\n%s",
|
||||
p.userApiUrl,
|
||||
p.userInfoURL,
|
||||
res.StatusCode,
|
||||
string(result),
|
||||
)
|
||||
@@ -177,15 +177,15 @@ func (p *baseProvider) sendRawUserDataRequest(req *http.Request, token *oauth2.T
|
||||
}
|
||||
|
||||
// oauth2Config constructs a oauth2.Config instance based on the provider settings.
|
||||
func (p *baseProvider) oauth2Config() *oauth2.Config {
|
||||
func (p *BaseProvider) oauth2Config() *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
RedirectURL: p.redirectUrl,
|
||||
RedirectURL: p.redirectURL,
|
||||
ClientID: p.clientId,
|
||||
ClientSecret: p.clientSecret,
|
||||
Scopes: p.scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: p.authUrl,
|
||||
TokenURL: p.tokenUrl,
|
||||
AuthURL: p.authURL,
|
||||
TokenURL: p.tokenURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.Scopes()
|
||||
if before != nil {
|
||||
@@ -24,7 +24,7 @@ func TestContext(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDisplayName(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.DisplayName()
|
||||
if before != "" {
|
||||
@@ -40,7 +40,7 @@ func TestDisplayName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPKCE(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.PKCE()
|
||||
if before != false {
|
||||
@@ -56,7 +56,7 @@ func TestPKCE(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestScopes(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.Scopes()
|
||||
if len(before) != 0 {
|
||||
@@ -72,7 +72,7 @@ func TestScopes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientId(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.ClientId()
|
||||
if before != "" {
|
||||
@@ -88,7 +88,7 @@ func TestClientId(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientSecret(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.ClientSecret()
|
||||
if before != "" {
|
||||
@@ -103,82 +103,82 @@ func TestClientSecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
func TestRedirectURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.RedirectUrl()
|
||||
before := b.RedirectURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected RedirectUrl to be empty, got %v", before)
|
||||
t.Fatalf("Expected RedirectURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetRedirectUrl("test")
|
||||
b.SetRedirectURL("test")
|
||||
|
||||
after := b.RedirectUrl()
|
||||
after := b.RedirectURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected RedirectUrl to be 'test', got %v", after)
|
||||
t.Fatalf("Expected RedirectURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
func TestAuthURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.AuthUrl()
|
||||
before := b.AuthURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected authUrl to be empty, got %v", before)
|
||||
t.Fatalf("Expected authURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetAuthUrl("test")
|
||||
b.SetAuthURL("test")
|
||||
|
||||
after := b.AuthUrl()
|
||||
after := b.AuthURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected authUrl to be 'test', got %v", after)
|
||||
t.Fatalf("Expected authURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
func TestTokenURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.TokenUrl()
|
||||
before := b.TokenURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected tokenUrl to be empty, got %v", before)
|
||||
t.Fatalf("Expected tokenURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetTokenUrl("test")
|
||||
b.SetTokenURL("test")
|
||||
|
||||
after := b.TokenUrl()
|
||||
after := b.TokenURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected tokenUrl to be 'test', got %v", after)
|
||||
t.Fatalf("Expected tokenURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserApiUrl(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
func TestUserInfoURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.UserApiUrl()
|
||||
before := b.UserInfoURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected userApiUrl to be empty, got %v", before)
|
||||
t.Fatalf("Expected userInfoURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetUserApiUrl("test")
|
||||
b.SetUserInfoURL("test")
|
||||
|
||||
after := b.UserApiUrl()
|
||||
after := b.UserInfoURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected userApiUrl to be 'test', got %v", after)
|
||||
t.Fatalf("Expected userInfoURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthUrl(t *testing.T) {
|
||||
b := baseProvider{
|
||||
authUrl: "authUrl_test",
|
||||
tokenUrl: "tokenUrl_test",
|
||||
redirectUrl: "redirectUrl_test",
|
||||
func TestBuildAuthURL(t *testing.T) {
|
||||
b := BaseProvider{
|
||||
authURL: "authURL_test",
|
||||
tokenURL: "tokenURL_test",
|
||||
redirectURL: "redirectURL_test",
|
||||
clientId: "clientId_test",
|
||||
clientSecret: "clientSecret_test",
|
||||
scopes: []string{"test_scope"},
|
||||
}
|
||||
|
||||
expected := "authUrl_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectUrl_test&response_type=code&scope=test_scope&state=state_test"
|
||||
result := b.BuildAuthUrl("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
expected := "authURL_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectURL_test&response_type=code&scope=test_scope&state=state_test"
|
||||
result := b.BuildAuthURL("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected auth url %q, got %q", expected, result)
|
||||
@@ -186,7 +186,7 @@ func TestBuildAuthUrl(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
b := baseProvider{}
|
||||
b := BaseProvider{}
|
||||
|
||||
result := b.Client(&oauth2.Token{})
|
||||
if result == nil {
|
||||
@@ -195,10 +195,10 @@ func TestClient(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOauth2Config(t *testing.T) {
|
||||
b := baseProvider{
|
||||
authUrl: "authUrl_test",
|
||||
tokenUrl: "tokenUrl_test",
|
||||
redirectUrl: "redirectUrl_test",
|
||||
b := BaseProvider{
|
||||
authURL: "authURL_test",
|
||||
tokenURL: "tokenURL_test",
|
||||
redirectURL: "redirectURL_test",
|
||||
clientId: "clientId_test",
|
||||
clientSecret: "clientSecret_test",
|
||||
scopes: []string{"test"},
|
||||
@@ -206,8 +206,8 @@ func TestOauth2Config(t *testing.T) {
|
||||
|
||||
result := b.oauth2Config()
|
||||
|
||||
if result.RedirectURL != b.RedirectUrl() {
|
||||
t.Errorf("Expected redirectUrl %s, got %s", b.RedirectUrl(), result.RedirectURL)
|
||||
if result.RedirectURL != b.RedirectURL() {
|
||||
t.Errorf("Expected redirectURL %s, got %s", b.RedirectURL(), result.RedirectURL)
|
||||
}
|
||||
|
||||
if result.ClientID != b.ClientId() {
|
||||
@@ -218,12 +218,12 @@ func TestOauth2Config(t *testing.T) {
|
||||
t.Errorf("Expected clientSecret %s, got %s", b.ClientSecret(), result.ClientSecret)
|
||||
}
|
||||
|
||||
if result.Endpoint.AuthURL != b.AuthUrl() {
|
||||
t.Errorf("Expected authUrl %s, got %s", b.AuthUrl(), result.Endpoint.AuthURL)
|
||||
if result.Endpoint.AuthURL != b.AuthURL() {
|
||||
t.Errorf("Expected authURL %s, got %s", b.AuthURL(), result.Endpoint.AuthURL)
|
||||
}
|
||||
|
||||
if result.Endpoint.TokenURL != b.TokenUrl() {
|
||||
t.Errorf("Expected authUrl %s, got %s", b.TokenUrl(), result.Endpoint.TokenURL)
|
||||
if result.Endpoint.TokenURL != b.TokenURL() {
|
||||
t.Errorf("Expected authURL %s, got %s", b.TokenURL(), result.Endpoint.TokenURL)
|
||||
}
|
||||
|
||||
if len(result.Scopes) != len(b.Scopes()) || result.Scopes[0] != b.Scopes()[0] {
|
||||
|
||||
+12
-8
@@ -10,6 +10,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameBitbucket] = wrapFactory(NewBitbucketProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Bitbucket)(nil)
|
||||
|
||||
// NameBitbucket is the unique name of the Bitbucket provider.
|
||||
@@ -17,19 +21,19 @@ const NameBitbucket = "bitbucket"
|
||||
|
||||
// Bitbucket is an auth provider for Bitbucket.
|
||||
type Bitbucket struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewBitbucketProvider creates a new Bitbucket provider instance with some defaults.
|
||||
func NewBitbucketProvider() *Bitbucket {
|
||||
return &Bitbucket{&baseProvider{
|
||||
return &Bitbucket{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Bitbucket",
|
||||
pkce: false,
|
||||
scopes: []string{"account"},
|
||||
authUrl: "https://bitbucket.org/site/oauth2/authorize",
|
||||
tokenUrl: "https://bitbucket.org/site/oauth2/access_token",
|
||||
userApiUrl: "https://api.bitbucket.org/2.0/user",
|
||||
authURL: "https://bitbucket.org/site/oauth2/authorize",
|
||||
tokenURL: "https://bitbucket.org/site/oauth2/access_token",
|
||||
userInfoURL: "https://api.bitbucket.org/2.0/user",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -37,7 +41,7 @@ func NewBitbucketProvider() *Bitbucket {
|
||||
//
|
||||
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-get
|
||||
func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,7 +80,7 @@ func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name: extracted.DisplayName,
|
||||
Username: extracted.Username,
|
||||
Email: email,
|
||||
AvatarUrl: extracted.Links.Avatar.Href,
|
||||
AvatarURL: extracted.Links.Avatar.Href,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
@@ -95,7 +99,7 @@ func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
//
|
||||
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-emails-get
|
||||
func (p *Bitbucket) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
|
||||
response, err := p.Client(token).Get(p.userApiUrl + "/emails")
|
||||
response, err := p.Client(token).Get(p.userInfoURL + "/emails")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
+12
-8
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameDiscord] = wrapFactory(NewDiscordProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Discord)(nil)
|
||||
|
||||
// NameDiscord is the unique name of the Discord provider.
|
||||
@@ -16,21 +20,21 @@ const NameDiscord string = "discord"
|
||||
|
||||
// Discord allows authentication via Discord OAuth2.
|
||||
type Discord struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewDiscordProvider creates a new Discord provider instance with some defaults.
|
||||
func NewDiscordProvider() *Discord {
|
||||
// https://discord.com/developers/docs/topics/oauth2
|
||||
// https://discord.com/developers/docs/resources/user#get-current-user
|
||||
return &Discord{&baseProvider{
|
||||
return &Discord{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Discord",
|
||||
pkce: true,
|
||||
scopes: []string{"identify", "email"},
|
||||
authUrl: "https://discord.com/api/oauth2/authorize",
|
||||
tokenUrl: "https://discord.com/api/oauth2/token",
|
||||
userApiUrl: "https://discord.com/api/users/@me",
|
||||
authURL: "https://discord.com/api/oauth2/authorize",
|
||||
tokenURL: "https://discord.com/api/oauth2/token",
|
||||
userInfoURL: "https://discord.com/api/users/@me",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewDiscordProvider() *Discord {
|
||||
//
|
||||
// API reference: https://discord.com/developers/docs/resources/user#user-object
|
||||
func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +66,7 @@ func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
|
||||
// Build a full avatar URL using the avatar hash provided in the API response
|
||||
// https://discord.com/developers/docs/reference#image-formatting
|
||||
avatarUrl := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", extracted.Id, extracted.Avatar)
|
||||
avatarURL := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", extracted.Id, extracted.Avatar)
|
||||
|
||||
// Concatenate the user's username and discriminator into a single username string
|
||||
username := fmt.Sprintf("%s#%s", extracted.Username, extracted.Discriminator)
|
||||
@@ -71,7 +75,7 @@ func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: extracted.Id,
|
||||
Name: username,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: avatarUrl,
|
||||
AvatarURL: avatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+11
-7
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2/facebook"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameFacebook] = wrapFactory(NewFacebookProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Facebook)(nil)
|
||||
|
||||
// NameFacebook is the unique name of the Facebook provider.
|
||||
@@ -16,19 +20,19 @@ const NameFacebook string = "facebook"
|
||||
|
||||
// Facebook allows authentication via Facebook OAuth2.
|
||||
type Facebook struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewFacebookProvider creates new Facebook provider instance with some defaults.
|
||||
func NewFacebookProvider() *Facebook {
|
||||
return &Facebook{&baseProvider{
|
||||
return &Facebook{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Facebook",
|
||||
pkce: true,
|
||||
scopes: []string{"email"},
|
||||
authUrl: facebook.Endpoint.AuthURL,
|
||||
tokenUrl: facebook.Endpoint.TokenURL,
|
||||
userApiUrl: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
|
||||
authURL: facebook.Endpoint.AuthURL,
|
||||
tokenURL: facebook.Endpoint.TokenURL,
|
||||
userInfoURL: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -36,7 +40,7 @@ func NewFacebookProvider() *Facebook {
|
||||
//
|
||||
// API reference: https://developers.facebook.com/docs/graph-api/reference/user/
|
||||
func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +66,7 @@ func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.Picture.Data.Url,
|
||||
AvatarURL: extracted.Picture.Data.Url,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+12
-8
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGitea] = wrapFactory(NewGiteaProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Gitea)(nil)
|
||||
|
||||
// NameGitea is the unique name of the Gitea provider.
|
||||
@@ -16,19 +20,19 @@ const NameGitea string = "gitea"
|
||||
|
||||
// Gitea allows authentication via Gitea OAuth2.
|
||||
type Gitea struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGiteaProvider creates new Gitea provider instance with some defaults.
|
||||
func NewGiteaProvider() *Gitea {
|
||||
return &Gitea{&baseProvider{
|
||||
return &Gitea{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Gitea",
|
||||
pkce: true,
|
||||
scopes: []string{"read:user", "user:email"},
|
||||
authUrl: "https://gitea.com/login/oauth/authorize",
|
||||
tokenUrl: "https://gitea.com/login/oauth/access_token",
|
||||
userApiUrl: "https://gitea.com/api/v1/user",
|
||||
authURL: "https://gitea.com/login/oauth/authorize",
|
||||
tokenURL: "https://gitea.com/login/oauth/access_token",
|
||||
userInfoURL: "https://gitea.com/api/v1/user",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -36,7 +40,7 @@ func NewGiteaProvider() *Gitea {
|
||||
//
|
||||
// API reference: https://try.gitea.io/api/swagger#/user/userGetCurrent
|
||||
func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +55,7 @@ func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name string `json:"full_name"`
|
||||
Username string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
@@ -62,7 +66,7 @@ func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+12
-8
@@ -11,6 +11,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGitee] = wrapFactory(NewGiteeProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Gitee)(nil)
|
||||
|
||||
// NameGitee is the unique name of the Gitee provider.
|
||||
@@ -18,19 +22,19 @@ const NameGitee string = "gitee"
|
||||
|
||||
// Gitee allows authentication via Gitee OAuth2.
|
||||
type Gitee struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGiteeProvider creates new Gitee provider instance with some defaults.
|
||||
func NewGiteeProvider() *Gitee {
|
||||
return &Gitee{&baseProvider{
|
||||
return &Gitee{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Gitee",
|
||||
pkce: true,
|
||||
scopes: []string{"user_info", "emails"},
|
||||
authUrl: "https://gitee.com/oauth/authorize",
|
||||
tokenUrl: "https://gitee.com/oauth/token",
|
||||
userApiUrl: "https://gitee.com/api/v5/user",
|
||||
authURL: "https://gitee.com/oauth/authorize",
|
||||
tokenURL: "https://gitee.com/oauth/token",
|
||||
userInfoURL: "https://gitee.com/api/v5/user",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewGiteeProvider() *Gitee {
|
||||
//
|
||||
// API reference: https://gitee.com/api/v5/swagger#/getV5User
|
||||
func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,7 +57,7 @@ func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
@@ -63,7 +67,7 @@ func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+13
-9
@@ -11,6 +11,10 @@ import (
|
||||
"golang.org/x/oauth2/github"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGithub] = wrapFactory(NewGithubProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Github)(nil)
|
||||
|
||||
// NameGithub is the unique name of the Github provider.
|
||||
@@ -18,19 +22,19 @@ const NameGithub string = "github"
|
||||
|
||||
// Github allows authentication via Github OAuth2.
|
||||
type Github struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGithubProvider creates new Github provider instance with some defaults.
|
||||
func NewGithubProvider() *Github {
|
||||
return &Github{&baseProvider{
|
||||
return &Github{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "GitHub",
|
||||
pkce: true, // technically is not supported yet but it is safe as the PKCE params are just ignored
|
||||
scopes: []string{"read:user", "user:email"},
|
||||
authUrl: github.Endpoint.AuthURL,
|
||||
tokenUrl: github.Endpoint.TokenURL,
|
||||
userApiUrl: "https://api.github.com/user",
|
||||
authURL: github.Endpoint.AuthURL,
|
||||
tokenURL: github.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.github.com/user",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewGithubProvider() *Github {
|
||||
//
|
||||
// API reference: https://docs.github.com/en/rest/reference/users#get-the-authenticated-user
|
||||
func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,7 +57,7 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
@@ -64,7 +68,7 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
@@ -95,7 +99,7 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
func (p *Github) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
|
||||
client := p.Client(token)
|
||||
|
||||
response, err := client.Get(p.userApiUrl + "/emails")
|
||||
response, err := client.Get(p.userInfoURL + "/emails")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
+12
-8
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGitlab] = wrapFactory(NewGitlabProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Gitlab)(nil)
|
||||
|
||||
// NameGitlab is the unique name of the Gitlab provider.
|
||||
@@ -16,19 +20,19 @@ const NameGitlab string = "gitlab"
|
||||
|
||||
// Gitlab allows authentication via Gitlab OAuth2.
|
||||
type Gitlab struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGitlabProvider creates new Gitlab provider instance with some defaults.
|
||||
func NewGitlabProvider() *Gitlab {
|
||||
return &Gitlab{&baseProvider{
|
||||
return &Gitlab{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "GitLab",
|
||||
pkce: true,
|
||||
scopes: []string{"read_user"},
|
||||
authUrl: "https://gitlab.com/oauth/authorize",
|
||||
tokenUrl: "https://gitlab.com/oauth/token",
|
||||
userApiUrl: "https://gitlab.com/api/v4/user",
|
||||
authURL: "https://gitlab.com/oauth/authorize",
|
||||
tokenURL: "https://gitlab.com/oauth/token",
|
||||
userInfoURL: "https://gitlab.com/api/v4/user",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -36,7 +40,7 @@ func NewGitlabProvider() *Gitlab {
|
||||
//
|
||||
// API reference: https://docs.gitlab.com/ee/api/users.html#for-admin
|
||||
func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +55,7 @@ func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
@@ -62,7 +66,7 @@ func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+11
-7
@@ -8,6 +8,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGoogle] = wrapFactory(NewGoogleProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Google)(nil)
|
||||
|
||||
// NameGoogle is the unique name of the Google provider.
|
||||
@@ -15,12 +19,12 @@ const NameGoogle string = "google"
|
||||
|
||||
// Google allows authentication via Google OAuth2.
|
||||
type Google struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGoogleProvider creates new Google provider instance with some defaults.
|
||||
func NewGoogleProvider() *Google {
|
||||
return &Google{&baseProvider{
|
||||
return &Google{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Google",
|
||||
pkce: true,
|
||||
@@ -28,15 +32,15 @@ func NewGoogleProvider() *Google {
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
},
|
||||
authUrl: "https://accounts.google.com/o/oauth2/auth",
|
||||
tokenUrl: "https://accounts.google.com/o/oauth2/token",
|
||||
userApiUrl: "https://www.googleapis.com/oauth2/v1/userinfo",
|
||||
authURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
tokenURL: "https://accounts.google.com/o/oauth2/token",
|
||||
userInfoURL: "https://www.googleapis.com/oauth2/v1/userinfo",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Google's user api.
|
||||
func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -60,7 +64,7 @@ func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
AvatarUrl: extracted.Picture,
|
||||
AvatarURL: extracted.Picture,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+10
-6
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2/instagram"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameInstagram] = wrapFactory(NewInstagramProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Instagram)(nil)
|
||||
|
||||
// NameInstagram is the unique name of the Instagram provider.
|
||||
@@ -16,19 +20,19 @@ const NameInstagram string = "instagram"
|
||||
|
||||
// Instagram allows authentication via Instagram OAuth2.
|
||||
type Instagram struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewInstagramProvider creates new Instagram provider instance with some defaults.
|
||||
func NewInstagramProvider() *Instagram {
|
||||
return &Instagram{&baseProvider{
|
||||
return &Instagram{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Instagram",
|
||||
pkce: true,
|
||||
scopes: []string{"user_profile"},
|
||||
authUrl: instagram.Endpoint.AuthURL,
|
||||
tokenUrl: instagram.Endpoint.TokenURL,
|
||||
userApiUrl: "https://graph.instagram.com/me?fields=id,username,account_type",
|
||||
authURL: instagram.Endpoint.AuthURL,
|
||||
tokenURL: instagram.Endpoint.TokenURL,
|
||||
userInfoURL: "https://graph.instagram.com/me?fields=id,username,account_type",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -36,7 +40,7 @@ func NewInstagramProvider() *Instagram {
|
||||
//
|
||||
// API reference: https://developers.facebook.com/docs/instagram-basic-display-api/reference/user#fields
|
||||
func (p *Instagram) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+12
-8
@@ -10,6 +10,10 @@ import (
|
||||
"golang.org/x/oauth2/kakao"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameKakao] = wrapFactory(NewKakaoProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Kakao)(nil)
|
||||
|
||||
// NameKakao is the unique name of the Kakao provider.
|
||||
@@ -17,19 +21,19 @@ const NameKakao string = "kakao"
|
||||
|
||||
// Kakao allows authentication via Kakao OAuth2.
|
||||
type Kakao struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewKakaoProvider creates a new Kakao provider instance with some defaults.
|
||||
func NewKakaoProvider() *Kakao {
|
||||
return &Kakao{&baseProvider{
|
||||
return &Kakao{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Kakao",
|
||||
pkce: true,
|
||||
scopes: []string{"account_email", "profile_nickname", "profile_image"},
|
||||
authUrl: kakao.Endpoint.AuthURL,
|
||||
tokenUrl: kakao.Endpoint.TokenURL,
|
||||
userApiUrl: "https://kapi.kakao.com/v2/user/me",
|
||||
authURL: kakao.Endpoint.AuthURL,
|
||||
tokenURL: kakao.Endpoint.TokenURL,
|
||||
userInfoURL: "https://kapi.kakao.com/v2/user/me",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -37,7 +41,7 @@ func NewKakaoProvider() *Kakao {
|
||||
//
|
||||
// API reference: https://developers.kakao.com/docs/latest/en/kakaologin/rest-api#req-user-info-response
|
||||
func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +55,7 @@ func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id int `json:"id"`
|
||||
Profile struct {
|
||||
Nickname string `json:"nickname"`
|
||||
ImageUrl string `json:"profile_image"`
|
||||
ImageURL string `json:"profile_image"`
|
||||
} `json:"properties"`
|
||||
KakaoAccount struct {
|
||||
Email string `json:"email"`
|
||||
@@ -66,7 +70,7 @@ func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user := &AuthUser{
|
||||
Id: strconv.Itoa(extracted.Id),
|
||||
Username: extracted.Profile.Nickname,
|
||||
AvatarUrl: extracted.Profile.ImageUrl,
|
||||
AvatarURL: extracted.Profile.ImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+12
-8
@@ -8,6 +8,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameLivechat] = wrapFactory(NewLivechatProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Livechat)(nil)
|
||||
|
||||
// NameLivechat is the unique name of the Livechat provider.
|
||||
@@ -15,19 +19,19 @@ const NameLivechat = "livechat"
|
||||
|
||||
// Livechat allows authentication via Livechat OAuth2.
|
||||
type Livechat struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewLivechatProvider creates new Livechat provider instance with some defaults.
|
||||
func NewLivechatProvider() *Livechat {
|
||||
return &Livechat{&baseProvider{
|
||||
return &Livechat{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "LiveChat",
|
||||
pkce: true,
|
||||
scopes: []string{}, // default scopes are specified from the provider dashboard
|
||||
authUrl: "https://accounts.livechat.com/",
|
||||
tokenUrl: "https://accounts.livechat.com/token",
|
||||
userApiUrl: "https://accounts.livechat.com/v2/accounts/me",
|
||||
authURL: "https://accounts.livechat.com/",
|
||||
tokenURL: "https://accounts.livechat.com/token",
|
||||
userInfoURL: "https://accounts.livechat.com/v2/accounts/me",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -35,7 +39,7 @@ func NewLivechatProvider() *Livechat {
|
||||
//
|
||||
// API reference: https://developers.livechat.com/docs/authorization
|
||||
func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +54,7 @@ func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
@@ -59,7 +63,7 @@ func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
AvatarUrl: extracted.AvatarUrl,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
@@ -10,6 +10,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameMailcow] = wrapFactory(NewMailcowProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Mailcow)(nil)
|
||||
|
||||
// NameMailcow is the unique name of the mailcow provider.
|
||||
@@ -17,12 +21,12 @@ const NameMailcow string = "mailcow"
|
||||
|
||||
// Mailcow allows authentication via mailcow OAuth2.
|
||||
type Mailcow struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewMailcowProvider creates a new mailcow provider instance with some defaults.
|
||||
func NewMailcowProvider() *Mailcow {
|
||||
return &Mailcow{&baseProvider{
|
||||
return &Mailcow{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "mailcow",
|
||||
pkce: true,
|
||||
@@ -34,7 +38,7 @@ func NewMailcowProvider() *Mailcow {
|
||||
//
|
||||
// API reference: https://github.com/mailcow/mailcow-dockerized/blob/master/data/web/oauth/profile.php
|
||||
func (p *Mailcow) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+10
-6
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2/microsoft"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameMicrosoft] = wrapFactory(NewMicrosoftProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Microsoft)(nil)
|
||||
|
||||
// NameMicrosoft is the unique name of the Microsoft provider.
|
||||
@@ -16,20 +20,20 @@ const NameMicrosoft string = "microsoft"
|
||||
|
||||
// Microsoft allows authentication via AzureADEndpoint OAuth2.
|
||||
type Microsoft struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewMicrosoftProvider creates new Microsoft AD provider instance with some defaults.
|
||||
func NewMicrosoftProvider() *Microsoft {
|
||||
endpoints := microsoft.AzureADEndpoint("")
|
||||
return &Microsoft{&baseProvider{
|
||||
return &Microsoft{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Microsoft",
|
||||
pkce: true,
|
||||
scopes: []string{"User.Read"},
|
||||
authUrl: endpoints.AuthURL,
|
||||
tokenUrl: endpoints.TokenURL,
|
||||
userApiUrl: "https://graph.microsoft.com/v1.0/me",
|
||||
authURL: endpoints.AuthURL,
|
||||
tokenURL: endpoints.TokenURL,
|
||||
userInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewMicrosoftProvider() *Microsoft {
|
||||
// API reference: https://learn.microsoft.com/en-us/azure/active-directory/develop/userinfo
|
||||
// Graph explorer: https://developer.microsoft.com/en-us/graph/graph-explorer
|
||||
func (p *Microsoft) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+12
-4
@@ -8,6 +8,12 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameOIDC] = wrapFactory(NewOIDCProvider)
|
||||
Providers[NameOIDC+"2"] = wrapFactory(NewOIDCProvider)
|
||||
Providers[NameOIDC+"3"] = wrapFactory(NewOIDCProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*OIDC)(nil)
|
||||
|
||||
// NameOIDC is the unique name of the OpenID Connect (OIDC) provider.
|
||||
@@ -15,12 +21,12 @@ const NameOIDC string = "oidc"
|
||||
|
||||
// OIDC allows authentication via OpenID Connect (OIDC) OAuth2 provider.
|
||||
type OIDC struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewOIDCProvider creates new OpenID Connect (OIDC) provider instance with some defaults.
|
||||
func NewOIDCProvider() *OIDC {
|
||||
return &OIDC{&baseProvider{
|
||||
return &OIDC{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "OIDC",
|
||||
pkce: true,
|
||||
@@ -35,8 +41,10 @@ func NewOIDCProvider() *OIDC {
|
||||
// FetchAuthUser returns an AuthUser instance based the provider's user api.
|
||||
//
|
||||
// API reference: https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
|
||||
//
|
||||
// @todo consider adding support for reading the user data from the id_token.
|
||||
func (p *OIDC) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +70,7 @@ func (p *OIDC) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: extracted.Picture,
|
||||
AvatarURL: extracted.Picture,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+12
-8
@@ -8,6 +8,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NamePatreon] = wrapFactory(NewPatreonProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Patreon)(nil)
|
||||
|
||||
// NamePatreon is the unique name of the Patreon provider.
|
||||
@@ -15,19 +19,19 @@ const NamePatreon string = "patreon"
|
||||
|
||||
// Patreon allows authentication via Patreon OAuth2.
|
||||
type Patreon struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewPatreonProvider creates new Patreon provider instance with some defaults.
|
||||
func NewPatreonProvider() *Patreon {
|
||||
return &Patreon{&baseProvider{
|
||||
return &Patreon{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Patreon",
|
||||
pkce: true,
|
||||
scopes: []string{"identity", "identity[email]"},
|
||||
authUrl: "https://www.patreon.com/oauth2/authorize",
|
||||
tokenUrl: "https://www.patreon.com/api/oauth2/token",
|
||||
userApiUrl: "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=full_name,email,vanity,image_url,is_email_verified",
|
||||
authURL: "https://www.patreon.com/oauth2/authorize",
|
||||
tokenURL: "https://www.patreon.com/api/oauth2/token",
|
||||
userInfoURL: "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=full_name,email,vanity,image_url,is_email_verified",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -37,7 +41,7 @@ func NewPatreonProvider() *Patreon {
|
||||
// https://docs.patreon.com/#get-api-oauth2-v2-identity
|
||||
// https://docs.patreon.com/#user-v2
|
||||
func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,7 +58,7 @@ func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"full_name"`
|
||||
Username string `json:"vanity"`
|
||||
AvatarUrl string `json:"image_url"`
|
||||
AvatarURL string `json:"image_url"`
|
||||
IsEmailVerified bool `json:"is_email_verified"`
|
||||
} `json:"attributes"`
|
||||
} `json:"data"`
|
||||
@@ -67,7 +71,7 @@ func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: extracted.Data.Id,
|
||||
Username: extracted.Data.Attributes.Username,
|
||||
Name: extracted.Data.Attributes.Name,
|
||||
AvatarUrl: extracted.Data.Attributes.AvatarUrl,
|
||||
AvatarURL: extracted.Data.Attributes.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NamePlanningcenter] = wrapFactory(NewPlanningcenterProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Planningcenter)(nil)
|
||||
|
||||
// NamePlanningcenter is the unique name of the Planningcenter provider.
|
||||
@@ -16,19 +20,19 @@ const NamePlanningcenter string = "planningcenter"
|
||||
|
||||
// Planningcenter allows authentication via Planningcenter OAuth2.
|
||||
type Planningcenter struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewPlanningcenterProvider creates a new Planningcenter provider instance with some defaults.
|
||||
func NewPlanningcenterProvider() *Planningcenter {
|
||||
return &Planningcenter{&baseProvider{
|
||||
return &Planningcenter{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Planning Center",
|
||||
pkce: true,
|
||||
scopes: []string{"people"},
|
||||
authUrl: "https://api.planningcenteronline.com/oauth/authorize",
|
||||
tokenUrl: "https://api.planningcenteronline.com/oauth/token",
|
||||
userApiUrl: "https://api.planningcenteronline.com/people/v2/me",
|
||||
authURL: "https://api.planningcenteronline.com/oauth/authorize",
|
||||
tokenURL: "https://api.planningcenteronline.com/oauth/token",
|
||||
userInfoURL: "https://api.planningcenteronline.com/people/v2/me",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -36,7 +40,7 @@ func NewPlanningcenterProvider() *Planningcenter {
|
||||
//
|
||||
// API reference: https://developer.planning.center/docs/#/overview/authentication
|
||||
func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -52,7 +56,7 @@ func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Attributes struct {
|
||||
Status string `json:"status"`
|
||||
Name string `json:"name"`
|
||||
AvatarUrl string `json:"avatar"`
|
||||
AvatarURL string `json:"avatar"`
|
||||
// don't map the email because users can have multiple assigned
|
||||
// and it's not clear if they are verified
|
||||
}
|
||||
@@ -69,7 +73,7 @@ func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.Attributes.Name,
|
||||
AvatarUrl: extracted.Data.Attributes.AvatarUrl,
|
||||
AvatarURL: extracted.Data.Attributes.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+12
-8
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2/spotify"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameSpotify] = wrapFactory(NewSpotifyProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Spotify)(nil)
|
||||
|
||||
// NameSpotify is the unique name of the Spotify provider.
|
||||
@@ -16,12 +20,12 @@ const NameSpotify string = "spotify"
|
||||
|
||||
// Spotify allows authentication via Spotify OAuth2.
|
||||
type Spotify struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewSpotifyProvider creates a new Spotify provider instance with some defaults.
|
||||
func NewSpotifyProvider() *Spotify {
|
||||
return &Spotify{&baseProvider{
|
||||
return &Spotify{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Spotify",
|
||||
pkce: true,
|
||||
@@ -30,9 +34,9 @@ func NewSpotifyProvider() *Spotify {
|
||||
// currently Spotify doesn't return information whether the email is verified or not
|
||||
// "user-read-email",
|
||||
},
|
||||
authUrl: spotify.Endpoint.AuthURL,
|
||||
tokenUrl: spotify.Endpoint.TokenURL,
|
||||
userApiUrl: "https://api.spotify.com/v1/me",
|
||||
authURL: spotify.Endpoint.AuthURL,
|
||||
tokenURL: spotify.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.spotify.com/v1/me",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -40,7 +44,7 @@ func NewSpotifyProvider() *Spotify {
|
||||
//
|
||||
// API reference: https://developer.spotify.com/documentation/web-api/reference/#/operations/get-current-users-profile
|
||||
func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,7 +58,7 @@ func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"display_name"`
|
||||
Images []struct {
|
||||
Url string `json:"url"`
|
||||
URL string `json:"url"`
|
||||
} `json:"images"`
|
||||
// don't map the email because per the official docs
|
||||
// the email field is "unverified" and there is no proof
|
||||
@@ -76,7 +80,7 @@ func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if len(extracted.Images) > 0 {
|
||||
user.AvatarUrl = extracted.Images[0].Url
|
||||
user.AvatarURL = extracted.Images[0].URL
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
+12
-8
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameStrava] = wrapFactory(NewStravaProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Strava)(nil)
|
||||
|
||||
// NameStrava is the unique name of the Strava provider.
|
||||
@@ -16,21 +20,21 @@ const NameStrava string = "strava"
|
||||
|
||||
// Strava allows authentication via Strava OAuth2.
|
||||
type Strava struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewStravaProvider creates new Strava provider instance with some defaults.
|
||||
func NewStravaProvider() *Strava {
|
||||
return &Strava{&baseProvider{
|
||||
return &Strava{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Strava",
|
||||
pkce: true,
|
||||
scopes: []string{
|
||||
"profile:read_all",
|
||||
},
|
||||
authUrl: "https://www.strava.com/oauth/authorize",
|
||||
tokenUrl: "https://www.strava.com/api/v3/oauth/token",
|
||||
userApiUrl: "https://www.strava.com/api/v3/athlete",
|
||||
authURL: "https://www.strava.com/oauth/authorize",
|
||||
tokenURL: "https://www.strava.com/api/v3/oauth/token",
|
||||
userInfoURL: "https://www.strava.com/api/v3/athlete",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewStravaProvider() *Strava {
|
||||
//
|
||||
// API reference: https://developers.strava.com/docs/authentication/
|
||||
func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,7 +57,7 @@ func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
FirstName string `json:"firstname"`
|
||||
LastName string `json:"lastname"`
|
||||
Username string `json:"username"`
|
||||
ProfileImageUrl string `json:"profile"`
|
||||
ProfileImageURL string `json:"profile"`
|
||||
|
||||
// At the time of writing, Strava OAuth2 doesn't support returning the user email address
|
||||
// Email string `json:"email"`
|
||||
@@ -65,7 +69,7 @@ func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user := &AuthUser{
|
||||
Name: extracted.FirstName + " " + extracted.LastName,
|
||||
Username: extracted.Username,
|
||||
AvatarUrl: extracted.ProfileImageUrl,
|
||||
AvatarURL: extracted.ProfileImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+16
-12
@@ -11,6 +11,10 @@ import (
|
||||
"golang.org/x/oauth2/twitch"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameTwitch] = wrapFactory(NewTwitchProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Twitch)(nil)
|
||||
|
||||
// NameTwitch is the unique name of the Twitch provider.
|
||||
@@ -18,19 +22,19 @@ const NameTwitch string = "twitch"
|
||||
|
||||
// Twitch allows authentication via Twitch OAuth2.
|
||||
type Twitch struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewTwitchProvider creates new Twitch provider instance with some defaults.
|
||||
func NewTwitchProvider() *Twitch {
|
||||
return &Twitch{&baseProvider{
|
||||
return &Twitch{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Twitch",
|
||||
pkce: true,
|
||||
scopes: []string{"user:read:email"},
|
||||
authUrl: twitch.Endpoint.AuthURL,
|
||||
tokenUrl: twitch.Endpoint.TokenURL,
|
||||
userApiUrl: "https://api.twitch.tv/helix/users",
|
||||
authURL: twitch.Endpoint.AuthURL,
|
||||
tokenURL: twitch.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.twitch.tv/helix/users",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewTwitchProvider() *Twitch {
|
||||
//
|
||||
// API reference: https://dev.twitch.tv/docs/api/reference#get-users
|
||||
func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,7 +58,7 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Login string `json:"login"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Email string `json:"email"`
|
||||
ProfileImageUrl string `json:"profile_image_url"`
|
||||
ProfileImageURL string `json:"profile_image_url"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
@@ -70,7 +74,7 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Name: extracted.Data[0].DisplayName,
|
||||
Username: extracted.Data[0].Login,
|
||||
Email: extracted.Data[0].Email,
|
||||
AvatarUrl: extracted.Data[0].ProfileImageUrl,
|
||||
AvatarURL: extracted.Data[0].ProfileImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
@@ -81,16 +85,16 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserData implements Provider.FetchRawUserData interface.
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
|
||||
//
|
||||
// This differ from baseProvider because Twitch requires the `Client-Id` header.
|
||||
func (p *Twitch) FetchRawUserData(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", p.userApiUrl, nil)
|
||||
func (p *Twitch) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Client-Id", p.clientId)
|
||||
|
||||
return p.sendRawUserDataRequest(req, token)
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
||||
|
||||
+12
-8
@@ -8,6 +8,10 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameTwitter] = wrapFactory(NewTwitterProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Twitter)(nil)
|
||||
|
||||
// NameTwitter is the unique name of the Twitter provider.
|
||||
@@ -15,12 +19,12 @@ const NameTwitter string = "twitter"
|
||||
|
||||
// Twitter allows authentication via Twitter OAuth2.
|
||||
type Twitter struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewTwitterProvider creates new Twitter provider instance with some defaults.
|
||||
func NewTwitterProvider() *Twitter {
|
||||
return &Twitter{&baseProvider{
|
||||
return &Twitter{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Twitter",
|
||||
pkce: true,
|
||||
@@ -31,9 +35,9 @@ func NewTwitterProvider() *Twitter {
|
||||
// (see https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me)
|
||||
"tweet.read",
|
||||
},
|
||||
authUrl: "https://twitter.com/i/oauth2/authorize",
|
||||
tokenUrl: "https://api.twitter.com/2/oauth2/token",
|
||||
userApiUrl: "https://api.twitter.com/2/users/me?user.fields=id,name,username,profile_image_url",
|
||||
authURL: "https://twitter.com/i/oauth2/authorize",
|
||||
tokenURL: "https://api.twitter.com/2/oauth2/token",
|
||||
userInfoURL: "https://api.twitter.com/2/users/me?user.fields=id,name,username,profile_image_url",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -41,7 +45,7 @@ func NewTwitterProvider() *Twitter {
|
||||
//
|
||||
// API reference: https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me
|
||||
func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -56,7 +60,7 @@ func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
ProfileImageUrl string `json:"profile_image_url"`
|
||||
ProfileImageURL string `json:"profile_image_url"`
|
||||
|
||||
// NB! At the time of writing, Twitter OAuth2 doesn't support returning the user email address
|
||||
// (see https://twittercommunity.com/t/which-api-to-get-user-after-oauth2-authorization/162417/33)
|
||||
@@ -71,7 +75,7 @@ func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.Name,
|
||||
Username: extracted.Data.Username,
|
||||
AvatarUrl: extracted.Data.ProfileImageUrl,
|
||||
AvatarURL: extracted.Data.ProfileImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+12
-8
@@ -13,6 +13,10 @@ import (
|
||||
"golang.org/x/oauth2/vk"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameVK] = wrapFactory(NewVKProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*VK)(nil)
|
||||
|
||||
// NameVK is the unique name of the VK provider.
|
||||
@@ -20,21 +24,21 @@ const NameVK string = "vk"
|
||||
|
||||
// VK allows authentication via VK OAuth2.
|
||||
type VK struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewVKProvider creates new VK provider instance with some defaults.
|
||||
//
|
||||
// Docs: https://dev.vk.com/api/oauth-parameters
|
||||
func NewVKProvider() *VK {
|
||||
return &VK{&baseProvider{
|
||||
return &VK{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "ВКонтакте",
|
||||
pkce: false, // VK currently doesn't support PKCE and throws an error if PKCE params are send
|
||||
scopes: []string{"email"},
|
||||
authUrl: vk.Endpoint.AuthURL,
|
||||
tokenUrl: vk.Endpoint.TokenURL,
|
||||
userApiUrl: "https://api.vk.com/method/users.get?fields=photo_max,screen_name&v=5.131",
|
||||
authURL: vk.Endpoint.AuthURL,
|
||||
tokenURL: vk.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.vk.com/method/users.get?fields=photo_max,screen_name&v=5.131",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -42,7 +46,7 @@ func NewVKProvider() *VK {
|
||||
//
|
||||
// API reference: https://dev.vk.com/method/users.get
|
||||
func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -58,7 +62,7 @@ func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Username string `json:"screen_name"`
|
||||
AvatarUrl string `json:"photo_max"`
|
||||
AvatarURL string `json:"photo_max"`
|
||||
} `json:"response"`
|
||||
}{}
|
||||
|
||||
@@ -74,7 +78,7 @@ func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
Id: strconv.Itoa(extracted.Response[0].Id),
|
||||
Name: strings.TrimSpace(extracted.Response[0].FirstName + " " + extracted.Response[0].LastName),
|
||||
Username: extracted.Response[0].Username,
|
||||
AvatarUrl: extracted.Response[0].AvatarUrl,
|
||||
AvatarURL: extracted.Response[0].AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
+11
-7
@@ -9,6 +9,10 @@ import (
|
||||
"golang.org/x/oauth2/yandex"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameYandex] = wrapFactory(NewYandexProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Yandex)(nil)
|
||||
|
||||
// NameYandex is the unique name of the Yandex provider.
|
||||
@@ -16,21 +20,21 @@ const NameYandex string = "yandex"
|
||||
|
||||
// Yandex allows authentication via Yandex OAuth2.
|
||||
type Yandex struct {
|
||||
*baseProvider
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewYandexProvider creates new Yandex provider instance with some defaults.
|
||||
//
|
||||
// Docs: https://yandex.ru/dev/id/doc/en/
|
||||
func NewYandexProvider() *Yandex {
|
||||
return &Yandex{&baseProvider{
|
||||
return &Yandex{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Yandex",
|
||||
pkce: true,
|
||||
scopes: []string{"login:email", "login:avatar", "login:info"},
|
||||
authUrl: yandex.Endpoint.AuthURL,
|
||||
tokenUrl: yandex.Endpoint.TokenURL,
|
||||
userApiUrl: "https://login.yandex.ru/info",
|
||||
authURL: yandex.Endpoint.AuthURL,
|
||||
tokenURL: yandex.Endpoint.TokenURL,
|
||||
userInfoURL: "https://login.yandex.ru/info",
|
||||
}}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func NewYandexProvider() *Yandex {
|
||||
//
|
||||
// API reference: https://yandex.ru/dev/id/doc/en/user-information#response-format
|
||||
func (p *Yandex) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserData(token)
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -73,7 +77,7 @@ func (p *Yandex) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if !extracted.IsAvatarEmpty {
|
||||
user.AvatarUrl = "https://avatars.yandex.net/get-yapic/" + extracted.AvatarId + "/islands-200"
|
||||
user.AvatarURL = "https://avatars.yandex.net/get-yapic/" + extracted.AvatarId + "/islands-200"
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
+24
-25
@@ -26,10 +26,9 @@ type Cron struct {
|
||||
ticker *time.Ticker
|
||||
startTimer *time.Timer
|
||||
jobs map[string]*job
|
||||
interval time.Duration
|
||||
tickerDone chan bool
|
||||
|
||||
sync.RWMutex
|
||||
interval time.Duration
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// New create a new Cron struct with default tick interval of 1 minute
|
||||
@@ -50,10 +49,10 @@ func New() *Cron {
|
||||
// (it usually should be >= 1 minute).
|
||||
func (c *Cron) SetInterval(d time.Duration) {
|
||||
// update interval
|
||||
c.Lock()
|
||||
c.mux.Lock()
|
||||
wasStarted := c.ticker != nil
|
||||
c.interval = d
|
||||
c.Unlock()
|
||||
c.mux.Unlock()
|
||||
|
||||
// restart the ticker
|
||||
if wasStarted {
|
||||
@@ -63,8 +62,8 @@ func (c *Cron) SetInterval(d time.Duration) {
|
||||
|
||||
// SetTimezone changes the current cron tick timezone.
|
||||
func (c *Cron) SetTimezone(l *time.Location) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.timezone = l
|
||||
}
|
||||
@@ -88,8 +87,8 @@ func (c *Cron) Add(jobId string, cronExpr string, run func()) error {
|
||||
return errors.New("failed to add new cron job: run must be non-nil function")
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
schedule, err := NewSchedule(cronExpr)
|
||||
if err != nil {
|
||||
@@ -106,24 +105,24 @@ func (c *Cron) Add(jobId string, cronExpr string, run func()) error {
|
||||
|
||||
// Remove removes a single cron job by its id.
|
||||
func (c *Cron) Remove(jobId string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
delete(c.jobs, jobId)
|
||||
}
|
||||
|
||||
// RemoveAll removes all registered cron jobs.
|
||||
func (c *Cron) RemoveAll() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.jobs = map[string]*job{}
|
||||
}
|
||||
|
||||
// Total returns the current total number of registered cron jobs.
|
||||
func (c *Cron) Total() int {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return len(c.jobs)
|
||||
}
|
||||
@@ -132,8 +131,8 @@ func (c *Cron) Total() int {
|
||||
//
|
||||
// You can resume the ticker by calling Start().
|
||||
func (c *Cron) Stop() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.startTimer != nil {
|
||||
c.startTimer.Stop()
|
||||
@@ -160,11 +159,11 @@ func (c *Cron) Start() {
|
||||
next := now.Add(c.interval).Truncate(c.interval)
|
||||
delay := next.Sub(now)
|
||||
|
||||
c.Lock()
|
||||
c.mux.Lock()
|
||||
c.startTimer = time.AfterFunc(delay, func() {
|
||||
c.Lock()
|
||||
c.mux.Lock()
|
||||
c.ticker = time.NewTicker(c.interval)
|
||||
c.Unlock()
|
||||
c.mux.Unlock()
|
||||
|
||||
// run immediately at 00
|
||||
c.runDue(time.Now())
|
||||
@@ -181,21 +180,21 @@ func (c *Cron) Start() {
|
||||
}
|
||||
}()
|
||||
})
|
||||
c.Unlock()
|
||||
c.mux.Unlock()
|
||||
}
|
||||
|
||||
// HasStarted checks whether the current Cron ticker has been started.
|
||||
func (c *Cron) HasStarted() bool {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return c.ticker != nil
|
||||
}
|
||||
|
||||
// runDue runs all registered jobs that are scheduled for the provided time.
|
||||
func (c *Cron) runDue(t time.Time) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
moment := NewMoment(t.In(c.timezone))
|
||||
|
||||
|
||||
+29
-24
@@ -2,6 +2,7 @@ package cron_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -252,26 +253,28 @@ func TestNewSchedule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
schedule, err := cron.NewSchedule(s.cronExpr)
|
||||
t.Run(s.cronExpr, func(t *testing.T) {
|
||||
schedule, err := cron.NewSchedule(s.cronExpr)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.cronExpr, s.expectError, hasErr, err)
|
||||
}
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(schedule)
|
||||
if err != nil {
|
||||
t.Fatalf("[%s] Failed to marshalize the result schedule: %v", s.cronExpr, err)
|
||||
}
|
||||
encodedStr := string(encoded)
|
||||
encoded, err := json.Marshal(schedule)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshalize the result schedule: %v", err)
|
||||
}
|
||||
encodedStr := string(encoded)
|
||||
|
||||
if encodedStr != s.expectSchedule {
|
||||
t.Fatalf("[%s] Expected \n%s, \ngot \n%s", s.cronExpr, s.expectSchedule, encodedStr)
|
||||
}
|
||||
if encodedStr != s.expectSchedule {
|
||||
t.Fatalf("Expected \n%s, \ngot \n%s", s.expectSchedule, encodedStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,15 +393,17 @@ func TestScheduleIsDue(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
schedule, err := cron.NewSchedule(s.cronExpr)
|
||||
if err != nil {
|
||||
t.Fatalf("[%d-%s] Unexpected cron error: %v", i, s.cronExpr, err)
|
||||
}
|
||||
t.Run(fmt.Sprintf("%d-%s", i, s.cronExpr), func(t *testing.T) {
|
||||
schedule, err := cron.NewSchedule(s.cronExpr)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected cron error: %v", err)
|
||||
}
|
||||
|
||||
result := schedule.IsDue(s.moment)
|
||||
result := schedule.IsDue(s.moment)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("[%d-%s] Expected %v, got %v", i, s.cronExpr, s.expected, result)
|
||||
}
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,31 +5,31 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JsonEach returns JSON_EACH SQLite string expression with
|
||||
// JSONEach returns JSON_EACH SQLite string expression with
|
||||
// some normalizations for non-json columns.
|
||||
func JsonEach(column string) string {
|
||||
func JSONEach(column string) string {
|
||||
return fmt.Sprintf(
|
||||
`json_each(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE json_array([[%s]]) END)`,
|
||||
column, column, column,
|
||||
)
|
||||
}
|
||||
|
||||
// JsonArrayLength returns JSON_ARRAY_LENGTH SQLite string expression
|
||||
// JSONArrayLength returns JSON_ARRAY_LENGTH SQLite string expression
|
||||
// with some normalizations for non-json columns.
|
||||
//
|
||||
// It works with both json and non-json column values.
|
||||
//
|
||||
// Returns 0 for empty string or NULL column values.
|
||||
func JsonArrayLength(column string) string {
|
||||
func JSONArrayLength(column string) string {
|
||||
return fmt.Sprintf(
|
||||
`json_array_length(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE (CASE WHEN [[%s]] = '' OR [[%s]] IS NULL THEN json_array() ELSE json_array([[%s]]) END) END)`,
|
||||
column, column, column, column, column,
|
||||
)
|
||||
}
|
||||
|
||||
// JsonExtract returns a JSON_EXTRACT SQLite string expression with
|
||||
// JSONExtract returns a JSON_EXTRACT SQLite string expression with
|
||||
// some normalizations for non-json columns.
|
||||
func JsonExtract(column string, path string) string {
|
||||
func JSONExtract(column string, path string) string {
|
||||
// prefix the path with dot if it is not starting with array notation
|
||||
if path != "" && !strings.HasPrefix(path, "[") {
|
||||
path = "." + path
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
)
|
||||
|
||||
func TestJsonEach(t *testing.T) {
|
||||
result := dbutils.JsonEach("a.b")
|
||||
func TestJSONEach(t *testing.T) {
|
||||
result := dbutils.JSONEach("a.b")
|
||||
|
||||
expected := "json_each(CASE WHEN json_valid([[a.b]]) THEN [[a.b]] ELSE json_array([[a.b]]) END)"
|
||||
|
||||
@@ -16,8 +16,8 @@ func TestJsonEach(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonArrayLength(t *testing.T) {
|
||||
result := dbutils.JsonArrayLength("a.b")
|
||||
func TestJSONArrayLength(t *testing.T) {
|
||||
result := dbutils.JSONArrayLength("a.b")
|
||||
|
||||
expected := "json_array_length(CASE WHEN json_valid([[a.b]]) THEN [[a.b]] ELSE (CASE WHEN [[a.b]] = '' OR [[a.b]] IS NULL THEN json_array() ELSE json_array([[a.b]]) END) END)"
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestJsonArrayLength(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonExtract(t *testing.T) {
|
||||
func TestJSONExtract(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
column string
|
||||
@@ -55,12 +55,11 @@ func TestJsonExtract(t *testing.T) {
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := dbutils.JsonExtract(s.column, s.path)
|
||||
result := dbutils.JSONExtract(s.column, s.path)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -27,10 +27,20 @@ type FileReader interface {
|
||||
//
|
||||
// The file could be from a local path, multipart/form-data header, etc.
|
||||
type File struct {
|
||||
Reader FileReader
|
||||
Name string
|
||||
OriginalName string
|
||||
Size int64
|
||||
Reader FileReader `form:"-" json:"-" xml:"-"`
|
||||
Name string `form:"name" json:"name" xml:"name"`
|
||||
OriginalName string `form:"originalName" json:"originalName" xml:"originalName"`
|
||||
Size int64 `form:"size" json:"size" xml:"size"`
|
||||
}
|
||||
|
||||
// AsMap implements [core.mapExtractor] and returns a value suitable
|
||||
// to be used in an API rule expression.
|
||||
func (f *File) AsMap() map[string]any {
|
||||
return map[string]any{
|
||||
"name": f.Name,
|
||||
"originalName": f.OriginalName,
|
||||
"size": f.Size,
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileFromPath creates a new File instance from the provided local file path.
|
||||
@@ -79,7 +89,7 @@ func NewFileFromMultipart(mh *multipart.FileHeader) (*File, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// NewFileFromUrl creates a new File from the provided url by
|
||||
// NewFileFromURL creates a new File from the provided url by
|
||||
// downloading the resource and load it as BytesReader.
|
||||
//
|
||||
// Example
|
||||
@@ -87,8 +97,8 @@ func NewFileFromMultipart(mh *multipart.FileHeader) (*File, error) {
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// file, err := filesystem.NewFileFromUrl(ctx, "https://example.com/image.png")
|
||||
func NewFileFromUrl(ctx context.Context, url string) (*File, error) {
|
||||
// file, err := filesystem.NewFileFromURL(ctx, "https://example.com/image.png")
|
||||
func NewFileFromURL(ctx context.Context, url string) (*File, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -168,6 +178,8 @@ func (r *bytesReadSeekCloser) Close() error {
|
||||
|
||||
var extInvalidCharsRegex = regexp.MustCompile(`[^\w\.\*\-\+\=\#]+`)
|
||||
|
||||
const randomAlphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
func normalizeName(fr FileReader, name string) string {
|
||||
// extension
|
||||
// ---
|
||||
@@ -187,7 +199,7 @@ func normalizeName(fr FileReader, name string) string {
|
||||
cleanName := inflector.Snakecase(strings.TrimSuffix(name, originalExt))
|
||||
if length := len(cleanName); length < 3 {
|
||||
// the name is too short so we concatenate an additional random part
|
||||
cleanName += security.RandomString(10)
|
||||
cleanName += security.RandomStringWithAlphabet(10, randomAlphabet)
|
||||
} else if length > 100 {
|
||||
// keep only the first 100 characters (it is multibyte safe after Snakecase)
|
||||
cleanName = cleanName[:100]
|
||||
@@ -196,7 +208,7 @@ func normalizeName(fr FileReader, name string) string {
|
||||
return fmt.Sprintf(
|
||||
"%s_%s%s",
|
||||
cleanName,
|
||||
security.RandomString(10), // ensure that there is always a random part
|
||||
security.RandomStringWithAlphabet(10, randomAlphabet), // ensure that there is always a random part
|
||||
cleanExt,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -12,11 +12,35 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
func TestFileAsMap(t *testing.T) {
|
||||
file, err := filesystem.NewFileFromBytes([]byte("test"), "test123.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result := file.AsMap()
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("Expected map with %d keys, got\n%v", 3, result)
|
||||
}
|
||||
|
||||
if result["size"] != int64(4) {
|
||||
t.Fatalf("Expected size %d, got %#v", 4, result["size"])
|
||||
}
|
||||
|
||||
if str, ok := result["name"].(string); !ok || !strings.HasPrefix(str, "test123") {
|
||||
t.Fatalf("Expected name to have prefix %q, got %#v", "test123", result["name"])
|
||||
}
|
||||
|
||||
if result["originalName"] != "test123.txt" {
|
||||
t.Fatalf("Expected originalName %q, got %#v", "test123.txt", result["originalName"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileFromPath(t *testing.T) {
|
||||
testDir := createTestDir(t)
|
||||
defer os.RemoveAll(testDir)
|
||||
@@ -83,7 +107,7 @@ func TestNewFileFromMultipart(t *testing.T) {
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("", "/", formData)
|
||||
req.Header.Set(echo.HeaderContentType, mp.FormDataContentType())
|
||||
req.Header.Set("Content-Type", mp.FormDataContentType())
|
||||
req.ParseMultipartForm(32 << 20)
|
||||
|
||||
_, mh, err := req.FormFile("test")
|
||||
@@ -115,7 +139,7 @@ func TestNewFileFromMultipart(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileFromUrlTimeout(t *testing.T) {
|
||||
func TestNewFileFromURLTimeout(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/error" {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -129,7 +153,7 @@ func TestNewFileFromUrlTimeout(t *testing.T) {
|
||||
{
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
f, err := filesystem.NewFileFromUrl(ctx, srv.URL+"/cancel")
|
||||
f, err := filesystem.NewFileFromURL(ctx, srv.URL+"/cancel")
|
||||
if err == nil {
|
||||
t.Fatal("[ctx_cancel] Expected error, got nil")
|
||||
}
|
||||
@@ -140,7 +164,7 @@ func TestNewFileFromUrlTimeout(t *testing.T) {
|
||||
|
||||
// error response
|
||||
{
|
||||
f, err := filesystem.NewFileFromUrl(context.Background(), srv.URL+"/error")
|
||||
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/error")
|
||||
if err == nil {
|
||||
t.Fatal("[error_status] Expected error, got nil")
|
||||
}
|
||||
@@ -154,7 +178,7 @@ func TestNewFileFromUrlTimeout(t *testing.T) {
|
||||
originalName := "image_! noext"
|
||||
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".txt")
|
||||
|
||||
f, err := filesystem.NewFileFromUrl(context.Background(), srv.URL+"/"+originalName)
|
||||
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/"+originalName)
|
||||
if err != nil {
|
||||
t.Fatalf("[valid] Unexpected error %v", err)
|
||||
}
|
||||
|
||||
@@ -20,13 +20,17 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3lite"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"gocloud.dev/blob"
|
||||
"gocloud.dev/blob/fileblob"
|
||||
"gocloud.dev/gcerrors"
|
||||
)
|
||||
|
||||
var gcpIgnoreHeaders = []string{"Accept-Encoding"}
|
||||
|
||||
var ErrNotFound = errors.New("blob not found")
|
||||
|
||||
type System struct {
|
||||
ctx context.Context
|
||||
bucket *blob.Bucket
|
||||
@@ -47,25 +51,23 @@ func NewS3(
|
||||
|
||||
cred := credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
cfg, err := config.LoadDefaultConfig(
|
||||
ctx,
|
||||
config.WithCredentialsProvider(cred),
|
||||
config.WithRegion(region),
|
||||
config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
|
||||
// ensure that the endpoint has url scheme for
|
||||
// backward compatibility with v1 of the aws sdk
|
||||
prefixedEndpoint := endpoint
|
||||
if !strings.Contains(endpoint, "://") {
|
||||
prefixedEndpoint = "https://" + endpoint
|
||||
}
|
||||
|
||||
return aws.Endpoint{URL: prefixedEndpoint, SigningRegion: region}, nil
|
||||
})),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
// ensure that the endpoint has url scheme for
|
||||
// backward compatibility with v1 of the aws sdk
|
||||
if !strings.Contains(endpoint, "://") {
|
||||
endpoint = "https://" + endpoint
|
||||
}
|
||||
o.BaseEndpoint = aws.String(endpoint)
|
||||
|
||||
o.UsePathStyle = s3ForcePathStyle
|
||||
|
||||
// Google Cloud Storage alters the Accept-Encoding header,
|
||||
@@ -76,7 +78,7 @@ func NewS3(
|
||||
}
|
||||
})
|
||||
|
||||
bucket, err := OpenBucketV2(ctx, client, bucketName, nil)
|
||||
bucket, err := s3lite.OpenBucketV2(ctx, client, bucketName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -116,32 +118,59 @@ func (s *System) Close() error {
|
||||
}
|
||||
|
||||
// Exists checks if file with fileKey path exists or not.
|
||||
//
|
||||
// If the file doesn't exist returns false and ErrNotFound.
|
||||
func (s *System) Exists(fileKey string) (bool, error) {
|
||||
return s.bucket.Exists(s.ctx, fileKey)
|
||||
exists, err := s.bucket.Exists(s.ctx, fileKey)
|
||||
|
||||
if gcerrors.Code(err) == gcerrors.NotFound {
|
||||
err = ErrNotFound
|
||||
}
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Attributes returns the attributes for the file with fileKey path.
|
||||
//
|
||||
// If the file doesn't exist it returns ErrNotFound.
|
||||
func (s *System) Attributes(fileKey string) (*blob.Attributes, error) {
|
||||
return s.bucket.Attributes(s.ctx, fileKey)
|
||||
attrs, err := s.bucket.Attributes(s.ctx, fileKey)
|
||||
|
||||
if gcerrors.Code(err) == gcerrors.NotFound {
|
||||
err = ErrNotFound
|
||||
}
|
||||
|
||||
return attrs, err
|
||||
}
|
||||
|
||||
// GetFile returns a file content reader for the given fileKey.
|
||||
//
|
||||
// NB! Make sure to call `Close()` after you are done working with it.
|
||||
// NB! Make sure to call Close() on the file after you are done working with it.
|
||||
//
|
||||
// If the file doesn't exist returns ErrNotFound.
|
||||
func (s *System) GetFile(fileKey string) (*blob.Reader, error) {
|
||||
br, err := s.bucket.NewReader(s.ctx, fileKey, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
if gcerrors.Code(err) == gcerrors.NotFound {
|
||||
err = ErrNotFound
|
||||
}
|
||||
|
||||
return br, nil
|
||||
return br, err
|
||||
}
|
||||
|
||||
// Copy copies the file stored at srcKey to dstKey.
|
||||
//
|
||||
// If srcKey file doesn't exist, it returns ErrNotFound.
|
||||
//
|
||||
// If dstKey file already exists, it is overwritten.
|
||||
func (s *System) Copy(srcKey, dstKey string) error {
|
||||
return s.bucket.Copy(s.ctx, dstKey, srcKey, nil)
|
||||
err := s.bucket.Copy(s.ctx, dstKey, srcKey, nil)
|
||||
|
||||
if gcerrors.Code(err) == gcerrors.NotFound {
|
||||
err = ErrNotFound
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// List returns a flat list with info for all files under the specified prefix.
|
||||
@@ -178,14 +207,13 @@ func (s *System) Upload(content []byte, fileKey string) error {
|
||||
}
|
||||
|
||||
if _, err := w.Write(content); err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
return errors.Join(err, w.Close())
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// UploadFile uploads the provided multipart file to the fileKey location.
|
||||
// UploadFile uploads the provided File to the fileKey location.
|
||||
func (s *System) UploadFile(file *File, fileKey string) error {
|
||||
f, err := file.Reader.Open()
|
||||
if err != nil {
|
||||
@@ -270,8 +298,16 @@ func (s *System) UploadMultipart(fh *multipart.FileHeader, fileKey string) error
|
||||
}
|
||||
|
||||
// Delete deletes stored file at fileKey location.
|
||||
//
|
||||
// If the file doesn't exist returns ErrNotFound.
|
||||
func (s *System) Delete(fileKey string) error {
|
||||
return s.bucket.Delete(s.ctx, fileKey)
|
||||
err := s.bucket.Delete(s.ctx, fileKey)
|
||||
|
||||
if gcerrors.Code(err) == gcerrors.NotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeletePrefix deletes everything starting with the specified prefix.
|
||||
@@ -345,6 +381,26 @@ func (s *System) DeletePrefix(prefix string) []error {
|
||||
return failed
|
||||
}
|
||||
|
||||
// Checks if the provided dir prefix doesn't have any files.
|
||||
//
|
||||
// A trailing slash will be appended to a non-empty dir string argument
|
||||
// to ensure that the checked prefix is a "directory".
|
||||
//
|
||||
// Returns "false" in case the has at least one file, otherwise - "true".
|
||||
func (s *System) IsEmptyDir(dir string) bool {
|
||||
if dir != "" && !strings.HasSuffix(dir, "/") {
|
||||
dir += "/"
|
||||
}
|
||||
|
||||
iter := s.bucket.List(&blob.ListOptions{
|
||||
Prefix: dir,
|
||||
})
|
||||
|
||||
_, err := iter.Next(s.ctx)
|
||||
|
||||
return err == io.EOF
|
||||
}
|
||||
|
||||
var inlineServeContentTypes = []string{
|
||||
// image
|
||||
"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp", "image/x-icon", "image/bmp",
|
||||
@@ -371,8 +427,11 @@ const forceAttachmentParam = "download"
|
||||
//
|
||||
// If the `download` query parameter is used the file will be always served for
|
||||
// download no matter of its type (aka. with "Content-Disposition: attachment").
|
||||
//
|
||||
// Internally this method uses [http.ServeContent] so Range requests,
|
||||
// If-Match, If-Unmodified-Since, etc. headers are handled transparently.
|
||||
func (s *System) Serve(res http.ResponseWriter, req *http.Request, fileKey string, name string) error {
|
||||
br, readErr := s.bucket.NewReader(s.ctx, fileKey, nil)
|
||||
br, readErr := s.GetFile(fileKey)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
@@ -444,7 +503,7 @@ func (s *System) CreateThumb(originalKey string, thumbKey, thumbSize string) err
|
||||
}
|
||||
|
||||
// fetch the original
|
||||
r, readErr := s.bucket.NewReader(s.ctx, originalKey, nil)
|
||||
r, readErr := s.GetFile(originalKey)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
|
||||
+182
-111
@@ -2,6 +2,7 @@ package filesystem_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"image/png"
|
||||
"mime/multipart"
|
||||
@@ -19,11 +20,11 @@ func TestFileSystemExists(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
file string
|
||||
@@ -35,12 +36,18 @@ func TestFileSystemExists(t *testing.T) {
|
||||
{"image.png", true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
exists, _ := fs.Exists(scenario.file)
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.file, func(t *testing.T) {
|
||||
exists, err := fsys.Exists(s.file)
|
||||
|
||||
if exists != scenario.exists {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.exists, exists)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if exists != s.exists {
|
||||
t.Fatalf("Expected exists %v, got %v", s.exists, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,11 +55,11 @@ func TestFileSystemAttributes(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
file string
|
||||
@@ -65,20 +72,24 @@ func TestFileSystemAttributes(t *testing.T) {
|
||||
{"image.png", false, "image/png"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
attr, err := fs.Attributes(scenario.file)
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.file, func(t *testing.T) {
|
||||
attr, err := fsys.Attributes(s.file)
|
||||
|
||||
if err == nil && scenario.expectError {
|
||||
t.Errorf("(%d) Expected error, got nil", i)
|
||||
}
|
||||
hasErr := err != nil
|
||||
|
||||
if err != nil && !scenario.expectError {
|
||||
t.Errorf("(%d) Expected nil, got error, %v", i, err)
|
||||
}
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
|
||||
if err == nil && attr.ContentType != scenario.expectContentType {
|
||||
t.Errorf("(%d) Expected attr.ContentType to be %q, got %q", i, scenario.expectContentType, attr.ContentType)
|
||||
}
|
||||
if hasErr && !errors.Is(err, filesystem.ErrNotFound) {
|
||||
t.Fatalf("Expected ErrNotFound err, got %q", err)
|
||||
}
|
||||
|
||||
if !hasErr && attr.ContentType != s.expectContentType {
|
||||
t.Fatalf("Expected attr.ContentType to be %q, got %q", s.expectContentType, attr.ContentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,17 +97,17 @@ func TestFileSystemDelete(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
if err := fs.Delete("missing.txt"); err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
if err := fsys.Delete("missing.txt"); err == nil || !errors.Is(err, filesystem.ErrNotFound) {
|
||||
t.Fatalf("Expected ErrNotFound error, got %v", err)
|
||||
}
|
||||
|
||||
if err := fs.Delete("image.png"); err != nil {
|
||||
if err := fsys.Delete("image.png"); err != nil {
|
||||
t.Fatalf("Expected nil, got error %v", err)
|
||||
}
|
||||
}
|
||||
@@ -105,29 +116,29 @@ func TestFileSystemDeletePrefixWithoutTrailingSlash(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
if errs := fs.DeletePrefix(""); len(errs) == 0 {
|
||||
if errs := fsys.DeletePrefix(""); len(errs) == 0 {
|
||||
t.Fatal("Expected error, got nil", errs)
|
||||
}
|
||||
|
||||
if errs := fs.DeletePrefix("missing"); len(errs) != 0 {
|
||||
if errs := fsys.DeletePrefix("missing"); len(errs) != 0 {
|
||||
t.Fatalf("Not existing prefix shouldn't error, got %v", errs)
|
||||
}
|
||||
|
||||
if errs := fs.DeletePrefix("test"); len(errs) != 0 {
|
||||
if errs := fsys.DeletePrefix("test"); len(errs) != 0 {
|
||||
t.Fatalf("Expected nil, got errors %v", errs)
|
||||
}
|
||||
|
||||
// ensure that the test/* files are deleted
|
||||
if exists, _ := fs.Exists("test/sub1.txt"); exists {
|
||||
if exists, _ := fsys.Exists("test/sub1.txt"); exists {
|
||||
t.Fatalf("Expected test/sub1.txt to be deleted")
|
||||
}
|
||||
if exists, _ := fs.Exists("test/sub2.txt"); exists {
|
||||
if exists, _ := fsys.Exists("test/sub2.txt"); exists {
|
||||
t.Fatalf("Expected test/sub2.txt to be deleted")
|
||||
}
|
||||
|
||||
@@ -141,25 +152,25 @@ func TestFileSystemDeletePrefixWithTrailingSlash(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
if errs := fs.DeletePrefix("missing/"); len(errs) != 0 {
|
||||
if errs := fsys.DeletePrefix("missing/"); len(errs) != 0 {
|
||||
t.Fatalf("Not existing prefix shouldn't error, got %v", errs)
|
||||
}
|
||||
|
||||
if errs := fs.DeletePrefix("test/"); len(errs) != 0 {
|
||||
if errs := fsys.DeletePrefix("test/"); len(errs) != 0 {
|
||||
t.Fatalf("Expected nil, got errors %v", errs)
|
||||
}
|
||||
|
||||
// ensure that the test/* files are deleted
|
||||
if exists, _ := fs.Exists("test/sub1.txt"); exists {
|
||||
if exists, _ := fsys.Exists("test/sub1.txt"); exists {
|
||||
t.Fatalf("Expected test/sub1.txt to be deleted")
|
||||
}
|
||||
if exists, _ := fs.Exists("test/sub2.txt"); exists {
|
||||
if exists, _ := fsys.Exists("test/sub2.txt"); exists {
|
||||
t.Fatalf("Expected test/sub2.txt to be deleted")
|
||||
}
|
||||
|
||||
@@ -169,6 +180,41 @@ func TestFileSystemDeletePrefixWithTrailingSlash(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemIsEmptyDir(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
dir string
|
||||
expected bool
|
||||
}{
|
||||
{"", false}, // special case that shouldn't be suffixed with delimiter to search for any files within the bucket
|
||||
{"/", true},
|
||||
{"missing", true},
|
||||
{"missing/", true},
|
||||
{"test", false},
|
||||
{"test/", false},
|
||||
{"empty", true},
|
||||
{"empty/", true},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.dir, func(t *testing.T) {
|
||||
result := fsys.IsEmptyDir(s.dir)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemUploadMultipart(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
@@ -193,24 +239,24 @@ func TestFileSystemUploadMultipart(t *testing.T) {
|
||||
defer file.Close()
|
||||
// ---
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
fileKey := "newdir/newkey.txt"
|
||||
|
||||
uploadErr := fs.UploadMultipart(fh, fileKey)
|
||||
uploadErr := fsys.UploadMultipart(fh, fileKey)
|
||||
if uploadErr != nil {
|
||||
t.Fatal(uploadErr)
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists(fileKey); !exists {
|
||||
if exists, _ := fsys.Exists(fileKey); !exists {
|
||||
t.Fatalf("Expected %q to exist", fileKey)
|
||||
}
|
||||
|
||||
attrs, err := fs.Attributes(fileKey)
|
||||
attrs, err := fsys.Attributes(fileKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch file attributes: %v", err)
|
||||
}
|
||||
@@ -223,11 +269,11 @@ func TestFileSystemUploadFile(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
fileKey := "newdir/newkey.txt"
|
||||
|
||||
@@ -238,16 +284,16 @@ func TestFileSystemUploadFile(t *testing.T) {
|
||||
|
||||
file.OriginalName = "test.txt"
|
||||
|
||||
uploadErr := fs.UploadFile(file, fileKey)
|
||||
uploadErr := fsys.UploadFile(file, fileKey)
|
||||
if uploadErr != nil {
|
||||
t.Fatal(uploadErr)
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists(fileKey); !exists {
|
||||
if exists, _ := fsys.Exists(fileKey); !exists {
|
||||
t.Fatalf("Expected %q to exist", fileKey)
|
||||
}
|
||||
|
||||
attrs, err := fs.Attributes(fileKey)
|
||||
attrs, err := fsys.Attributes(fileKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch file attributes: %v", err)
|
||||
}
|
||||
@@ -260,20 +306,20 @@ func TestFileSystemUpload(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
fileKey := "newdir/newkey.txt"
|
||||
|
||||
uploadErr := fs.Upload([]byte("demo"), fileKey)
|
||||
uploadErr := fsys.Upload([]byte("demo"), fileKey)
|
||||
if uploadErr != nil {
|
||||
t.Fatal(uploadErr)
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists(fileKey); !exists {
|
||||
if exists, _ := fsys.Exists(fileKey); !exists {
|
||||
t.Fatalf("Expected %s to exist", fileKey)
|
||||
}
|
||||
}
|
||||
@@ -282,11 +328,11 @@ func TestFileSystemServe(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
csp := "default-src 'none'; media-src 'self'; style-src 'unsafe-inline'; sandbox"
|
||||
cacheControl := "max-age=2592000, stale-while-revalidate=86400"
|
||||
@@ -409,39 +455,41 @@ func TestFileSystemServe(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
t.Run(s.path, func(t *testing.T) {
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
query := req.URL.Query()
|
||||
for k, v := range s.query {
|
||||
query.Set(k, v)
|
||||
}
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
for k, v := range s.headers {
|
||||
res.Header().Set(k, v)
|
||||
}
|
||||
|
||||
err := fs.Serve(res, req, s.path, s.name)
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%s) Expected hasError %v, got %v (%v)", s.path, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if s.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
result := res.Result()
|
||||
|
||||
for hName, hValue := range s.expectHeaders {
|
||||
v := result.Header.Get(hName)
|
||||
if v != hValue {
|
||||
t.Errorf("(%s) Expected value %q for header %q, got %q", s.path, hValue, hName, v)
|
||||
query := req.URL.Query()
|
||||
for k, v := range s.query {
|
||||
query.Set(k, v)
|
||||
}
|
||||
}
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
for k, v := range s.headers {
|
||||
res.Header().Set(k, v)
|
||||
}
|
||||
|
||||
err := fsys.Serve(res, req, s.path, s.name)
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasError %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if s.expectError {
|
||||
return
|
||||
}
|
||||
|
||||
result := res.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
for hName, hValue := range s.expectHeaders {
|
||||
v := result.Header.Get(hName)
|
||||
if v != hValue {
|
||||
t.Errorf("Expected value %q for header %q, got %q", hValue, hName, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,20 +497,38 @@ func TestFileSystemGetFile(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
f, err := fs.GetFile("image.png")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
scenarios := []struct {
|
||||
file string
|
||||
expectError bool
|
||||
}{
|
||||
{"missing.png", true},
|
||||
{"image.png", false},
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if f == nil {
|
||||
t.Fatal("File is supposed to be found")
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.file, func(t *testing.T) {
|
||||
f, err := fsys.GetFile(s.file)
|
||||
|
||||
hasErr := err != nil
|
||||
|
||||
if !hasErr {
|
||||
defer f.Close()
|
||||
}
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
|
||||
if hasErr && !errors.Is(err, filesystem.ErrNotFound) {
|
||||
t.Fatalf("Expected ErrNotFound error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,25 +536,26 @@ func TestFileSystemCopy(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
src := "image.png"
|
||||
dst := "image.png_copy"
|
||||
|
||||
// copy missing file
|
||||
if err := fs.Copy(dst, src); err == nil {
|
||||
if err := fsys.Copy(dst, src); err == nil {
|
||||
t.Fatalf("Expected to fail copying %q to %q, got nil", dst, src)
|
||||
}
|
||||
|
||||
// copy existing file
|
||||
if err := fs.Copy(src, dst); err != nil {
|
||||
if err := fsys.Copy(src, dst); err != nil {
|
||||
t.Fatalf("Failed to copy %q to %q: %v", src, dst, err)
|
||||
}
|
||||
f, err := fs.GetFile(dst)
|
||||
f, err := fsys.GetFile(dst)
|
||||
//nolint
|
||||
defer f.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Missing copied file %q: %v", dst, err)
|
||||
@@ -502,11 +569,11 @@ func TestFileSystemList(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
prefix string
|
||||
@@ -537,7 +604,7 @@ func TestFileSystemList(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
objs, err := fs.List(s.prefix)
|
||||
objs, err := fsys.List(s.prefix)
|
||||
if err != nil {
|
||||
t.Fatalf("[%s] %v", s.prefix, err)
|
||||
}
|
||||
@@ -563,17 +630,17 @@ func TestFileSystemServeSingleRange(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Add("Range", "bytes=0-20")
|
||||
|
||||
if err := fs.Serve(res, req, "image.png", "image.png"); err != nil {
|
||||
if err := fsys.Serve(res, req, "image.png", "image.png"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -597,17 +664,17 @@ func TestFileSystemServeMultiRange(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Add("Range", "bytes=0-20, 25-30")
|
||||
|
||||
if err := fs.Serve(res, req, "image.png", "image.png"); err != nil {
|
||||
if err := fsys.Serve(res, req, "image.png", "image.png"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -626,11 +693,11 @@ func TestFileSystemCreateThumb(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
fsys, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
defer fsys.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
file string
|
||||
@@ -651,7 +718,7 @@ func TestFileSystemCreateThumb(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := fs.CreateThumb(scenario.file, scenario.thumb, "100x100")
|
||||
err := fsys.CreateThumb(scenario.file, scenario.thumb, "100x100")
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
@@ -663,7 +730,7 @@ func TestFileSystemCreateThumb(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists(scenario.thumb); !exists {
|
||||
if exists, _ := fsys.Exists(scenario.thumb); !exists {
|
||||
t.Errorf("(%d) Couldn't find %q thumb", i, scenario.thumb)
|
||||
}
|
||||
}
|
||||
@@ -677,6 +744,10 @@ func createTestDir(t *testing.T) string {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "empty"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "test"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@
|
||||
// (V1) *s3.PutObjectInput; (V2) *s3v2.PutObjectInput, when Options.Method == http.MethodPut, or
|
||||
// (V1) *s3.DeleteObjectInput; (V2) [not supported] when Options.Method == http.MethodDelete
|
||||
|
||||
package filesystem
|
||||
package s3lite
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -82,7 +82,6 @@ import (
|
||||
"strings"
|
||||
|
||||
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsv2cfg "github.com/aws/aws-sdk-go-v2/config"
|
||||
s3managerv2 "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
s3v2 "github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
typesv2 "github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
@@ -244,116 +243,8 @@ func URLUnescape(s string) string {
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// UseV2 returns true iff the URL parameters indicate that the provider
|
||||
// should use the AWS SDK v2.
|
||||
//
|
||||
// "awssdk=v1" will force V1.
|
||||
// "awssdk=v2" will force V2.
|
||||
// No "awssdk" parameter (or any other value) will return the default, currently V1.
|
||||
// Note that the default may change in the future.
|
||||
func UseV2(q url.Values) bool {
|
||||
if values, ok := q["awssdk"]; ok {
|
||||
if values[0] == "v2" || values[0] == "V2" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NewDefaultV2Config returns a aws.Config for AWS SDK v2, using the default options.
|
||||
func NewDefaultV2Config(ctx context.Context) (awsv2.Config, error) {
|
||||
return awsv2cfg.LoadDefaultConfig(ctx)
|
||||
}
|
||||
|
||||
// V2ConfigFromURLParams returns an aws.Config for AWS SDK v2 initialized based on the URL
|
||||
// parameters in q. It is intended to be used by URLOpeners for AWS services if
|
||||
// UseV2 returns true.
|
||||
//
|
||||
// https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#Config
|
||||
//
|
||||
// It returns an error if q contains any unknown query parameters; callers
|
||||
// should remove any query parameters they know about from q before calling
|
||||
// V2ConfigFromURLParams.
|
||||
//
|
||||
// The following query options are supported:
|
||||
// - region: The AWS region for requests; sets WithRegion.
|
||||
// - profile: The shared config profile to use; sets SharedConfigProfile.
|
||||
// - endpoint: The AWS service endpoint to send HTTP request.
|
||||
func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, error) {
|
||||
var opts []func(*awsv2cfg.LoadOptions) error
|
||||
for param, values := range q {
|
||||
value := values[0]
|
||||
switch param {
|
||||
case "region":
|
||||
opts = append(opts, awsv2cfg.WithRegion(value))
|
||||
case "endpoint":
|
||||
customResolver := awsv2.EndpointResolverWithOptionsFunc(
|
||||
func(service, region string, options ...interface{}) (awsv2.Endpoint, error) {
|
||||
return awsv2.Endpoint{
|
||||
PartitionID: "aws",
|
||||
URL: value,
|
||||
SigningRegion: region,
|
||||
}, nil
|
||||
})
|
||||
opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver))
|
||||
case "profile":
|
||||
opts = append(opts, awsv2cfg.WithSharedConfigProfile(value))
|
||||
case "awssdk":
|
||||
// ignore, should be handled before this
|
||||
default:
|
||||
return awsv2.Config{}, fmt.Errorf("unknown query parameter %q", param)
|
||||
}
|
||||
}
|
||||
return awsv2cfg.LoadDefaultConfig(ctx, opts...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const defaultPageSize = 1000
|
||||
|
||||
func init() {
|
||||
blob.DefaultURLMux().RegisterBucket(Scheme, new(urlSessionOpener))
|
||||
}
|
||||
|
||||
type urlSessionOpener struct{}
|
||||
|
||||
func (o *urlSessionOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
|
||||
opener := &URLOpener{UseV2: true}
|
||||
return opener.OpenBucketURL(ctx, u)
|
||||
}
|
||||
|
||||
// Scheme is the URL scheme s3blob registers its URLOpener under on
|
||||
// blob.DefaultMux.
|
||||
const Scheme = "s3"
|
||||
|
||||
// URLOpener opens S3 URLs like "s3://mybucket".
|
||||
//
|
||||
// The URL host is used as the bucket name.
|
||||
//
|
||||
// Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2,
|
||||
// or anything else to accept the default.
|
||||
//
|
||||
// For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters
|
||||
// for overriding the aws.Session from the URL.
|
||||
// For V2, see gocloud.dev/aws/V2ConfigFromURLParams.
|
||||
type URLOpener struct {
|
||||
// UseV2 indicates whether the AWS SDK V2 should be used.
|
||||
UseV2 bool
|
||||
|
||||
// Options specifies the options to pass to OpenBucket.
|
||||
Options Options
|
||||
}
|
||||
|
||||
// OpenBucketURL opens a blob.Bucket based on u.
|
||||
func (o *URLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
|
||||
cfg, err := V2ConfigFromURLParams(ctx, u.Query())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open bucket %v: %v", u, err)
|
||||
}
|
||||
clientV2 := s3v2.NewFromConfig(cfg)
|
||||
return OpenBucketV2(ctx, clientV2, u.Host, &o.Options)
|
||||
}
|
||||
|
||||
// Options sets options for constructing a *blob.Bucket backed by fileblob.
|
||||
type Options struct {
|
||||
// UseLegacyList forces the use of ListObjects instead of ListObjectsV2.
|
||||
@@ -676,64 +567,6 @@ func (b *bucket) listObjectsV2(ctx context.Context, in *s3v2.ListObjectsV2Input,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// func (b *bucket) listObjects(ctx context.Context, in *s3.ListObjectsV2Input, opts *driver.ListOptions) (*s3.ListObjectsV2Output, error) {
|
||||
// if !b.useLegacyList {
|
||||
// if opts.BeforeList != nil {
|
||||
// asFunc := func(i interface{}) bool {
|
||||
// if p, ok := i.(**s3.ListObjectsV2Input); ok {
|
||||
// *p = in
|
||||
// return true
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
// if err := opts.BeforeList(asFunc); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// }
|
||||
// return b.client.ListObjectsV2WithContext(ctx, in)
|
||||
// }
|
||||
|
||||
// // Use the legacy ListObjects request.
|
||||
// legacyIn := &s3.ListObjectsInput{
|
||||
// Bucket: in.Bucket,
|
||||
// Delimiter: in.Delimiter,
|
||||
// EncodingType: in.EncodingType,
|
||||
// Marker: in.ContinuationToken,
|
||||
// MaxKeys: in.MaxKeys,
|
||||
// Prefix: in.Prefix,
|
||||
// RequestPayer: in.RequestPayer,
|
||||
// }
|
||||
// if opts.BeforeList != nil {
|
||||
// asFunc := func(i interface{}) bool {
|
||||
// p, ok := i.(**s3.ListObjectsInput)
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
// *p = legacyIn
|
||||
// return true
|
||||
// }
|
||||
// if err := opts.BeforeList(asFunc); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// }
|
||||
// legacyResp, err := b.client.ListObjectsWithContext(ctx, legacyIn)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var nextContinuationToken *string
|
||||
// if legacyResp.NextMarker != nil {
|
||||
// nextContinuationToken = legacyResp.NextMarker
|
||||
// } else if awsv2.ToBool(legacyResp.IsTruncated) {
|
||||
// nextContinuationToken = awsv2.String(awsv2.ToString(legacyResp.Contents[len(legacyResp.Contents)-1].Key))
|
||||
// }
|
||||
// return &s3.ListObjectsV2Output{
|
||||
// CommonPrefixes: legacyResp.CommonPrefixes,
|
||||
// Contents: legacyResp.Contents,
|
||||
// NextContinuationToken: nextContinuationToken,
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// As implements driver.As.
|
||||
func (b *bucket) As(i interface{}) bool {
|
||||
p, ok := i.(**s3v2.Client)
|
||||
@@ -0,0 +1,45 @@
|
||||
package hook
|
||||
|
||||
// Resolver defines a common interface for a Hook event (see [Event]).
|
||||
type Resolver interface {
|
||||
// Next triggers the next handler in the hook's chain (if any).
|
||||
Next() error
|
||||
|
||||
// note: kept only for the generic interface; may get removed in the future
|
||||
nextFunc() func() error
|
||||
setNextFunc(f func() error)
|
||||
}
|
||||
|
||||
var _ Resolver = (*Event)(nil)
|
||||
|
||||
// Event implements [Resolver] and it is intended to be used as a base
|
||||
// Hook event that you can embed in your custom typed event structs.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type CustomEvent struct {
|
||||
// hook.Event
|
||||
//
|
||||
// SomeField int
|
||||
// }
|
||||
type Event struct {
|
||||
next func() error
|
||||
}
|
||||
|
||||
// Next calls the next hook handler.
|
||||
func (e *Event) Next() error {
|
||||
if e.next != nil {
|
||||
return e.next()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nextFunc returns the function that Next calls.
|
||||
func (e *Event) nextFunc() func() error {
|
||||
return e.next
|
||||
}
|
||||
|
||||
// setNextFunc sets the function that Next calls.
|
||||
func (e *Event) setNextFunc(f func() error) {
|
||||
e.next = f
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package hook
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEventNext(t *testing.T) {
|
||||
calls := 0
|
||||
|
||||
e := Event{}
|
||||
|
||||
if e.nextFunc() != nil {
|
||||
t.Fatalf("Expected nextFunc to be nil")
|
||||
}
|
||||
|
||||
e.setNextFunc(func() error {
|
||||
calls++
|
||||
return nil
|
||||
})
|
||||
|
||||
if e.nextFunc() == nil {
|
||||
t.Fatalf("Expected nextFunc to be non-nil")
|
||||
}
|
||||
|
||||
e.Next()
|
||||
e.Next()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("Expected %d calls, got %d", 2, calls)
|
||||
}
|
||||
}
|
||||
+138
-85
@@ -1,126 +1,179 @@
|
||||
package hook
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
var StopPropagation = errors.New("Event hook propagation stopped")
|
||||
// HandlerFunc defines a hook handler function.
|
||||
type HandlerFunc[T Resolver] func(e T) error
|
||||
|
||||
// Handler defines a hook handler function.
|
||||
type Handler[T any] func(e T) error
|
||||
// Handler defines a single Hook handler.
|
||||
// Multiple handlers can share the same id.
|
||||
// If Id is not explicitly set it will be autogenerated by Hook.Add and Hook.AddHandler.
|
||||
type Handler[T Resolver] struct {
|
||||
// Func defines the handler function to execute.
|
||||
//
|
||||
// Note that users need to call e.Next() in order to proceed with
|
||||
// the execution of the hook chain.
|
||||
Func HandlerFunc[T]
|
||||
|
||||
// handlerPair defines a pair of string id and Handler.
|
||||
type handlerPair[T any] struct {
|
||||
id string
|
||||
handler Handler[T]
|
||||
// Id is the unique identifier of the handler.
|
||||
//
|
||||
// It could be used later to remove the handler from a hook via [Hook.Remove].
|
||||
//
|
||||
// If missing, an autogenerated value will be assigned when adding
|
||||
// the handler to a hook.
|
||||
Id string
|
||||
|
||||
// Priority allows changing the default exec priority of the handler within a hook.
|
||||
//
|
||||
// If 0, the handler will be executed in the same order it was registered.
|
||||
Priority int
|
||||
}
|
||||
|
||||
// Hook defines a concurrent safe structure for handling event hooks
|
||||
// (aka. callbacks propagation).
|
||||
type Hook[T any] struct {
|
||||
mux sync.RWMutex
|
||||
handlers []*handlerPair[T]
|
||||
}
|
||||
|
||||
// PreAdd registers a new handler to the hook by prepending it to the existing queue.
|
||||
// Hook defines a generic concurrent safe structure for managing event hooks.
|
||||
//
|
||||
// Returns an autogenerated hook id that could be used later to remove the hook with Hook.Remove(id).
|
||||
func (h *Hook[T]) PreAdd(fn Handler[T]) string {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
id := generateHookId()
|
||||
|
||||
// minimize allocations by shifting the slice
|
||||
h.handlers = append(h.handlers, nil)
|
||||
copy(h.handlers[1:], h.handlers)
|
||||
h.handlers[0] = &handlerPair[T]{id, fn}
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
// Add registers a new handler to the hook by appending it to the existing queue.
|
||||
// When using custom a event it must embed the base [hook.Event].
|
||||
//
|
||||
// Returns an autogenerated hook id that could be used later to remove the hook with Hook.Remove(id).
|
||||
func (h *Hook[T]) Add(fn Handler[T]) string {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
|
||||
id := generateHookId()
|
||||
|
||||
h.handlers = append(h.handlers, &handlerPair[T]{id, fn})
|
||||
|
||||
return id
|
||||
// Example:
|
||||
//
|
||||
// type CustomEvent struct {
|
||||
// hook.Event
|
||||
// SomeField int
|
||||
// }
|
||||
//
|
||||
// h := Hook[*CustomEvent]{}
|
||||
//
|
||||
// h.BindFunc(func(e *CustomEvent) error {
|
||||
// println(e.SomeField)
|
||||
//
|
||||
// return e.Next()
|
||||
// })
|
||||
//
|
||||
// h.Trigger(&CustomEvent{ SomeField: 123 })
|
||||
type Hook[T Resolver] struct {
|
||||
handlers []*Handler[T]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Remove removes a single hook handler by its id.
|
||||
func (h *Hook[T]) Remove(id string) {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
// Bind registers the provided handler to the current hooks queue.
|
||||
//
|
||||
// If handler.Id is empty it is updated with autogenerated value.
|
||||
//
|
||||
// If a handler from the current hook list has Id matching handler.Id
|
||||
// then the old handler is replaced with the new one.
|
||||
func (h *Hook[T]) Bind(handler *Handler[T]) string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
var exists bool
|
||||
|
||||
if handler.Id == "" {
|
||||
handler.Id = generateHookId()
|
||||
|
||||
// ensure that it doesn't exist
|
||||
DUPLICATE_CHECK:
|
||||
for _, existing := range h.handlers {
|
||||
if existing.Id == handler.Id {
|
||||
handler.Id = generateHookId()
|
||||
goto DUPLICATE_CHECK
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// replace existing
|
||||
for i, existing := range h.handlers {
|
||||
if existing.Id == handler.Id {
|
||||
h.handlers[i] = handler
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// append new
|
||||
if !exists {
|
||||
h.handlers = append(h.handlers, handler)
|
||||
}
|
||||
|
||||
// sort handlers by Priority, preserving the original order of equal items
|
||||
sort.SliceStable(h.handlers, func(i, j int) bool {
|
||||
return h.handlers[i].Priority < h.handlers[j].Priority
|
||||
})
|
||||
|
||||
return handler.Id
|
||||
}
|
||||
|
||||
// BindFunc is similar to Bind but registers a new handler from just the provided function.
|
||||
//
|
||||
// The registered handler is added with a default 0 priority and the id will be autogenerated.
|
||||
//
|
||||
// If you want to register a handler with custom priority or id use the [Hook.Bind] method.
|
||||
func (h *Hook[T]) BindFunc(fn HandlerFunc[T]) string {
|
||||
return h.Bind(&Handler[T]{Func: fn})
|
||||
}
|
||||
|
||||
// Unbind removes a single hook handler by its id.
|
||||
func (h *Hook[T]) Unbind(id string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for i := len(h.handlers) - 1; i >= 0; i-- {
|
||||
if h.handlers[i].id == id {
|
||||
if h.handlers[i].Id == id {
|
||||
h.handlers = append(h.handlers[:i], h.handlers[i+1:]...)
|
||||
return
|
||||
break // for now stop on the first occurrence since we don't allow handlers with duplicated ids
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveAll removes all registered handlers.
|
||||
func (h *Hook[T]) RemoveAll() {
|
||||
h.mux.Lock()
|
||||
defer h.mux.Unlock()
|
||||
// UnbindAll removes all registered handlers.
|
||||
func (h *Hook[T]) UnbindAll() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.handlers = nil
|
||||
}
|
||||
|
||||
// Length returns to total number of registered hook handlers.
|
||||
func (h *Hook[T]) Length() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
return len(h.handlers)
|
||||
}
|
||||
|
||||
// Trigger executes all registered hook handlers one by one
|
||||
// with the specified `data` as an argument.
|
||||
// with the specified event as an argument.
|
||||
//
|
||||
// Optionally, this method allows also to register additional one off
|
||||
// handlers that will be temporary appended to the handlers queue.
|
||||
//
|
||||
// The execution stops when:
|
||||
// - hook.StopPropagation is returned in one of the handlers
|
||||
// - any non-nil error is returned in one of the handlers
|
||||
func (h *Hook[T]) Trigger(data T, oneOffHandlers ...Handler[T]) error {
|
||||
h.mux.RLock()
|
||||
// NB! Each hook handler must call event.Next() in order the hook chain to proceed.
|
||||
func (h *Hook[T]) Trigger(event T, oneOffHandlers ...HandlerFunc[T]) error {
|
||||
h.mu.RLock()
|
||||
handlers := make([]HandlerFunc[T], 0, len(h.handlers)+len(oneOffHandlers))
|
||||
for _, handler := range h.handlers {
|
||||
handlers = append(handlers, handler.Func)
|
||||
}
|
||||
handlers = append(handlers, oneOffHandlers...)
|
||||
h.mu.RUnlock()
|
||||
|
||||
handlers := make([]*handlerPair[T], 0, len(h.handlers)+len(oneOffHandlers))
|
||||
handlers = append(handlers, h.handlers...)
|
||||
event.setNextFunc(nil) // reset in case the event is being reused
|
||||
|
||||
// append the one off handlers
|
||||
for i, oneOff := range oneOffHandlers {
|
||||
handlers = append(handlers, &handlerPair[T]{
|
||||
id: fmt.Sprintf("@%d", i),
|
||||
handler: oneOff,
|
||||
for i := len(handlers) - 1; i >= 0; i-- {
|
||||
i := i
|
||||
old := event.nextFunc()
|
||||
event.setNextFunc(func() error {
|
||||
event.setNextFunc(old)
|
||||
return handlers[i](event)
|
||||
})
|
||||
}
|
||||
|
||||
// unlock is not deferred to avoid deadlocks in case Trigger
|
||||
// is called recursively by the handlers
|
||||
h.mux.RUnlock()
|
||||
|
||||
for _, item := range handlers {
|
||||
err := item.handler(data)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if errors.Is(err, StopPropagation) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return event.Next()
|
||||
}
|
||||
|
||||
func generateHookId() string {
|
||||
return security.PseudorandomString(8)
|
||||
return security.PseudorandomString(20)
|
||||
}
|
||||
|
||||
+106
-124
@@ -5,175 +5,157 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHookAddAndPreAdd(t *testing.T) {
|
||||
h := Hook[int]{}
|
||||
func TestHookAddHandlerAndAdd(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
if total := len(h.handlers); total != 0 {
|
||||
t.Fatalf("Expected no handlers, found %d", total)
|
||||
h := Hook[*Event]{}
|
||||
|
||||
h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
|
||||
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
|
||||
h3Id := h.BindFunc(func(e *Event) error { calls += "3"; return e.Next() })
|
||||
h.Bind(&Handler[*Event]{
|
||||
Id: h3Id, // should replace 3
|
||||
Func: func(e *Event) error { calls += "3'"; return e.Next() },
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "4"; return e.Next() },
|
||||
Priority: -2,
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "5"; return e.Next() },
|
||||
Priority: -1,
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "6"; return e.Next() },
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "7"; e.Next(); return errors.New("test") }, // error shouldn't stop the chain
|
||||
})
|
||||
|
||||
h.Trigger(
|
||||
&Event{},
|
||||
func(e *Event) error { calls += "8"; return e.Next() },
|
||||
func(e *Event) error { calls += "9"; return nil }, // skip next
|
||||
func(e *Event) error { calls += "10"; return e.Next() },
|
||||
)
|
||||
|
||||
if total := len(h.handlers); total != 7 {
|
||||
t.Fatalf("Expected %d handlers, found %d", 7, total)
|
||||
}
|
||||
|
||||
triggerSequence := ""
|
||||
expectedCalls := "45123'6789"
|
||||
|
||||
f1 := func(data int) error { triggerSequence += "f1"; return nil }
|
||||
f2 := func(data int) error { triggerSequence += "f2"; return nil }
|
||||
f3 := func(data int) error { triggerSequence += "f3"; return nil }
|
||||
f4 := func(data int) error { triggerSequence += "f4"; return nil }
|
||||
|
||||
h.Add(f1)
|
||||
h.Add(f2)
|
||||
h.PreAdd(f3)
|
||||
h.PreAdd(f4)
|
||||
h.Trigger(1)
|
||||
|
||||
if total := len(h.handlers); total != 4 {
|
||||
t.Fatalf("Expected %d handlers, found %d", 4, total)
|
||||
}
|
||||
|
||||
expectedTriggerSequence := "f4f3f1f2"
|
||||
|
||||
if triggerSequence != expectedTriggerSequence {
|
||||
t.Fatalf("Expected trigger sequence %s, got %s", expectedTriggerSequence, triggerSequence)
|
||||
if calls != expectedCalls {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRemove(t *testing.T) {
|
||||
h := Hook[int]{}
|
||||
func TestHookLength(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
|
||||
h1Called := false
|
||||
h2Called := false
|
||||
if l := h.Length(); l != 0 {
|
||||
t.Fatalf("Expected 0 hook handlers, got %d", l)
|
||||
}
|
||||
|
||||
id1 := h.Add(func(data int) error { h1Called = true; return nil })
|
||||
h.Add(func(data int) error { h2Called = true; return nil })
|
||||
h.BindFunc(func(e *Event) error { return e.Next() })
|
||||
h.BindFunc(func(e *Event) error { return e.Next() })
|
||||
|
||||
h.Remove("missing") // should do nothing and not panic
|
||||
if l := h.Length(); l != 2 {
|
||||
t.Fatalf("Expected 2 hook handlers, got %d", l)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookUnbind(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
id1 := h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
|
||||
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "3"; return e.Next() },
|
||||
})
|
||||
|
||||
h.Unbind("missing") // should do nothing and not panic
|
||||
|
||||
if total := len(h.handlers); total != 3 {
|
||||
t.Fatalf("Expected %d handlers, got %d", 3, total)
|
||||
}
|
||||
|
||||
h.Unbind(id1)
|
||||
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected %d handlers, got %d", 2, total)
|
||||
}
|
||||
|
||||
h.Remove(id1)
|
||||
|
||||
if total := len(h.handlers); total != 1 {
|
||||
t.Fatalf("Expected %d handlers, got %d", 1, total)
|
||||
}
|
||||
|
||||
if err := h.Trigger(1); err != nil {
|
||||
err := h.Trigger(&Event{}, func(e *Event) error { calls += "4"; return e.Next() })
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if h1Called {
|
||||
t.Fatalf("Expected hook 1 to be removed and not called")
|
||||
}
|
||||
expectedCalls := "234"
|
||||
|
||||
if !h2Called {
|
||||
t.Fatalf("Expected hook 2 to be called")
|
||||
if calls != expectedCalls {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRemoveAll(t *testing.T) {
|
||||
h := Hook[int]{}
|
||||
func TestHookUnbindAll(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
|
||||
h.RemoveAll() // should do nothing and not panic
|
||||
h.UnbindAll() // should do nothing and not panic
|
||||
|
||||
h.Add(func(data int) error { return nil })
|
||||
h.Add(func(data int) error { return nil })
|
||||
h.BindFunc(func(e *Event) error { return nil })
|
||||
h.BindFunc(func(e *Event) error { return nil })
|
||||
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected 2 handlers before RemoveAll, found %d", total)
|
||||
t.Fatalf("Expected %d handlers before UnbindAll, found %d", 2, total)
|
||||
}
|
||||
|
||||
h.RemoveAll()
|
||||
h.UnbindAll()
|
||||
|
||||
if total := len(h.handlers); total != 0 {
|
||||
t.Fatalf("Expected no handlers after RemoveAll, found %d", total)
|
||||
t.Fatalf("Expected no handlers after UnbindAll, found %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookTrigger(t *testing.T) {
|
||||
err1 := errors.New("demo")
|
||||
err2 := errors.New("demo")
|
||||
func TestHookTriggerErrorPropagation(t *testing.T) {
|
||||
err := errors.New("test")
|
||||
|
||||
scenarios := []struct {
|
||||
handlers []Handler[int]
|
||||
name string
|
||||
handlers []HandlerFunc[*Event]
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
[]Handler[int]{
|
||||
func(data int) error { return nil },
|
||||
func(data int) error { return nil },
|
||||
"without error",
|
||||
[]HandlerFunc[*Event]{
|
||||
func(e *Event) error { return e.Next() },
|
||||
func(e *Event) error { return e.Next() },
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
[]Handler[int]{
|
||||
func(data int) error { return nil },
|
||||
func(data int) error { return err1 },
|
||||
func(data int) error { return err2 },
|
||||
"with error",
|
||||
[]HandlerFunc[*Event]{
|
||||
func(e *Event) error { return e.Next() },
|
||||
func(e *Event) error { e.Next(); return err },
|
||||
func(e *Event) error { return e.Next() },
|
||||
},
|
||||
err1,
|
||||
err,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
h := Hook[int]{}
|
||||
for _, handler := range scenario.handlers {
|
||||
h.Add(handler)
|
||||
}
|
||||
result := h.Trigger(1)
|
||||
if result != scenario.expectedError {
|
||||
t.Fatalf("(%d) Expected %v, got %v", i, scenario.expectedError, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookTriggerStopPropagation(t *testing.T) {
|
||||
called1 := false
|
||||
f1 := func(data int) error { called1 = true; return nil }
|
||||
|
||||
called2 := false
|
||||
f2 := func(data int) error { called2 = true; return nil }
|
||||
|
||||
called3 := false
|
||||
f3 := func(data int) error { called3 = true; return nil }
|
||||
|
||||
called4 := false
|
||||
f4 := func(data int) error { called4 = true; return StopPropagation }
|
||||
|
||||
called5 := false
|
||||
f5 := func(data int) error { called5 = true; return nil }
|
||||
|
||||
called6 := false
|
||||
f6 := func(data int) error { called6 = true; return nil }
|
||||
|
||||
h := Hook[int]{}
|
||||
h.Add(f1)
|
||||
h.Add(f2)
|
||||
|
||||
result := h.Trigger(123, f3, f4, f5, f6)
|
||||
|
||||
if result != nil {
|
||||
t.Fatalf("Expected nil after StopPropagation, got %v", result)
|
||||
}
|
||||
|
||||
// ensure that the trigger handler were not persisted
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected 2 handlers, found %d", total)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
called bool
|
||||
expected bool
|
||||
}{
|
||||
{called1, true},
|
||||
{called2, true},
|
||||
{called3, true},
|
||||
{called4, true}, // StopPropagation
|
||||
{called5, false},
|
||||
{called6, false},
|
||||
}
|
||||
for i, scenario := range scenarios {
|
||||
if scenario.called != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, scenario.called)
|
||||
}
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
for _, handler := range s.handlers {
|
||||
h.BindFunc(handler)
|
||||
}
|
||||
result := h.Trigger(&Event{})
|
||||
if result != s.expectedError {
|
||||
t.Fatalf("Expected %v, got %v", s.expectedError, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+23
-13
@@ -7,6 +7,8 @@ import (
|
||||
// Tagger defines an interface for event data structs that support tags/groups/categories/etc.
|
||||
// Usually used together with TaggedHook.
|
||||
type Tagger interface {
|
||||
Resolver
|
||||
|
||||
Tags() []string
|
||||
}
|
||||
|
||||
@@ -33,12 +35,14 @@ type TaggedHook[T Tagger] struct {
|
||||
|
||||
// CanTriggerOn checks if the current TaggedHook can be triggered with
|
||||
// the provided event data tags.
|
||||
func (h *TaggedHook[T]) CanTriggerOn(tags []string) bool {
|
||||
//
|
||||
// It returns always true if the hook doens't have any tags.
|
||||
func (h *TaggedHook[T]) CanTriggerOn(tagsToCheck []string) bool {
|
||||
if len(h.tags) == 0 {
|
||||
return true // match all
|
||||
}
|
||||
|
||||
for _, t := range tags {
|
||||
for _, t := range tagsToCheck {
|
||||
if list.ExistInSlice(t, h.tags) {
|
||||
return true
|
||||
}
|
||||
@@ -47,28 +51,34 @@ func (h *TaggedHook[T]) CanTriggerOn(tags []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// PreAdd registers a new handler to the hook by prepending it to the existing queue.
|
||||
// Bind registers the provided handler to the current hooks queue.
|
||||
//
|
||||
// The fn handler will be called only if the event data tags satisfy h.CanTriggerOn.
|
||||
func (h *TaggedHook[T]) PreAdd(fn Handler[T]) string {
|
||||
return h.mainHook.PreAdd(func(e T) error {
|
||||
// It is similar to [Hook.Bind] with the difference that the handler
|
||||
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
|
||||
func (h *TaggedHook[T]) Bind(handler *Handler[T]) string {
|
||||
fn := handler.Func
|
||||
|
||||
handler.Func = func(e T) error {
|
||||
if h.CanTriggerOn(e.Tags()) {
|
||||
return fn(e)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
return h.mainHook.Bind(handler)
|
||||
}
|
||||
|
||||
// Add registers a new handler to the hook by appending it to the existing queue.
|
||||
// BindFunc registers a new handler with the specified function.
|
||||
//
|
||||
// The fn handler will be called only if the event data tags satisfy h.CanTriggerOn.
|
||||
func (h *TaggedHook[T]) Add(fn Handler[T]) string {
|
||||
return h.mainHook.Add(func(e T) error {
|
||||
// It is similar to [Hook.Bind] with the difference that the handler
|
||||
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
|
||||
func (h *TaggedHook[T]) BindFunc(fn HandlerFunc[T]) string {
|
||||
return h.mainHook.BindFunc(func(e T) error {
|
||||
if h.CanTriggerOn(e.Tags()) {
|
||||
return fn(e)
|
||||
}
|
||||
|
||||
return nil
|
||||
return e.Next()
|
||||
})
|
||||
}
|
||||
|
||||
+43
-28
@@ -1,69 +1,84 @@
|
||||
package hook
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockTagsData struct {
|
||||
type mockTagsEvent struct {
|
||||
Event
|
||||
tags []string
|
||||
}
|
||||
|
||||
func (m mockTagsData) Tags() []string {
|
||||
func (m mockTagsEvent) Tags() []string {
|
||||
return m.tags
|
||||
}
|
||||
|
||||
func TestTaggedHook(t *testing.T) {
|
||||
triggerSequence := ""
|
||||
calls := ""
|
||||
|
||||
base := &Hook[mockTagsData]{}
|
||||
base.Add(func(data mockTagsData) error { triggerSequence += "f0"; return nil })
|
||||
base := &Hook[*mockTagsEvent]{}
|
||||
base.BindFunc(func(e *mockTagsEvent) error { calls += "f0"; return e.Next() })
|
||||
|
||||
hA := NewTaggedHook(base)
|
||||
hA.Add(func(data mockTagsData) error { triggerSequence += "a1"; return nil })
|
||||
hA.PreAdd(func(data mockTagsData) error { triggerSequence += "a2"; return nil })
|
||||
hA.BindFunc(func(e *mockTagsEvent) error { calls += "a1"; return e.Next() })
|
||||
hA.Bind(&Handler[*mockTagsEvent]{
|
||||
Func: func(e *mockTagsEvent) error { calls += "a2"; return e.Next() },
|
||||
Priority: -1,
|
||||
})
|
||||
|
||||
hB := NewTaggedHook(base, "b1", "b2")
|
||||
hB.Add(func(data mockTagsData) error { triggerSequence += "b1"; return nil })
|
||||
hB.PreAdd(func(data mockTagsData) error { triggerSequence += "b2"; return nil })
|
||||
hB.BindFunc(func(e *mockTagsEvent) error { calls += "b1"; return e.Next() })
|
||||
hB.Bind(&Handler[*mockTagsEvent]{
|
||||
Func: func(e *mockTagsEvent) error { calls += "b2"; return e.Next() },
|
||||
Priority: -2,
|
||||
})
|
||||
|
||||
hC := NewTaggedHook(base, "c1", "c2")
|
||||
hC.Add(func(data mockTagsData) error { triggerSequence += "c1"; return nil })
|
||||
hC.PreAdd(func(data mockTagsData) error { triggerSequence += "c2"; return nil })
|
||||
hC.BindFunc(func(e *mockTagsEvent) error { calls += "c1"; return e.Next() })
|
||||
hC.Bind(&Handler[*mockTagsEvent]{
|
||||
Func: func(e *mockTagsEvent) error { calls += "c2"; return e.Next() },
|
||||
Priority: -3,
|
||||
})
|
||||
|
||||
scenarios := []struct {
|
||||
data mockTagsData
|
||||
expectedSequence string
|
||||
event *mockTagsEvent
|
||||
expectedCalls string
|
||||
}{
|
||||
{
|
||||
mockTagsData{},
|
||||
&mockTagsEvent{},
|
||||
"a2f0a1",
|
||||
},
|
||||
{
|
||||
mockTagsData{[]string{"missing"}},
|
||||
&mockTagsEvent{tags: []string{"missing"}},
|
||||
"a2f0a1",
|
||||
},
|
||||
{
|
||||
mockTagsData{[]string{"b2"}},
|
||||
&mockTagsEvent{tags: []string{"b2"}},
|
||||
"b2a2f0a1b1",
|
||||
},
|
||||
{
|
||||
mockTagsData{[]string{"c1"}},
|
||||
&mockTagsEvent{tags: []string{"c1"}},
|
||||
"c2a2f0a1c1",
|
||||
},
|
||||
{
|
||||
mockTagsData{[]string{"b1", "c2"}},
|
||||
&mockTagsEvent{tags: []string{"b1", "c2"}},
|
||||
"c2b2a2f0a1b1c1",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
triggerSequence = "" // reset
|
||||
for _, s := range scenarios {
|
||||
t.Run(strings.Join(s.event.tags, "_"), func(t *testing.T) {
|
||||
calls = "" // reset
|
||||
|
||||
err := hA.Trigger(s.data)
|
||||
if err != nil {
|
||||
t.Fatalf("[%d] Unexpected trigger error: %v", i, err)
|
||||
}
|
||||
err := base.Trigger(s.event)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected trigger error: %v", err)
|
||||
}
|
||||
|
||||
if triggerSequence != s.expectedSequence {
|
||||
t.Fatalf("[%d] Expected trigger sequence %s, got %s", i, s.expectedSequence, triggerSequence)
|
||||
}
|
||||
if calls != s.expectedCalls {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", s.expectedCalls, calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package inflector_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
@@ -18,10 +19,13 @@ func TestUcFirst(t *testing.T) {
|
||||
{"test test2", "Test test2"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.UcFirst(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.UcFirst(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,10 +46,13 @@ func TestColumnify(t *testing.T) {
|
||||
{"test1--test2", "test1--test2"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Columnify(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Columnify(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,10 +74,13 @@ func TestSentenize(t *testing.T) {
|
||||
{"hello world?", "Hello world?"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Sentenize(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Sentenize(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,21 +99,19 @@ func TestSanitize(t *testing.T) {
|
||||
{"abcABC", `[A-Z`, "", true}, // invalid pattern
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := inflector.Sanitize(scenario.val, scenario.pattern)
|
||||
hasErr := err != nil
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result, err := inflector.Sanitize(s.val, s.pattern)
|
||||
hasErr := err != nil
|
||||
|
||||
if scenario.expectErr != hasErr {
|
||||
if scenario.expectErr {
|
||||
t.Errorf("(%d) Expected error, got nil", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Didn't expect error, got", err)
|
||||
if s.expectErr != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectErr, hasErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,9 +134,12 @@ func TestSnakecase(t *testing.T) {
|
||||
{"testABR", "test_abr"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
if result := inflector.Snakecase(scenario.val); result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Snakecase(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+26
-3
@@ -39,7 +39,7 @@ func ExistInSlice[T comparable](item T, list []T) bool {
|
||||
// ExistInSliceWithRegex checks whether a string exists in a slice
|
||||
// either by direct match, or by a regular expression (eg. `^\w+$`).
|
||||
//
|
||||
// _Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!_
|
||||
// Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!
|
||||
func ExistInSliceWithRegex(str string, list []string) bool {
|
||||
for _, field := range list {
|
||||
isRegex := strings.HasPrefix(field, "^") && strings.HasSuffix(field, "$")
|
||||
@@ -64,7 +64,7 @@ func ExistInSliceWithRegex(str string, list []string) bool {
|
||||
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
|
||||
//
|
||||
// @todo consider replacing with TTL or LRU type cache
|
||||
cachedPatterns.SetIfLessThanLimit(field, pattern, 5000)
|
||||
cachedPatterns.SetIfLessThanLimit(field, pattern, 500)
|
||||
}
|
||||
|
||||
if pattern != nil && pattern.MatchString(str) {
|
||||
@@ -129,7 +129,7 @@ func ToUniqueStringSlice(value any) (result []string) {
|
||||
// just add the string as single array element
|
||||
result = append(result, val)
|
||||
}
|
||||
case json.Marshaler: // eg. JsonArray
|
||||
case json.Marshaler: // eg. JSONArray
|
||||
raw, _ := val.MarshalJSON()
|
||||
_ = json.Unmarshal(raw, &result)
|
||||
default:
|
||||
@@ -138,3 +138,26 @@ func ToUniqueStringSlice(value any) (result []string) {
|
||||
|
||||
return NonzeroUniques(result)
|
||||
}
|
||||
|
||||
// ToChunks splits list into chunks.
|
||||
//
|
||||
// Zero or negative chunkSize argument is normalized to 1.
|
||||
//
|
||||
// See https://go.dev/wiki/SliceTricks#batching-with-minimal-allocation.
|
||||
func ToChunks[T any](list []T, chunkSize int) [][]T {
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = 1
|
||||
}
|
||||
|
||||
chunks := make([][]T, 0, (len(list)+chunkSize-1)/chunkSize)
|
||||
|
||||
if len(list) == 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
for chunkSize < len(list) {
|
||||
list, chunks = list[chunkSize:], append(chunks, list[0:chunkSize:chunkSize])
|
||||
}
|
||||
|
||||
return append(chunks, list)
|
||||
}
|
||||
|
||||
+111
-74
@@ -2,6 +2,7 @@ package list_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
@@ -42,18 +43,20 @@ func TestSubtractSliceString(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("(%d) Failed to serialize: %v", i, err)
|
||||
}
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize: %v", err)
|
||||
}
|
||||
|
||||
strResult := string(raw)
|
||||
strResult := string(raw)
|
||||
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("(%d) Expected %v, got %v", i, s.expected, strResult)
|
||||
}
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, strResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,18 +94,20 @@ func TestSubtractSliceInt(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("(%d) Failed to serialize: %v", i, err)
|
||||
}
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize: %v", err)
|
||||
}
|
||||
|
||||
strResult := string(raw)
|
||||
strResult := string(raw)
|
||||
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("(%d) Expected %v, got %v", i, s.expected, strResult)
|
||||
}
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, strResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,15 +125,13 @@ func TestExistInSliceString(t *testing.T) {
|
||||
{"test", []string{"1", "2", "test"}, true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ExistInSlice(scenario.item, scenario.list)
|
||||
if result != scenario.expected {
|
||||
if scenario.expected {
|
||||
t.Errorf("(%d) Expected to exist in the list", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Expected NOT to exist in the list", i)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
|
||||
result := list.ExistInSlice(s.item, s.list)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,15 +149,13 @@ func TestExistInSliceInt(t *testing.T) {
|
||||
{-1, []int{0, -1, -2, -3, -4}, true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ExistInSlice(scenario.item, scenario.list)
|
||||
if result != scenario.expected {
|
||||
if scenario.expected {
|
||||
t.Errorf("(%d) Expected to exist in the list", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Expected NOT to exist in the list", i)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%d", i, s.item), func(t *testing.T) {
|
||||
result := list.ExistInSlice(s.item, s.list)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,15 +178,13 @@ func TestExistInSliceWithRegex(t *testing.T) {
|
||||
{"!?@test", []string{`^\W+$`, "test"}, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ExistInSliceWithRegex(scenario.item, scenario.list)
|
||||
if result != scenario.expected {
|
||||
if scenario.expected {
|
||||
t.Errorf("(%d) Expected the string to exist in the list", i)
|
||||
} else {
|
||||
t.Errorf("(%d) Expected the string NOT to exist in the list", i)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
|
||||
result := list.ExistInSliceWithRegex(s.item, s.list)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,21 +195,23 @@ func TestToInterfaceSlice(t *testing.T) {
|
||||
{[]string{}},
|
||||
{[]string{""}},
|
||||
{[]string{"1", "test"}},
|
||||
{[]string{"test1", "test2", "test3"}},
|
||||
{[]string{"test1", "test1", "test2", "test3"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ToInterfaceSlice(scenario.items)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
|
||||
result := list.ToInterfaceSlice(s.items)
|
||||
|
||||
if len(result) != len(scenario.items) {
|
||||
t.Errorf("(%d) Result list length doesn't match with the original list", i)
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != scenario.items[j] {
|
||||
t.Errorf("(%d:%d) Result list item should match with the original list item", i, j)
|
||||
if len(result) != len(s.items) {
|
||||
t.Fatalf("Expected length %d, got %d", len(s.items), len(result))
|
||||
}
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != s.items[j] {
|
||||
t.Fatalf("Result list item doesn't match with the original list item, got %v VS %v", v, s.items[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,18 +226,20 @@ func TestNonzeroUniquesString(t *testing.T) {
|
||||
{[]string{"test1", "", "test2", "Test2", "test1", "test3"}, []string{"test1", "test2", "Test2", "test3"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.NonzeroUniques(scenario.items)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
|
||||
result := list.NonzeroUniques(s.items)
|
||||
|
||||
if len(result) != len(scenario.expected) {
|
||||
t.Errorf("(%d) Result list length doesn't match with the expected list", i)
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != scenario.expected[j] {
|
||||
t.Errorf("(%d:%d) Result list item should match with the expected list item", i, j)
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != s.expected[j] {
|
||||
t.Fatalf("Result list item doesn't match with the expected list item, got %v VS %v", v, s.expected[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,20 +257,54 @@ func TestToUniqueStringSlice(t *testing.T) {
|
||||
{[]any{0, 1, "test", ""}, []string{"0", "1", "test"}},
|
||||
{[]string{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
{`["test1", "test2", "test2"]`, []string{"test1", "test2"}},
|
||||
{types.JsonArray[string]{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
{types.JSONArray[string]{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := list.ToUniqueStringSlice(scenario.value)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.value), func(t *testing.T) {
|
||||
result := list.ToUniqueStringSlice(s.value)
|
||||
|
||||
if len(result) != len(scenario.expected) {
|
||||
t.Errorf("(%d) Result list length doesn't match with the expected list", i)
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != scenario.expected[j] {
|
||||
t.Errorf("(%d:%d) Result list item should match with the expected list item", i, j)
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != s.expected[j] {
|
||||
t.Fatalf("Result list item doesn't match with the expected list item, got %v vs %v", v, s.expected[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
items []any
|
||||
chunkSize int
|
||||
expected string
|
||||
}{
|
||||
{nil, 2, "[]"},
|
||||
{[]any{}, 2, "[]"},
|
||||
{[]any{1, 2, 3, 4}, -1, "[[1],[2],[3],[4]]"},
|
||||
{[]any{1, 2, 3, 4}, 0, "[[1],[2],[3],[4]]"},
|
||||
{[]any{1, 2, 3, 4}, 2, "[[1,2],[3,4]]"},
|
||||
{[]any{1, 2, 3, 4, 5}, 2, "[[1,2],[3,4],[5]]"},
|
||||
{[]any{1, 2, 3, 4, 5}, 10, "[[1,2,3,4,5]]"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
|
||||
result := list.ToChunks(s.items, s.chunkSize)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
@@ -160,7 +162,6 @@ func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
if err := h.resolveAttr(data, a); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -168,7 +169,7 @@ func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
Time: r.Time,
|
||||
Level: r.Level,
|
||||
Message: r.Message,
|
||||
Data: types.JsonMap(data),
|
||||
Data: types.JSONMap[any](data),
|
||||
}
|
||||
|
||||
if h.options.BeforeAddFunc != nil && !h.options.BeforeAddFunc(ctx, log) {
|
||||
@@ -251,11 +252,23 @@ func (h *BatchHandler) resolveAttr(data map[string]any, attr slog.Attr) error {
|
||||
data[attr.Key] = groupData
|
||||
}
|
||||
default:
|
||||
v := attr.Value.Any()
|
||||
|
||||
if err, ok := v.(error); ok {
|
||||
data[attr.Key] = err.Error()
|
||||
} else {
|
||||
switch v := attr.Value.Any().(type) {
|
||||
case validation.Errors:
|
||||
data[attr.Key] = map[string]any{
|
||||
"data": v,
|
||||
"raw": v.Error(),
|
||||
}
|
||||
case error:
|
||||
var ve validation.Errors
|
||||
if errors.As(v, &ve) {
|
||||
data[attr.Key] = map[string]any{
|
||||
"data": ve,
|
||||
"raw": v.Error(),
|
||||
}
|
||||
} else {
|
||||
data[attr.Key] = v.Error()
|
||||
}
|
||||
default:
|
||||
data[attr.Key] = v
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,11 +181,8 @@ func TestBatchHandlerHandle(t *testing.T) {
|
||||
BeforeAddFunc: func(_ context.Context, log *Log) bool {
|
||||
beforeLogs = append(beforeLogs, log)
|
||||
|
||||
if log.Message == "test2" {
|
||||
return false // skip test2 log
|
||||
}
|
||||
|
||||
return true
|
||||
// skip test2 log
|
||||
return log.Message != "test2"
|
||||
},
|
||||
WriteFunc: func(_ context.Context, logs []*Log) error {
|
||||
writeLogs = logs
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ import (
|
||||
// preformatted JSON map.
|
||||
type Log struct {
|
||||
Time time.Time
|
||||
Data types.JSONMap[any]
|
||||
Message string
|
||||
Level slog.Level
|
||||
Data types.JsonMap
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHtml2Text(t *testing.T) {
|
||||
func TestHTML2Text(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
html string
|
||||
expected string
|
||||
|
||||
@@ -3,6 +3,8 @@ package mailer
|
||||
import (
|
||||
"io"
|
||||
"net/mail"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
// Message defines a generic email message struct.
|
||||
@@ -24,6 +26,16 @@ type Mailer interface {
|
||||
Send(message *Message) error
|
||||
}
|
||||
|
||||
// SendInterceptor is optional interface for registering mail send hooks.
|
||||
type SendInterceptor interface {
|
||||
OnSend() *hook.Hook[*SendEvent]
|
||||
}
|
||||
|
||||
type SendEvent struct {
|
||||
hook.Event
|
||||
Message *Message
|
||||
}
|
||||
|
||||
// addressesToStrings converts the provided address to a list of serialized RFC 5322 strings.
|
||||
//
|
||||
// To export only the email part of mail.Address, you can set withName to false.
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
var _ Mailer = (*Sendmail)(nil)
|
||||
@@ -16,10 +18,29 @@ var _ Mailer = (*Sendmail)(nil)
|
||||
//
|
||||
// This client is usually recommended only for development and testing.
|
||||
type Sendmail struct {
|
||||
onSend *hook.Hook[*SendEvent]
|
||||
}
|
||||
|
||||
// Send implements `mailer.Mailer` interface.
|
||||
// OnSend implements [mailer.SendInterceptor] interface.
|
||||
func (c *Sendmail) OnSend() *hook.Hook[*SendEvent] {
|
||||
if c.onSend == nil {
|
||||
c.onSend = &hook.Hook[*SendEvent]{}
|
||||
}
|
||||
return c.onSend
|
||||
}
|
||||
|
||||
// Send implements [mailer.Mailer] interface.
|
||||
func (c *Sendmail) Send(m *Message) error {
|
||||
if c.onSend != nil {
|
||||
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
|
||||
return c.send(e.Message)
|
||||
})
|
||||
}
|
||||
|
||||
return c.send(m)
|
||||
}
|
||||
|
||||
func (c *Sendmail) send(m *Message) error {
|
||||
toAddresses := addressesToStrings(m.To, false)
|
||||
|
||||
headers := make(http.Header)
|
||||
@@ -74,5 +95,5 @@ func findSendmailPath() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("failed to locate a sendmail executable path")
|
||||
return "", errors.New("Failed to locate a sendmail executable path.")
|
||||
}
|
||||
|
||||
+32
-30
@@ -7,43 +7,27 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/domodwyer/mailyak/v3"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
var _ Mailer = (*SmtpClient)(nil)
|
||||
var _ Mailer = (*SMTPClient)(nil)
|
||||
|
||||
const (
|
||||
SmtpAuthPlain = "PLAIN"
|
||||
SmtpAuthLogin = "LOGIN"
|
||||
SMTPAuthPlain = "PLAIN"
|
||||
SMTPAuthLogin = "LOGIN"
|
||||
)
|
||||
|
||||
// Deprecated: Use directly the SmtpClient struct literal.
|
||||
//
|
||||
// NewSmtpClient creates new SmtpClient with the provided configuration.
|
||||
func NewSmtpClient(
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
tls bool,
|
||||
) *SmtpClient {
|
||||
return &SmtpClient{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Tls: tls,
|
||||
}
|
||||
}
|
||||
|
||||
// SmtpClient defines a SMTP mail client structure that implements
|
||||
// SMTPClient defines a SMTP mail client structure that implements
|
||||
// `mailer.Mailer` interface.
|
||||
type SmtpClient struct {
|
||||
Host string
|
||||
type SMTPClient struct {
|
||||
onSend *hook.Hook[*SendEvent]
|
||||
|
||||
TLS bool
|
||||
Port int
|
||||
Host string
|
||||
Username string
|
||||
Password string
|
||||
Tls bool
|
||||
|
||||
// SMTP auth method to use
|
||||
// (if not explicitly set, defaults to "PLAIN")
|
||||
@@ -56,12 +40,30 @@ type SmtpClient struct {
|
||||
LocalName string
|
||||
}
|
||||
|
||||
// Send implements `mailer.Mailer` interface.
|
||||
func (c *SmtpClient) Send(m *Message) error {
|
||||
// OnSend implements [mailer.SendInterceptor] interface.
|
||||
func (c *SMTPClient) OnSend() *hook.Hook[*SendEvent] {
|
||||
if c.onSend == nil {
|
||||
c.onSend = &hook.Hook[*SendEvent]{}
|
||||
}
|
||||
return c.onSend
|
||||
}
|
||||
|
||||
// Send implements [mailer.Mailer] interface.
|
||||
func (c *SMTPClient) Send(m *Message) error {
|
||||
if c.onSend != nil {
|
||||
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
|
||||
return c.send(e.Message)
|
||||
})
|
||||
}
|
||||
|
||||
return c.send(m)
|
||||
}
|
||||
|
||||
func (c *SMTPClient) send(m *Message) error {
|
||||
var smtpAuth smtp.Auth
|
||||
if c.Username != "" || c.Password != "" {
|
||||
switch c.AuthMethod {
|
||||
case SmtpAuthLogin:
|
||||
case SMTPAuthLogin:
|
||||
smtpAuth = &smtpLoginAuth{c.Username, c.Password}
|
||||
default:
|
||||
smtpAuth = smtp.PlainAuth("", c.Username, c.Password, c.Host)
|
||||
@@ -70,7 +72,7 @@ func (c *SmtpClient) Send(m *Message) error {
|
||||
|
||||
// create mail instance
|
||||
var yak *mailyak.MailYak
|
||||
if c.Tls {
|
||||
if c.TLS {
|
||||
var tlsErr error
|
||||
yak, tlsErr = mailyak.NewWithTLS(fmt.Sprintf("%s:%d", c.Host, c.Port), smtpAuth, nil)
|
||||
if tlsErr != nil {
|
||||
|
||||
+16
-14
@@ -56,24 +56,26 @@ func TestLoginAuthStart(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
method, resp, err := auth.Start(s.serverInfo)
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
method, resp, err := auth.Start(s.serverInfo)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("[%s] Expected hasErr %v, got %v", s.name, s.expectError, hasErr)
|
||||
}
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if len(resp) != 0 {
|
||||
t.Fatalf("[%s] Expected empty data response, got %v", s.name, resp)
|
||||
}
|
||||
if len(resp) != 0 {
|
||||
t.Fatalf("Expected empty data response, got %v", resp)
|
||||
}
|
||||
|
||||
if method != "LOGIN" {
|
||||
t.Fatalf("[%s] Expected LOGIN, got %v", s.name, method)
|
||||
}
|
||||
if method != "LOGIN" {
|
||||
t.Fatalf("Expected LOGIN, got %v", method)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
File string
|
||||
Up func(db dbx.Builder) error
|
||||
Down func(db dbx.Builder) error
|
||||
}
|
||||
|
||||
// MigrationsList defines a list with migration definitions
|
||||
type MigrationsList struct {
|
||||
list []*Migration
|
||||
}
|
||||
|
||||
// Item returns a single migration from the list by its index.
|
||||
func (l *MigrationsList) Item(index int) *Migration {
|
||||
return l.list[index]
|
||||
}
|
||||
|
||||
// Items returns the internal migrations list slice.
|
||||
func (l *MigrationsList) Items() []*Migration {
|
||||
return l.list
|
||||
}
|
||||
|
||||
// Register adds new migration definition to the list.
|
||||
//
|
||||
// If `optFilename` is not provided, it will try to get the name from its .go file.
|
||||
//
|
||||
// The list will be sorted automatically based on the migrations file name.
|
||||
func (l *MigrationsList) Register(
|
||||
up func(db dbx.Builder) error,
|
||||
down func(db dbx.Builder) error,
|
||||
optFilename ...string,
|
||||
) {
|
||||
var file string
|
||||
if len(optFilename) > 0 {
|
||||
file = optFilename[0]
|
||||
} else {
|
||||
_, path, _, _ := runtime.Caller(1)
|
||||
file = filepath.Base(path)
|
||||
}
|
||||
|
||||
l.list = append(l.list, &Migration{
|
||||
File: file,
|
||||
Up: up,
|
||||
Down: down,
|
||||
})
|
||||
|
||||
sort.Slice(l.list, func(i int, j int) bool {
|
||||
return l.list[i].File < l.list[j].File
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigrationsList(t *testing.T) {
|
||||
l := MigrationsList{}
|
||||
|
||||
l.Register(nil, nil, "3_test.go")
|
||||
l.Register(nil, nil, "1_test.go")
|
||||
l.Register(nil, nil, "2_test.go")
|
||||
l.Register(nil, nil /* auto detect file name */)
|
||||
|
||||
expected := []string{
|
||||
"1_test.go",
|
||||
"2_test.go",
|
||||
"3_test.go",
|
||||
"list_test.go",
|
||||
}
|
||||
|
||||
items := l.Items()
|
||||
if len(items) != len(expected) {
|
||||
t.Fatalf("Expected %d items, got %d: \n%#v", len(expected), len(items), items)
|
||||
}
|
||||
|
||||
for i, name := range expected {
|
||||
item := l.Item(i)
|
||||
if item.File != name {
|
||||
t.Fatalf("Expected name %s for index %d, got %s", name, i, item.File)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const DefaultMigrationsTable = "_migrations"
|
||||
|
||||
// Runner defines a simple struct for managing the execution of db migrations.
|
||||
type Runner struct {
|
||||
db *dbx.DB
|
||||
migrationsList MigrationsList
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewRunner creates and initializes a new db migrations Runner instance.
|
||||
func NewRunner(db *dbx.DB, migrationsList MigrationsList) (*Runner, error) {
|
||||
runner := &Runner{
|
||||
db: db,
|
||||
migrationsList: migrationsList,
|
||||
tableName: DefaultMigrationsTable,
|
||||
}
|
||||
|
||||
if err := runner.createMigrationsTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
// Run interactively executes the current runner with the provided args.
|
||||
//
|
||||
// The following commands are supported:
|
||||
// - up - applies all migrations
|
||||
// - down [n] - reverts the last n applied migrations
|
||||
func (r *Runner) Run(args ...string) error {
|
||||
cmd := "up"
|
||||
if len(args) > 0 {
|
||||
cmd = args[0]
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "up":
|
||||
applied, err := r.Up()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(applied) == 0 {
|
||||
color.Green("No new migrations to apply.")
|
||||
} else {
|
||||
for _, file := range applied {
|
||||
color.Green("Applied %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "down":
|
||||
toRevertCount := 1
|
||||
if len(args) > 1 {
|
||||
toRevertCount = cast.ToInt(args[1])
|
||||
if toRevertCount < 0 {
|
||||
// revert all applied migrations
|
||||
toRevertCount = len(r.migrationsList.Items())
|
||||
}
|
||||
}
|
||||
|
||||
names, err := r.lastAppliedMigrations(toRevertCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf(
|
||||
"\n%v\nDo you really want to revert the last %d applied migration(s)?",
|
||||
strings.Join(names, "\n"),
|
||||
toRevertCount,
|
||||
),
|
||||
}
|
||||
survey.AskOne(prompt, &confirm)
|
||||
if !confirm {
|
||||
fmt.Println("The command has been cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
reverted, err := r.Down(toRevertCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(reverted) == 0 {
|
||||
color.Green("No migrations to revert.")
|
||||
} else {
|
||||
for _, file := range reverted {
|
||||
color.Green("Reverted %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "history-sync":
|
||||
if err := r.removeMissingAppliedMigrations(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
color.Green("The %s table was synced with the available migrations.", r.tableName)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Unsupported command: %q\n", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Up executes all unapplied migrations for the provided runner.
|
||||
//
|
||||
// On success returns list with the applied migrations file names.
|
||||
func (r *Runner) Up() ([]string, error) {
|
||||
applied := []string{}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
// skip applied
|
||||
if r.isMigrationApplied(tx, m.File) {
|
||||
continue
|
||||
}
|
||||
|
||||
// ignore empty Up action
|
||||
if m.Up != nil {
|
||||
if err := m.Up(tx); err != nil {
|
||||
return fmt.Errorf("Failed to apply migration %s: %w", m.File, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.saveAppliedMigration(tx, m.File); err != nil {
|
||||
return fmt.Errorf("Failed to save applied migration info for %s: %w", m.File, err)
|
||||
}
|
||||
|
||||
applied = append(applied, m.File)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// Down reverts the last `toRevertCount` applied migrations
|
||||
// (in the order they were applied).
|
||||
//
|
||||
// On success returns list with the reverted migrations file names.
|
||||
func (r *Runner) Down(toRevertCount int) ([]string, error) {
|
||||
reverted := make([]string, 0, toRevertCount)
|
||||
|
||||
names, appliedErr := r.lastAppliedMigrations(toRevertCount)
|
||||
if appliedErr != nil {
|
||||
return nil, appliedErr
|
||||
}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
for _, name := range names {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
if m.File != name {
|
||||
continue
|
||||
}
|
||||
|
||||
// revert limit reached
|
||||
if toRevertCount-len(reverted) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ignore empty Down action
|
||||
if m.Down != nil {
|
||||
if err := m.Down(tx); err != nil {
|
||||
return fmt.Errorf("Failed to revert migration %s: %w", m.File, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.saveRevertedMigration(tx, m.File); err != nil {
|
||||
return fmt.Errorf("Failed to save reverted migration info for %s: %w", m.File, err)
|
||||
}
|
||||
|
||||
reverted = append(reverted, m.File)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reverted, nil
|
||||
}
|
||||
|
||||
func (r *Runner) createMigrationsTable() error {
|
||||
rawQuery := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %v (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
r.db.QuoteTableName(r.tableName),
|
||||
)
|
||||
|
||||
_, err := r.db.NewQuery(rawQuery).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) isMigrationApplied(tx dbx.Builder, file string) bool {
|
||||
var exists bool
|
||||
|
||||
err := tx.Select("count(*)").
|
||||
From(r.tableName).
|
||||
Where(dbx.HashExp{"file": file}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists
|
||||
}
|
||||
|
||||
func (r *Runner) saveAppliedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Insert(r.tableName, dbx.Params{
|
||||
"file": file,
|
||||
"applied": time.Now().UnixMicro(),
|
||||
}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) saveRevertedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Delete(r.tableName, dbx.HashExp{"file": file}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) lastAppliedMigrations(limit int) ([]string, error) {
|
||||
var files = make([]string, 0, limit)
|
||||
|
||||
err := r.db.Select("file").
|
||||
From(r.tableName).
|
||||
Where(dbx.Not(dbx.HashExp{"applied": nil})).
|
||||
// unify microseconds and seconds applied time for backward compatibility
|
||||
OrderBy("substr(applied||'0000000000000000', 0, 17) DESC").
|
||||
AndOrderBy("file DESC").
|
||||
Limit(int64(limit)).
|
||||
Column(&files)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func (r *Runner) removeMissingAppliedMigrations() error {
|
||||
loadedMigrations := r.migrationsList.Items()
|
||||
|
||||
names := make([]any, len(loadedMigrations))
|
||||
for i, migration := range loadedMigrations {
|
||||
names[i] = migration.File
|
||||
}
|
||||
|
||||
_, err := r.db.Delete(r.tableName, dbx.Not(dbx.HashExp{
|
||||
"file": names,
|
||||
})).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestNewRunner(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(nil, nil, "1_test.go")
|
||||
l.Register(nil, nil, "2_test.go")
|
||||
l.Register(nil, nil, "3_test.go")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(r.migrationsList.Items()) != len(l.Items()) {
|
||||
t.Fatalf("Expected the same migrations list to be assigned, got \n%#v", r.migrationsList)
|
||||
}
|
||||
|
||||
expectedQueries := []string{
|
||||
"CREATE TABLE IF NOT EXISTS `_migrations` (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
}
|
||||
if len(expectedQueries) != len(testDB.CalledQueries) {
|
||||
t.Fatalf("Expected %d queries, got %d: \n%v", len(expectedQueries), len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
for _, q := range expectedQueries {
|
||||
if !list.ExistInSlice(q, testDB.CalledQueries) {
|
||||
t.Fatalf("Query %s was not found in \n%v", q, testDB.CalledQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerUpAndDown(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
callsOrder := []string{}
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "up2")
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down2")
|
||||
return nil
|
||||
}, "2_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "up3")
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down3")
|
||||
return nil
|
||||
}, "3_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "up1")
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down1")
|
||||
return nil
|
||||
}, "1_test")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// simulate partially out-of-order run migration
|
||||
r.saveAppliedMigration(testDB, "2_test")
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Up()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if _, err := r.Up(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedUpCallsOrder := `["up1","up3"]` // skip up2 since it was applied previously
|
||||
|
||||
upCallsOrder, err := json.Marshal(callsOrder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := string(upCallsOrder); v != expectedUpCallsOrder {
|
||||
t.Fatalf("Expected Up() calls order %s, got %s", expectedUpCallsOrder, upCallsOrder)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// reset callsOrder
|
||||
callsOrder = []string{}
|
||||
|
||||
// simulate unrun migration
|
||||
r.migrationsList.Register(nil, func(db dbx.Builder) error {
|
||||
callsOrder = append(callsOrder, "down4")
|
||||
return nil
|
||||
}, "4_test")
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Down()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if _, err := r.Down(2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedDownCallsOrder := `["down3","down1"]` // revert in the applied order
|
||||
|
||||
downCallsOrder, err := json.Marshal(callsOrder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := string(downCallsOrder); v != expectedDownCallsOrder {
|
||||
t.Fatalf("Expected Down() calls order %s, got %s", expectedDownCallsOrder, downCallsOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistorySync(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
// mock migrations history
|
||||
l := MigrationsList{}
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
return nil
|
||||
}, "1_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
return nil
|
||||
}, "2_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
return nil
|
||||
}, "3_test")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize the runner: %v", err)
|
||||
}
|
||||
|
||||
if _, err := r.Up(); err != nil {
|
||||
t.Fatalf("Failed to apply the mock migrations: %v", err)
|
||||
}
|
||||
|
||||
if !r.isMigrationApplied(testDB.DB, "2_test") {
|
||||
t.Fatalf("Expected 2_test migration to be applied")
|
||||
}
|
||||
|
||||
// mock deleted migrations
|
||||
r.migrationsList.list = []*Migration{r.migrationsList.list[0], r.migrationsList.list[2]}
|
||||
|
||||
if err := r.removeMissingAppliedMigrations(); err != nil {
|
||||
t.Fatalf("Failed to remove missing applied migrations: %v", err)
|
||||
}
|
||||
|
||||
if r.isMigrationApplied(testDB.DB, "2_test") {
|
||||
t.Fatalf("Expected 2_test migration to NOT be applied")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testDB struct {
|
||||
*dbx.DB
|
||||
CalledQueries []string
|
||||
}
|
||||
|
||||
// NB! Don't forget to call `db.Close()` at the end of the test.
|
||||
func createTestDB() (*testDB, error) {
|
||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package rest
|
||||
package picker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -10,6 +10,12 @@ import (
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Modifiers["excerpt"] = func(args ...string) (Modifier, error) {
|
||||
return newExcerptModifier(args...)
|
||||
}
|
||||
}
|
||||
|
||||
var whitespaceRegex = regexp.MustCompile(`\s+`)
|
||||
|
||||
var excludeTags = []string{
|
||||
@@ -24,7 +30,7 @@ var inlineTags = []string{
|
||||
"strong", "strike", "sub", "sup", "time",
|
||||
}
|
||||
|
||||
var _ FieldModifier = (*excerptModifier)(nil)
|
||||
var _ Modifier = (*excerptModifier)(nil)
|
||||
|
||||
type excerptModifier struct {
|
||||
max int // approximate max excerpt length
|
||||
@@ -59,7 +65,7 @@ func newExcerptModifier(args ...string) (*excerptModifier, error) {
|
||||
return &excerptModifier{max, withEllipsis}, nil
|
||||
}
|
||||
|
||||
// Modify implements the [FieldModifier.Modify] interface method.
|
||||
// Modify implements the [Modifier.Modify] interface method.
|
||||
//
|
||||
// It returns a plain text excerpt/short-description from a formatted
|
||||
// html string (non-string values are kept untouched).
|
||||
@@ -1,4 +1,4 @@
|
||||
package rest
|
||||
package picker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -0,0 +1,41 @@
|
||||
package picker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/tokenizer"
|
||||
)
|
||||
|
||||
var Modifiers = map[string]ModifierFactoryFunc{}
|
||||
|
||||
type ModifierFactoryFunc func(args ...string) (Modifier, error)
|
||||
|
||||
type Modifier interface {
|
||||
// Modify executes the modifier and returns a new modified value.
|
||||
Modify(value any) (any, error)
|
||||
}
|
||||
|
||||
func initModifer(rawModifier string) (Modifier, error) {
|
||||
t := tokenizer.NewFromString(rawModifier)
|
||||
t.Separators('(', ')', ',', ' ')
|
||||
t.IgnoreParenthesis(true)
|
||||
|
||||
parts, err := t.ScanAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty modifier expression %q", rawModifier)
|
||||
}
|
||||
|
||||
name := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
factory, ok := Modifiers[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing or invalid modifier %q", name)
|
||||
}
|
||||
|
||||
return factory(args...)
|
||||
}
|
||||
@@ -1,76 +1,25 @@
|
||||
package rest
|
||||
package picker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
// Experimental!
|
||||
//
|
||||
// Need more tests before replacing encoding/json entirely.
|
||||
// Test also encoding/json/v2 once released (see https://github.com/golang/go/discussions/63397)
|
||||
goccy "github.com/goccy/go-json"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/tokenizer"
|
||||
)
|
||||
|
||||
type FieldModifier interface {
|
||||
// Modify executes the modifier and returns a new modified value.
|
||||
Modify(value any) (any, error)
|
||||
}
|
||||
|
||||
// Serializer represents custom REST JSON serializer based on echo.DefaultJSONSerializer,
|
||||
// with support for additional generic response data transformation (eg. fields picker).
|
||||
type Serializer struct {
|
||||
echo.DefaultJSONSerializer
|
||||
|
||||
FieldsParam string
|
||||
}
|
||||
|
||||
// Serialize converts an interface into a json and writes it to the response.
|
||||
// Pick converts data into a []any, map[string]any, etc. (using json marshal->unmarshal)
|
||||
// containing only the fields from the parsed rawFields expression.
|
||||
//
|
||||
// It also provides a generic response data fields picker via the FieldsParam query parameter (default to "fields").
|
||||
//
|
||||
// Note: for the places where it is safe, the std encoding/json is replaced
|
||||
// with goccy due to its slightly better Unmarshal/Marshal performance.
|
||||
func (s *Serializer) Serialize(c echo.Context, i any, indent string) error {
|
||||
fieldsParam := s.FieldsParam
|
||||
if fieldsParam == "" {
|
||||
fieldsParam = "fields"
|
||||
}
|
||||
|
||||
statusCode := c.Response().Status
|
||||
|
||||
rawFields := c.QueryParam(fieldsParam)
|
||||
if rawFields == "" || statusCode < 200 || statusCode > 299 {
|
||||
return s.DefaultJSONSerializer.Serialize(c, i, indent)
|
||||
}
|
||||
|
||||
decoded, err := PickFields(i, rawFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
enc := goccy.NewEncoder(c.Response())
|
||||
if indent != "" {
|
||||
enc.SetIndent("", indent)
|
||||
}
|
||||
|
||||
return enc.Encode(decoded)
|
||||
}
|
||||
|
||||
// PickFields parses the provided fields string expression and
|
||||
// returns a new subset of data with only the requested fields.
|
||||
//
|
||||
// Fields transformations with modifiers are also supported (see initModifer()).
|
||||
// rawFields is a comma separated string of the fields to include.
|
||||
// Nested fields should be listed with dot-notation.
|
||||
// Fields value modifiers are also supported using the `:modifier(args)` format (see Modifiers).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data := map[string]any{"a": 1, "b": 2, "c": map[string]any{"c1": 11, "c2": 22}}
|
||||
// PickFields(data, "a,c.c1") // map[string]any{"a": 1, "c": map[string]any{"c1": 11}}
|
||||
func PickFields(data any, rawFields string) (any, error) {
|
||||
// Pick(data, "a,c.c1") // map[string]any{"a": 1, "c": map[string]any{"c1": 11}}
|
||||
func Pick(data any, rawFields string) (any, error) {
|
||||
parsedFields, err := parseFields(rawFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -82,18 +31,18 @@ func PickFields(data any, rawFields string) (any, error) {
|
||||
//
|
||||
// @todo research other approaches to avoid the double serialization
|
||||
// ---
|
||||
encoded, err := json.Marshal(data) // use the std json since goccy has several bugs reported with struct marshaling and it is not safe
|
||||
encoded, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := goccy.Unmarshal(encoded, &decoded); err != nil {
|
||||
if err := json.Unmarshal(encoded, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ---
|
||||
|
||||
// special cases to preserve the same fields format when used with single item or array data.
|
||||
// special cases to preserve the same fields format when used with single item or search results data.
|
||||
var isSearchResult bool
|
||||
switch data.(type) {
|
||||
case search.Result, *search.Result:
|
||||
@@ -111,7 +60,7 @@ func PickFields(data any, rawFields string) (any, error) {
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func parseFields(rawFields string) (map[string]FieldModifier, error) {
|
||||
func parseFields(rawFields string) (map[string]Modifier, error) {
|
||||
t := tokenizer.NewFromString(rawFields)
|
||||
|
||||
fields, err := t.ScanAll()
|
||||
@@ -119,7 +68,7 @@ func parseFields(rawFields string) (map[string]FieldModifier, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]FieldModifier, len(fields))
|
||||
result := make(map[string]Modifier, len(fields))
|
||||
|
||||
for _, f := range fields {
|
||||
parts := strings.SplitN(strings.TrimSpace(f), ":", 2)
|
||||
@@ -138,36 +87,7 @@ func parseFields(rawFields string) (map[string]FieldModifier, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func initModifer(rawModifier string) (FieldModifier, error) {
|
||||
t := tokenizer.NewFromString(rawModifier)
|
||||
t.Separators('(', ')', ',', ' ')
|
||||
t.IgnoreParenthesis(true)
|
||||
|
||||
parts, err := t.ScanAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty modifier expression %q", rawModifier)
|
||||
}
|
||||
|
||||
name := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
switch name {
|
||||
case "excerpt":
|
||||
m, err := newExcerptModifier(args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid excerpt modifier: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("missing or invalid modifier %q", name)
|
||||
}
|
||||
|
||||
func pickParsedFields(data any, fields map[string]FieldModifier) error {
|
||||
func pickParsedFields(data any, fields map[string]Modifier) error {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
pickMapFields(v, fields)
|
||||
@@ -196,7 +116,7 @@ func pickParsedFields(data any, fields map[string]FieldModifier) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func pickMapFields(data map[string]any, fields map[string]FieldModifier) error {
|
||||
func pickMapFields(data map[string]any, fields map[string]Modifier) error {
|
||||
if len(fields) == 0 {
|
||||
return nil // nothing to pick
|
||||
}
|
||||
@@ -221,7 +141,7 @@ func pickMapFields(data map[string]any, fields map[string]FieldModifier) error {
|
||||
|
||||
DataLoop:
|
||||
for k := range data {
|
||||
matchingFields := make(map[string]FieldModifier, len(fields))
|
||||
matchingFields := make(map[string]Modifier, len(fields))
|
||||
for f, m := range fields {
|
||||
if strings.HasPrefix(f+".", k+".") {
|
||||
matchingFields[f] = m
|
||||
@@ -1,135 +1,13 @@
|
||||
package rest_test
|
||||
package picker_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
"github.com/pocketbase/pocketbase/tools/picker"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
serializer rest.Serializer
|
||||
statusCode int
|
||||
data any
|
||||
query string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"empty query",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"empty fields",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"missing fields",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{}`,
|
||||
},
|
||||
{
|
||||
">299 response",
|
||||
rest.Serializer{},
|
||||
300,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"<200 response",
|
||||
rest.Serializer{},
|
||||
199,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"non map response",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
"test",
|
||||
"fields=a,b,test",
|
||||
`"test"`,
|
||||
},
|
||||
{
|
||||
"non slice of map response",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
[]any{"a", "b", "test"},
|
||||
"fields=a,test",
|
||||
`["a","b","test"]`,
|
||||
},
|
||||
{
|
||||
"map with no matching field",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing", // test individual fields trim
|
||||
`{}`,
|
||||
},
|
||||
{
|
||||
"map with existing and missing fields",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=a, c ,missing", // test individual fields trim
|
||||
`{"a":1,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"custom fields param",
|
||||
rest.Serializer{FieldsParam: "custom"},
|
||||
200,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"custom=a, c ,missing", // test individual fields trim
|
||||
`{"a":1,"c":"test"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.URL.RawQuery = s.query
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Status = s.statusCode
|
||||
|
||||
if err := s.serializer.Serialize(c, s.data, ""); err != nil {
|
||||
t.Fatalf("Serialize failure: %v", err)
|
||||
}
|
||||
|
||||
rawBody, err := io.ReadAll(rec.Result().Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read request body: %v", err)
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(string(rawBody)); v != s.expected {
|
||||
t.Fatalf("Expected body\n%v \ngot \n%v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickFields(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
@@ -374,7 +252,7 @@ func TestPickFields(t *testing.T) {
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result, err := rest.PickFields(s.data, s.fields)
|
||||
result, err := picker.Pick(s.data, s.fields)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
@@ -1,170 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// MultipartJsonKey is the key for the special multipart/form-data
|
||||
// handling allowing reading serialized json payload without normalization.
|
||||
const MultipartJsonKey string = "@jsonPayload"
|
||||
|
||||
// MultiBinder is similar to [echo.DefaultBinder] but uses slightly different
|
||||
// application/json and multipart/form-data bind methods to accommodate better
|
||||
// the PocketBase router needs.
|
||||
type MultiBinder struct{}
|
||||
|
||||
// Bind implements the [Binder.Bind] method.
|
||||
//
|
||||
// Bind is almost identical to [echo.DefaultBinder.Bind] but uses the
|
||||
// [rest.BindBody] function for binding the request body.
|
||||
func (b *MultiBinder) Bind(c echo.Context, i interface{}) (err error) {
|
||||
if err := echo.BindPathParams(c, i); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
|
||||
// For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues.
|
||||
method := c.Request().Method
|
||||
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
|
||||
if err = echo.BindQueryParams(c, i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return BindBody(c, i)
|
||||
}
|
||||
|
||||
// BindBody binds request body content to i.
|
||||
//
|
||||
// This is similar to `echo.BindBody()`, but for JSON requests uses
|
||||
// custom json reader that **copies** the request body, allowing multiple reads.
|
||||
func BindBody(c echo.Context, i any) error {
|
||||
req := c.Request()
|
||||
if req.ContentLength == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctype := req.Header.Get(echo.HeaderContentType)
|
||||
switch {
|
||||
case strings.HasPrefix(ctype, echo.MIMEApplicationJSON):
|
||||
err := CopyJsonBody(c.Request(), i)
|
||||
if err != nil {
|
||||
return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
return nil
|
||||
case strings.HasPrefix(ctype, echo.MIMEApplicationForm), strings.HasPrefix(ctype, echo.MIMEMultipartForm):
|
||||
return bindFormData(c, i)
|
||||
}
|
||||
|
||||
// fallback to the default binder
|
||||
return echo.BindBody(c, i)
|
||||
}
|
||||
|
||||
// CopyJsonBody reads the request body into i by
|
||||
// creating a copy of `r.Body` to allow multiple reads.
|
||||
func CopyJsonBody(r *http.Request, i any) error {
|
||||
body := r.Body
|
||||
|
||||
// this usually shouldn't be needed because the Server calls close
|
||||
// for us but we are changing the request body with a new reader
|
||||
defer body.Close()
|
||||
|
||||
limitReader := io.LimitReader(body, DefaultMaxMemory)
|
||||
|
||||
bodyBytes, readErr := io.ReadAll(limitReader)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
|
||||
err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(i)
|
||||
|
||||
// set new body reader
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Custom multipart/form-data binder that implements an additional handling like
|
||||
// loading a serialized json payload or properly scan array values when a map destination is used.
|
||||
func bindFormData(c echo.Context, i any) error {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
values, err := c.FormValues()
|
||||
if err != nil {
|
||||
return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// special case to allow submitting json without normalization
|
||||
// alongside the other multipart/form-data values
|
||||
jsonPayloadValues := values[MultipartJsonKey]
|
||||
for _, payload := range jsonPayloadValues {
|
||||
json.Unmarshal([]byte(payload), i)
|
||||
}
|
||||
|
||||
rt := reflect.TypeOf(i).Elem()
|
||||
|
||||
// map
|
||||
if rt.Kind() == reflect.Map {
|
||||
rv := reflect.ValueOf(i).Elem()
|
||||
|
||||
for k, v := range values {
|
||||
if k == MultipartJsonKey {
|
||||
continue
|
||||
}
|
||||
|
||||
if total := len(v); total == 1 {
|
||||
rv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalizeMultipartValue(v[0])))
|
||||
} else {
|
||||
normalized := make([]any, total)
|
||||
for i, vItem := range v {
|
||||
normalized[i] = normalizeMultipartValue(vItem)
|
||||
}
|
||||
rv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// anything else
|
||||
return echo.BindBody(c, i)
|
||||
}
|
||||
|
||||
// In order to support more seamlessly both json and multipart/form-data requests,
|
||||
// the following normalization rules are applied for plain multipart string values:
|
||||
// - "true" is converted to the json `true`
|
||||
// - "false" is converted to the json `false`
|
||||
// - numeric (non-scientific) strings are converted to json number
|
||||
// - any other string (empty string too) is left as it is
|
||||
func normalizeMultipartValue(raw string) any {
|
||||
switch raw {
|
||||
case "":
|
||||
return raw
|
||||
case "true":
|
||||
return true
|
||||
case "false":
|
||||
return false
|
||||
default:
|
||||
if raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9') {
|
||||
if v, err := cast.ToFloat64E(raw); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return raw
|
||||
}
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestMultiBinderBind(t *testing.T) {
|
||||
binder := rest.MultiBinder{}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test?query=123", strings.NewReader(`{"body":"456"}`))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
e.Any("/:name", func(c echo.Context) error {
|
||||
// bind twice to ensure that the json body reader copy is invoked
|
||||
for i := 0; i < 2; i++ {
|
||||
data := struct {
|
||||
Name string `param:"name"`
|
||||
Query string `query:"query"`
|
||||
Body string `form:"body"`
|
||||
}{}
|
||||
|
||||
if err := binder.Bind(c, &data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if data.Name != "test" {
|
||||
t.Fatalf("Expected Name %q, got %q", "test", data.Name)
|
||||
}
|
||||
|
||||
if data.Query != "123" {
|
||||
t.Fatalf("Expected Query %q, got %q", "123", data.Query)
|
||||
}
|
||||
|
||||
if data.Body != "456" {
|
||||
t.Fatalf("Expected Body %q, got %q", "456", data.Body)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
func TestBindBody(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
body io.Reader
|
||||
contentType string
|
||||
expectBody string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
strings.NewReader(""),
|
||||
echo.MIMEApplicationJSON,
|
||||
`{}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
strings.NewReader(`{"test":"invalid`),
|
||||
echo.MIMEApplicationJSON,
|
||||
`{}`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
strings.NewReader(`{"test":123}`),
|
||||
echo.MIMEApplicationJSON,
|
||||
`{"test":123}`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
strings.NewReader(
|
||||
url.Values{
|
||||
"string": []string{"str"},
|
||||
"stings": []string{"str1", "str2", ""},
|
||||
"number": []string{"-123"},
|
||||
"numbers": []string{"123", "456.789"},
|
||||
"bool": []string{"true"},
|
||||
"bools": []string{"true", "false"},
|
||||
rest.MultipartJsonKey: []string{`invalid`, `{"a":123}`, `{"b":456}`},
|
||||
}.Encode(),
|
||||
),
|
||||
echo.MIMEApplicationForm,
|
||||
`{"a":123,"b":456,"bool":true,"bools":[true,false],"number":-123,"numbers":[123,456.789],"stings":["str1","str2",""],"string":"str"}`,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", scenario.body)
|
||||
req.Header.Set(echo.HeaderContentType, scenario.contentType)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
data := map[string]any{}
|
||||
err := rest.BindBody(c, &data)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("[%d] Expected hasErr %v, got %v", i, scenario.expectError, hasErr)
|
||||
}
|
||||
|
||||
rawBody, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
t.Errorf("[%d] Failed to marshal binded body: %v", i, err)
|
||||
}
|
||||
|
||||
if scenario.expectBody != string(rawBody) {
|
||||
t.Errorf("[%d] Expected body \n%s, \ngot \n%s", i, scenario.expectBody, rawBody)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyJsonBody(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"test":"test123"}`))
|
||||
|
||||
// simulate multiple reads from the same request
|
||||
result1 := map[string]string{}
|
||||
rest.CopyJsonBody(req, &result1)
|
||||
result2 := map[string]string{}
|
||||
rest.CopyJsonBody(req, &result2)
|
||||
|
||||
if len(result1) == 0 {
|
||||
t.Error("Expected result1 to be filled")
|
||||
}
|
||||
|
||||
if len(result2) == 0 {
|
||||
t.Error("Expected result2 to be filled")
|
||||
}
|
||||
|
||||
if v, ok := result1["test"]; !ok || v != "test123" {
|
||||
t.Errorf("Expected result1.test to be %q, got %q", "test123", v)
|
||||
}
|
||||
|
||||
if v, ok := result2["test"]; !ok || v != "test123" {
|
||||
t.Errorf("Expected result2.test to be %q, got %q", "test123", v)
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
// DefaultMaxMemory defines the default max memory bytes that
|
||||
// will be used when parsing a form request body.
|
||||
const DefaultMaxMemory = 32 << 20 // 32mb
|
||||
|
||||
// FindUploadedFiles extracts all form files of "key" from a http request
|
||||
// and returns a slice with filesystem.File instances (if any).
|
||||
func FindUploadedFiles(r *http.Request, key string) ([]*filesystem.File, error) {
|
||||
if r.MultipartForm == nil {
|
||||
err := r.ParseMultipartForm(DefaultMaxMemory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 {
|
||||
return nil, http.ErrMissingFile
|
||||
}
|
||||
|
||||
result := make([]*filesystem.File, 0, len(r.MultipartForm.File[key]))
|
||||
|
||||
for _, fh := range r.MultipartForm.File[key] {
|
||||
file, err := filesystem.NewFileFromMultipart(fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, file)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestFindUploadedFiles(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
filename string
|
||||
expectedPattern string
|
||||
}{
|
||||
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
|
||||
{"test", `^test_\w{10}\.txt$`},
|
||||
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
|
||||
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
// create multipart form file body
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", s.filename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w.Write([]byte("test"))
|
||||
mp.Close()
|
||||
// ---
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := rest.FindUploadedFiles(req, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Errorf("[%d] Expected 1 file, got %d", i, len(result))
|
||||
}
|
||||
|
||||
if result[0].Size != 4 {
|
||||
t.Errorf("[%d] Expected the file size to be 4 bytes, got %d", i, result[0].Size)
|
||||
}
|
||||
|
||||
pattern, err := regexp.Compile(s.expectedPattern)
|
||||
if err != nil {
|
||||
t.Errorf("[%d] Invalid filename pattern %q: %v", i, s.expectedPattern, err)
|
||||
}
|
||||
if !pattern.MatchString(result[0].Name) {
|
||||
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUploadedFilesMissing(t *testing.T) {
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
mp.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := rest.FindUploadedFiles(req, "test")
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected result to be nil, got %v", result)
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeUrl removes duplicated slashes from a url path.
|
||||
func NormalizeUrl(originalUrl string) (string, error) {
|
||||
u, err := url.Parse(originalUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hasSlash := strings.HasSuffix(u.Path, "/")
|
||||
|
||||
// clean up path by removing duplicated /
|
||||
u.Path = path.Clean(u.Path)
|
||||
u.RawPath = path.Clean(u.RawPath)
|
||||
|
||||
// restore original trailing slash
|
||||
if hasSlash && !strings.HasSuffix(u.Path, "/") {
|
||||
u.Path += "/"
|
||||
u.RawPath += "/"
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestNormalizeUrl(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
url string
|
||||
expectError bool
|
||||
expectUrl string
|
||||
}{
|
||||
{":/", true, ""},
|
||||
{"./", false, "./"},
|
||||
{"../../test////", false, "../../test/"},
|
||||
{"/a/b/c", false, "/a/b/c"},
|
||||
{"a/////b//c/", false, "a/b/c/"},
|
||||
{"/a/////b//c", false, "/a/b/c"},
|
||||
{"///a/b/c", false, "/a/b/c"},
|
||||
{"//a/b/c", false, "//a/b/c"}, // preserve "auto-schema"
|
||||
{"http://a/b/c", false, "http://a/b/c"},
|
||||
{"a//bc?test=1//dd", false, "a/bc?test=1//dd"}, // only the path is normalized
|
||||
{"a//bc?test=1#12///3", false, "a/bc?test=1#12///3"}, // only the path is normalized
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := rest.NormalizeUrl(s.url)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v", i, s.expectError, hasErr)
|
||||
}
|
||||
|
||||
if result != s.expectUrl {
|
||||
t.Errorf("(%d) Expected url %q, got %q", i, s.expectUrl, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
// SafeErrorItem defines a common error interface for a printable public safe error.
|
||||
type SafeErrorItem interface {
|
||||
// Code represents a fixed unique identifier of the error (usually used as translation key).
|
||||
Code() string
|
||||
|
||||
// Error is the default English human readable error message that will be returned.
|
||||
Error() string
|
||||
}
|
||||
|
||||
// SafeErrorParamsResolver defines an optional interface for specifying dynamic error parameters.
|
||||
type SafeErrorParamsResolver interface {
|
||||
// Params defines a map with dynamic parameters to return as part of the public safe error view.
|
||||
Params() map[string]any
|
||||
}
|
||||
|
||||
// SafeErrorResolver defines an error interface for resolving the public safe error fields.
|
||||
type SafeErrorResolver interface {
|
||||
// Resolve allows modifying and returning a new public safe error data map.
|
||||
Resolve(errData map[string]any) any
|
||||
}
|
||||
|
||||
// ApiError defines the struct for a basic api error response.
|
||||
type ApiError struct {
|
||||
rawData any
|
||||
|
||||
Data map[string]any `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// Error makes it compatible with the `error` interface.
|
||||
func (e *ApiError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// RawData returns the unformatted error data (could be an internal error, text, etc.)
|
||||
func (e *ApiError) RawData() any {
|
||||
return e.rawData
|
||||
}
|
||||
|
||||
// Is reports whether the current ApiError wraps the target.
|
||||
func (e *ApiError) Is(target error) bool {
|
||||
err, ok := e.rawData.(error)
|
||||
if ok {
|
||||
return errors.Is(err, target)
|
||||
}
|
||||
|
||||
apiErr, ok := target.(*ApiError)
|
||||
|
||||
return ok && e == apiErr
|
||||
}
|
||||
|
||||
// NewNotFoundError creates and returns 404 ApiError.
|
||||
func NewNotFoundError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "The requested resource wasn't found."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusNotFound, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewBadRequestError creates and returns 400 ApiError.
|
||||
func NewBadRequestError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusBadRequest, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewForbiddenError creates and returns 403 ApiError.
|
||||
func NewForbiddenError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "You are not allowed to perform this request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusForbidden, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError creates and returns 401 ApiError.
|
||||
func NewUnauthorizedError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Missing or invalid authentication."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusUnauthorized, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewInternalServerError creates and returns 500 ApiError.
|
||||
func NewInternalServerError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusInternalServerError, message, rawErrData)
|
||||
}
|
||||
|
||||
func NewTooManyRequestsError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Too Many Requests."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusTooManyRequests, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewApiError creates and returns new normalized ApiError instance.
|
||||
func NewApiError(status int, message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
|
||||
return &ApiError{
|
||||
rawData: rawErrData,
|
||||
Data: safeErrorsData(rawErrData),
|
||||
Status: status,
|
||||
Message: strings.TrimSpace(inflector.Sentenize(message)),
|
||||
}
|
||||
}
|
||||
|
||||
// ToApiError wraps err into ApiError instance (if not already).
|
||||
func ToApiError(err error) *ApiError {
|
||||
var apiErr *ApiError
|
||||
|
||||
if !errors.As(err, &apiErr) {
|
||||
// no ApiError found -> assign a generic one
|
||||
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, fs.ErrNotExist) {
|
||||
apiErr = NewNotFoundError("", err)
|
||||
} else {
|
||||
apiErr = NewBadRequestError("", err)
|
||||
}
|
||||
}
|
||||
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func safeErrorsData(data any) map[string]any {
|
||||
switch v := data.(type) {
|
||||
case validation.Errors:
|
||||
return resolveSafeErrorsData(v)
|
||||
case error:
|
||||
validationErrors := validation.Errors{}
|
||||
if errors.As(v, &validationErrors) {
|
||||
return resolveSafeErrorsData(validationErrors)
|
||||
}
|
||||
return map[string]any{} // not nil to ensure that is json serialized as object
|
||||
case map[string]validation.Error:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]SafeErrorItem:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]error:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]string:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]any:
|
||||
return resolveSafeErrorsData(v)
|
||||
default:
|
||||
return map[string]any{} // not nil to ensure that is json serialized as object
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
|
||||
result := map[string]any{}
|
||||
|
||||
for name, err := range data {
|
||||
if isNestedError(err) {
|
||||
result[name] = safeErrorsData(err)
|
||||
} else {
|
||||
result[name] = resolveSafeErrorItem(err)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isNestedError(err any) bool {
|
||||
switch err.(type) {
|
||||
case validation.Errors,
|
||||
map[string]validation.Error,
|
||||
map[string]SafeErrorItem,
|
||||
map[string]error,
|
||||
map[string]string,
|
||||
map[string]any:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveSafeErrorItem extracts from each validation error its
|
||||
// public safe error code and message.
|
||||
func resolveSafeErrorItem(err any) any {
|
||||
data := map[string]any{}
|
||||
|
||||
if obj, ok := err.(SafeErrorItem); ok {
|
||||
// extract the specific error code and message
|
||||
data["code"] = obj.Code()
|
||||
data["message"] = inflector.Sentenize(obj.Error())
|
||||
} else {
|
||||
// fallback to the default public safe values
|
||||
data["code"] = "validation_invalid_value"
|
||||
data["message"] = "Invalid value."
|
||||
}
|
||||
|
||||
if s, ok := err.(SafeErrorParamsResolver); ok {
|
||||
params := s.Params()
|
||||
if len(params) > 0 {
|
||||
data["params"] = params
|
||||
}
|
||||
}
|
||||
|
||||
if s, ok := err.(SafeErrorResolver); ok {
|
||||
return s.Resolve(data)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestNewApiErrorWithRawData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := router.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
"rawData_test",
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"data":{},"message":"Message_test.","status":300}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%v\ngot\n%v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() != "rawData_test" {
|
||||
t.Errorf("Expected rawData\n%v\ngot\n%v", "rawData_test", e.RawData())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewApiErrorWithValidationData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := router.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
map[string]any{
|
||||
"err1": errors.New("test error"), // should be normalized
|
||||
"err2": validation.ErrRequired,
|
||||
"err3": validation.Errors{
|
||||
"err3.1": errors.New("test error"), // should be normalized
|
||||
"err3.2": validation.ErrRequired,
|
||||
"err3.3": validation.Errors{
|
||||
"err3.3.1": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
"err4": &mockSafeErrorItem{},
|
||||
"err5": map[string]error{
|
||||
"err5.1": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"err3.1":{"code":"validation_invalid_value","message":"Invalid value."},"err3.2":{"code":"validation_required","message":"Cannot be blank."},"err3.3":{"err3.3.1":{"code":"validation_required","message":"Cannot be blank."}}},"err4":{"code":"mock_code","message":"Mock_error.","mock_resolve":123},"err5":{"err5.1":{"code":"validation_required","message":"Cannot be blank."}}},"message":"Message_test.","status":300}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() == nil {
|
||||
t.Error("Expected non-nil rawData")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"The requested resource wasn't found.","status":404}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":404}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":404}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewNotFoundError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadRequestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":400}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":400}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":400}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewBadRequestError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForbiddenError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"You are not allowed to perform this request.","status":403}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":403}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":403}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewForbiddenError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Missing or invalid authentication.","status":401}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":401}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":401}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewUnauthorizedError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInternalServerError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":500}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":500}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":500}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewInternalServerError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTooManyRequestsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Too Many Requests.","status":429}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":429}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message").SetParams(map[string]any{"test": 123})}, `{"data":{"err1":{"code":"test_code","message":"Test_message.","params":{"test":123}}},"message":"Demo.","status":429}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewTooManyRequestsError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiErrorIs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err0 := router.NewInternalServerError("", nil)
|
||||
err1 := router.NewInternalServerError("", nil)
|
||||
err2 := errors.New("test")
|
||||
err3 := fmt.Errorf("wrapped: %w", err0)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
err error
|
||||
target error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"nil error",
|
||||
err0,
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non ApiError",
|
||||
err0,
|
||||
err1,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"different ApiError",
|
||||
err0,
|
||||
err2,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"same ApiError",
|
||||
err0,
|
||||
err0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrapped ApiError",
|
||||
err3,
|
||||
err0,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
is := errors.Is(s.err, s.target)
|
||||
|
||||
if is != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, is)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToApiError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
err error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"regular error",
|
||||
errors.New("test"),
|
||||
`{"data":{},"message":"Something went wrong while processing your request.","status":400}`,
|
||||
},
|
||||
{
|
||||
"fs.ErrNotExist",
|
||||
fs.ErrNotExist,
|
||||
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
|
||||
},
|
||||
{
|
||||
"sql.ErrNoRows",
|
||||
sql.ErrNoRows,
|
||||
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
|
||||
},
|
||||
{
|
||||
"ApiError",
|
||||
router.NewForbiddenError("test", nil),
|
||||
`{"data":{},"message":"Test.","status":403}`,
|
||||
},
|
||||
{
|
||||
"wrapped ApiError",
|
||||
fmt.Errorf("wrapped: %w", router.NewForbiddenError("test", nil)),
|
||||
`{"data":{},"message":"Test.","status":403}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
raw, err := json.Marshal(router.ToApiError(s.err))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected error\n%vgot\n%v", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
_ router.SafeErrorItem = (*mockSafeErrorItem)(nil)
|
||||
_ router.SafeErrorResolver = (*mockSafeErrorItem)(nil)
|
||||
)
|
||||
|
||||
type mockSafeErrorItem struct {
|
||||
}
|
||||
|
||||
func (m *mockSafeErrorItem) Code() string {
|
||||
return "mock_code"
|
||||
}
|
||||
|
||||
func (m *mockSafeErrorItem) Error() string {
|
||||
return "mock_error"
|
||||
}
|
||||
|
||||
func (m *mockSafeErrorItem) Resolve(errData map[string]any) any {
|
||||
errData["mock_resolve"] = 123
|
||||
return errData
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/picker"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
var ErrUnsupportedContentType = NewBadRequestError("Unsupported Content-Type", nil)
|
||||
var ErrInvalidRedirectStatusCode = NewInternalServerError("Invalid redirect status code", nil)
|
||||
var ErrFileNotFound = NewNotFoundError("File not found", nil)
|
||||
|
||||
const IndexPage = "index.html"
|
||||
|
||||
// Event specifies based Route handler event that is usually intended
|
||||
// to be embedded as part of a custom event struct.
|
||||
//
|
||||
// NB! It is expected that the Response and Request fields are always set.
|
||||
type Event struct {
|
||||
Response http.ResponseWriter
|
||||
Request *http.Request
|
||||
|
||||
hook.Event
|
||||
|
||||
data store.Store[any]
|
||||
}
|
||||
|
||||
// RWUnwrapper specifies that an http.ResponseWriter could be "unwrapped"
|
||||
// (usually used with [http.ResponseController]).
|
||||
type RWUnwrapper interface {
|
||||
Unwrap() http.ResponseWriter
|
||||
}
|
||||
|
||||
// Written reports whether the current response has already been written.
|
||||
//
|
||||
// This method always returns false if e.ResponseWritter doesn't implement the WriteTracker interface
|
||||
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
|
||||
func (e *Event) Written() bool {
|
||||
written, _ := getWritten(e.Response)
|
||||
return written
|
||||
}
|
||||
|
||||
// Status reports the status code of the current response.
|
||||
//
|
||||
// This method always returns 0 if e.Response doesn't implement the StatusTracker interface
|
||||
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
|
||||
func (e *Event) Status() int {
|
||||
status, _ := getStatus(e.Response)
|
||||
return status
|
||||
}
|
||||
|
||||
// Flush flushes buffered data to the current response.
|
||||
//
|
||||
// Returns [http.ErrNotSupported] if e.Response doesn't implement the [http.Flusher] interface
|
||||
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
|
||||
func (e *Event) Flush() error {
|
||||
return http.NewResponseController(e.Response).Flush()
|
||||
}
|
||||
|
||||
// IsTLS reports whether the connection on which the request was received is TLS.
|
||||
func (e *Event) IsTLS() bool {
|
||||
return e.Request.TLS != nil
|
||||
}
|
||||
|
||||
// SetCookie is an alias for [http.SetCookie].
|
||||
//
|
||||
// SetCookie adds a Set-Cookie header to the current response's headers.
|
||||
// The provided cookie must have a valid Name.
|
||||
// Invalid cookies may be silently dropped.
|
||||
func (e *Event) SetCookie(cookie *http.Cookie) {
|
||||
http.SetCookie(e.Response, cookie)
|
||||
}
|
||||
|
||||
// RemoteIP returns the IP address of the client that sent the request.
|
||||
//
|
||||
// IPv6 addresses are returned expanded.
|
||||
// For example, "2001:db8::1" becomes "2001:0db8:0000:0000:0000:0000:0000:0001".
|
||||
//
|
||||
// Note that if you are behind reverse proxy(ies), this method returns
|
||||
// the IP of the last connecting proxy.
|
||||
func (e *Event) RemoteIP() string {
|
||||
ip, _, _ := net.SplitHostPort(e.Request.RemoteAddr)
|
||||
parsed, _ := netip.ParseAddr(ip)
|
||||
return parsed.StringExpanded()
|
||||
}
|
||||
|
||||
// UnsafeRealIP returns the "real" client IP from common proxy headers
|
||||
// OR fallbacks to the RemoteIP if none is found.
|
||||
//
|
||||
// NB! The returned IP value could be anything and it shouldn't be trusted if not behind a trusted reverse proxy!
|
||||
func (e *Event) UnsafeRealIP() string {
|
||||
if ip := e.Request.Header.Get("CF-Connecting-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
if ip := e.Request.Header.Get("Fly-Client-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
if ip := e.Request.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
if ipsList := e.Request.Header.Get("X-Forwarded-For"); ipsList != "" {
|
||||
// extract the first non-empty leftmost-ish ip
|
||||
ips := strings.Split(ipsList, ",")
|
||||
for _, ip := range ips {
|
||||
ip = strings.TrimSpace(ip)
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return e.RemoteIP()
|
||||
}
|
||||
|
||||
// Store
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Get retrieves single value from the current event data store.
|
||||
func (e *Event) Get(key string) any {
|
||||
return e.data.Get(key)
|
||||
}
|
||||
|
||||
// GetAll returns a copy of the current event data store.
|
||||
func (e *Event) GetAll() map[string]any {
|
||||
return e.data.GetAll()
|
||||
}
|
||||
|
||||
// Set saves single value into the current event data store.
|
||||
func (e *Event) Set(key string, value any) {
|
||||
e.data.Set(key, value)
|
||||
}
|
||||
|
||||
// SetAll saves all items from m into the current event data store.
|
||||
func (e *Event) SetAll(m map[string]any) {
|
||||
for k, v := range m {
|
||||
e.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Response writers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const headerContentType = "Content-Type"
|
||||
|
||||
func (e *Event) setResponseHeaderIfEmpty(key, value string) {
|
||||
header := e.Response.Header()
|
||||
if header.Get(key) == "" {
|
||||
header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// String writes a plain string response.
|
||||
func (e *Event) String(status int, data string) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "text/plain; charset=utf-8")
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := e.Response.Write([]byte(data))
|
||||
return err
|
||||
}
|
||||
|
||||
// HTML writes an HTML response.
|
||||
func (e *Event) HTML(status int, data string) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "text/html; charset=utf-8")
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := e.Response.Write([]byte(data))
|
||||
return err
|
||||
}
|
||||
|
||||
const jsonFieldsParam = "fields"
|
||||
|
||||
// JSON writes a JSON response.
|
||||
//
|
||||
// It also provides a generic response data fields picker if the "fields" query parameter is set.
|
||||
func (e *Event) JSON(status int, data any) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "application/json")
|
||||
e.Response.WriteHeader(status)
|
||||
|
||||
rawFields := e.Request.URL.Query().Get(jsonFieldsParam)
|
||||
|
||||
// error response or no fields to pick
|
||||
if rawFields == "" || status < 200 || status > 299 {
|
||||
return json.NewEncoder(e.Response).Encode(data)
|
||||
}
|
||||
|
||||
// pick only the requested fields
|
||||
modified, err := picker.Pick(data, rawFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.NewEncoder(e.Response).Encode(modified)
|
||||
}
|
||||
|
||||
// XML writes an XML response.
|
||||
// It automatically prepends the generic [xml.Header] string to the response.
|
||||
func (e *Event) XML(status int, data any) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "application/xml; charset=utf-8")
|
||||
e.Response.WriteHeader(status)
|
||||
if _, err := e.Response.Write([]byte(xml.Header)); err != nil {
|
||||
return err
|
||||
}
|
||||
return xml.NewEncoder(e.Response).Encode(data)
|
||||
}
|
||||
|
||||
// Stream streams the specified reader into the response.
|
||||
func (e *Event) Stream(status int, contentType string, reader io.Reader) error {
|
||||
e.Response.Header().Set(headerContentType, contentType)
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := io.Copy(e.Response, reader)
|
||||
return err
|
||||
}
|
||||
|
||||
// FileFS serves the specified filename from fsys.
|
||||
//
|
||||
// It is similar to [echo.FileFS] for consistency with earlier versions.
|
||||
func (e *Event) FileFS(fsys fs.FS, filename string) error {
|
||||
f, err := fsys.Open(filename)
|
||||
if err != nil {
|
||||
return ErrFileNotFound
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if it is a directory try to open its index.html file
|
||||
if fi.IsDir() {
|
||||
filename = filepath.ToSlash(filepath.Join(filename, IndexPage))
|
||||
f, err = fsys.Open(filename)
|
||||
if err != nil {
|
||||
return ErrFileNotFound
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err = f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ff, ok := f.(io.ReadSeeker)
|
||||
if !ok {
|
||||
return errors.New("[FileFS] file does not implement io.ReadSeeker")
|
||||
}
|
||||
|
||||
http.ServeContent(e.Response, e.Request, fi.Name(), fi.ModTime(), ff)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NoContent writes a response with no body (ex. 204).
|
||||
func (e *Event) NoContent(status int) error {
|
||||
e.Response.WriteHeader(status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Redirect writes a redirect response to the specified url.
|
||||
// The status code must be in between 300 – 399 range.
|
||||
func (e *Event) Redirect(status int, url string) error {
|
||||
if status < 300 || status > 399 {
|
||||
return ErrInvalidRedirectStatusCode
|
||||
}
|
||||
e.Response.Header().Set("Location", url)
|
||||
e.Response.WriteHeader(status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApiError helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func (e *Event) Error(status int, message string, errData any) *ApiError {
|
||||
return NewApiError(status, message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) BadRequestError(message string, errData any) *ApiError {
|
||||
return NewBadRequestError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) NotFoundError(message string, errData any) *ApiError {
|
||||
return NewNotFoundError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) ForbiddenError(message string, errData any) *ApiError {
|
||||
return NewForbiddenError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) UnauthorizedError(message string, errData any) *ApiError {
|
||||
return NewUnauthorizedError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) TooManyRequestsError(message string, errData any) *ApiError {
|
||||
return NewTooManyRequestsError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) InternalServerError(message string, errData any) *ApiError {
|
||||
return NewInternalServerError(message, errData)
|
||||
}
|
||||
|
||||
// Binders
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const DefaultMaxMemory = 32 << 20 // 32mb
|
||||
|
||||
// Supports the following content-types:
|
||||
//
|
||||
// - application/json
|
||||
// - multipart/form-data
|
||||
// - application/x-www-form-urlencoded
|
||||
// - text/xml, application/xml
|
||||
func (e *Event) BindBody(dst any) error {
|
||||
if e.Request.ContentLength == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
contentType := e.Request.Header.Get(headerContentType)
|
||||
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
dec := json.NewDecoder(e.Request.Body)
|
||||
err := dec.Decode(dst)
|
||||
if err == nil {
|
||||
// manually call Reread because single call of json.Decoder.Decode()
|
||||
// doesn't ensure that the entire body is a valid json string
|
||||
// and it is not guaranteed that it will reach EOF to trigger the reread reset
|
||||
// (ex. in case of trailing spaces or invalid trailing parts like: `{"test":1},something`)
|
||||
if body, ok := e.Request.Body.(Rereader); ok {
|
||||
body.Reread()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
if err := e.Request.ParseMultipartForm(DefaultMaxMemory); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalRequestData(e.Request.Form, dst, "", "")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
|
||||
if err := e.Request.ParseForm(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalRequestData(e.Request.Form, dst, "", "")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(contentType, "text/xml") ||
|
||||
strings.HasPrefix(contentType, "application/xml") {
|
||||
return xml.NewDecoder(e.Request.Body).Decode(dst)
|
||||
}
|
||||
|
||||
return ErrUnsupportedContentType
|
||||
}
|
||||
@@ -0,0 +1,924 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
type unwrapTester struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (ut unwrapTester) Unwrap() http.ResponseWriter {
|
||||
return ut.ResponseWriter
|
||||
}
|
||||
|
||||
func TestEventWritten(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
res2.Write([]byte("test"))
|
||||
|
||||
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
|
||||
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
res4.Write([]byte("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
response http.ResponseWriter
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "non-written non-WriteTracker",
|
||||
response: res1,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "written non-WriteTracker",
|
||||
response: res2,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-written WriteTracker",
|
||||
response: res3,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "written WriteTracker",
|
||||
response: res4,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
event := router.Event{
|
||||
Response: s.response,
|
||||
}
|
||||
|
||||
result := event.Written()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
res2.WriteHeader(123)
|
||||
|
||||
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
|
||||
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
res4.WriteHeader(123)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
response http.ResponseWriter
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "non-written non-StatusTracker",
|
||||
response: res1,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "written non-StatusTracker",
|
||||
response: res2,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "non-written StatusTracker",
|
||||
response: res3,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "written StatusTracker",
|
||||
response: res4,
|
||||
expected: 123,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
event := router.Event{
|
||||
Response: s.response,
|
||||
}
|
||||
|
||||
result := event.Status()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %d, got %d", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventIsTLS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
// without TLS
|
||||
if event.IsTLS() {
|
||||
t.Fatalf("Expected IsTLS false")
|
||||
}
|
||||
|
||||
// dummy TLS state
|
||||
req.TLS = new(tls.ConnectionState)
|
||||
|
||||
// with TLS
|
||||
if !event.IsTLS() {
|
||||
t.Fatalf("Expected IsTLS true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
event := router.Event{
|
||||
Response: httptest.NewRecorder(),
|
||||
}
|
||||
|
||||
cookie := event.Response.Header().Get("set-cookie")
|
||||
if cookie != "" {
|
||||
t.Fatalf("Expected empty cookie string, got %q", cookie)
|
||||
}
|
||||
|
||||
event.SetCookie(&http.Cookie{Name: "test", Value: "a"})
|
||||
|
||||
expected := "test=a"
|
||||
|
||||
cookie = event.Response.Header().Get("set-cookie")
|
||||
if cookie != expected {
|
||||
t.Fatalf("Expected cookie %q, got %q", expected, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventRemoteIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{"", "invalid IP"},
|
||||
{"1.2.3.4", "invalid IP"},
|
||||
{"1.2.3.4:8090", "1.2.3.4"},
|
||||
{"[0000:0000:0000:0000:0000:0000:0000:0002]:80", "0000:0000:0000:0000:0000:0000:0000:0002"},
|
||||
{"[::2]:80", "0000:0000:0000:0000:0000:0000:0000:0002"}, // should always return the expanded version
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.remoteAddr, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.RemoteAddr = s.remoteAddr
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
ip := event.RemoteIP()
|
||||
|
||||
if ip != s.expected {
|
||||
t.Fatalf("Expected IP %q, got %q", s.expected, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventUnsafeRealIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
headers map[string]string
|
||||
expected string
|
||||
}{
|
||||
{nil, "1.2.3.4"},
|
||||
{
|
||||
map[string]string{"CF-Connecting-IP": "test"},
|
||||
"test",
|
||||
},
|
||||
{
|
||||
map[string]string{"Fly-Client-IP": "test"},
|
||||
"test",
|
||||
},
|
||||
{
|
||||
map[string]string{"X-Real-IP": "test"},
|
||||
"test",
|
||||
},
|
||||
{
|
||||
map[string]string{"X-Forwarded-For": "test1,test2,test3"},
|
||||
"test1",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
keys := make([]string, 0, len(s.headers))
|
||||
for h := range s.headers {
|
||||
keys = append(keys, h)
|
||||
}
|
||||
|
||||
testName := strings.Join(keys, "_")
|
||||
if testName == "" {
|
||||
testName = "no_headers" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.RemoteAddr = "1.2.3.4:80" // fallback
|
||||
|
||||
for k, v := range s.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
ip := event.UnsafeRealIP()
|
||||
|
||||
if ip != s.expected {
|
||||
t.Fatalf("Expected IP %q, got %q", s.expected, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetGet(t *testing.T) {
|
||||
event := router.Event{}
|
||||
|
||||
// get before any set (ensures that doesn't panic)
|
||||
if v := event.Get("test"); v != nil {
|
||||
t.Fatalf("Expected nil value, got %v", v)
|
||||
}
|
||||
|
||||
event.Set("a", 123)
|
||||
event.Set("b", 456)
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
expected any
|
||||
}{
|
||||
{"", nil},
|
||||
{"missing", nil},
|
||||
{"a", 123},
|
||||
{"b", 456},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
|
||||
result := event.Get(s.key)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetAllGetAll(t *testing.T) {
|
||||
data := map[string]any{
|
||||
"a": 123,
|
||||
"b": 456,
|
||||
}
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event := router.Event{}
|
||||
event.SetAll(data)
|
||||
|
||||
// modify the data to ensure that the map was shallow coppied
|
||||
data["c"] = 789
|
||||
|
||||
result := event.GetAll()
|
||||
rawResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(rawResult) == 0 || !bytes.Equal(rawData, rawResult) {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", rawData, rawResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventString(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 123,
|
||||
headers: nil,
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/plain; charset=utf-8"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type",
|
||||
status: 123,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.String(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventHTML(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 123,
|
||||
headers: nil,
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/html; charset=utf-8"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type",
|
||||
status: 123,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.HTML(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventJSON(t *testing.T) {
|
||||
body := map[string]any{"a": 123, "b": 456, "c": "test"}
|
||||
expectedPickedBody := `{"a":123,"c":"test"}` + "\n"
|
||||
expectedFullBody := `{"a":123,"b":456,"c":"test"}` + "\n"
|
||||
|
||||
scenarios := []testResponseWriteScenario[any]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 200,
|
||||
headers: nil,
|
||||
body: body,
|
||||
expectedStatus: 200,
|
||||
expectedHeaders: map[string]string{"content-type": "application/json"},
|
||||
expectedBody: expectedPickedBody,
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type (200)",
|
||||
status: 200,
|
||||
headers: map[string]string{"content-type": "application/test"},
|
||||
body: body,
|
||||
expectedStatus: 200,
|
||||
expectedHeaders: map[string]string{"content-type": "application/test"},
|
||||
expectedBody: expectedPickedBody,
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type (400)", // no fields picker
|
||||
status: 400,
|
||||
headers: map[string]string{"content-type": "application/test"},
|
||||
body: body,
|
||||
expectedStatus: 400,
|
||||
expectedHeaders: map[string]string{"content-type": "application/test"},
|
||||
expectedBody: expectedFullBody,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
e.Request.URL.RawQuery = "fields=a,c" // ensures that the picker is invoked
|
||||
return e.JSON(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventXML(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 234,
|
||||
headers: nil,
|
||||
body: "test",
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "application/xml; charset=utf-8"},
|
||||
expectedBody: xml.Header + "<string>test</string>",
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: xml.Header + "<string>test</string>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.XML(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventStream(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "stream",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.Stream(s.status, s.headers["content-type"], strings.NewReader(s.body))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventNoContent(t *testing.T) {
|
||||
s := testResponseWriteScenario[any]{
|
||||
name: "no content",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: nil,
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "",
|
||||
}
|
||||
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.NoContent(s.status)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventFlush(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := &router.Event{
|
||||
Response: unwrapTester{&router.ResponseWriter{ResponseWriter: rec}},
|
||||
}
|
||||
event.Response.Write([]byte("test"))
|
||||
event.Flush()
|
||||
|
||||
if !rec.Flushed {
|
||||
t.Fatal("Expected response to be flushed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventRedirect(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[any]{
|
||||
{
|
||||
name: "non-30x status",
|
||||
status: 200,
|
||||
expectedStatus: 200,
|
||||
expectedError: router.ErrInvalidRedirectStatusCode,
|
||||
},
|
||||
{
|
||||
name: "30x status",
|
||||
status: 302,
|
||||
headers: map[string]string{"location": "test"}, // should be overwritten with the argument
|
||||
expectedStatus: 302,
|
||||
expectedHeaders: map[string]string{"location": "example"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.Redirect(s.status, "example")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventFileFS(t *testing.T) {
|
||||
// stub test files
|
||||
// ---
|
||||
dir, err := os.MkdirTemp("", "EventFileFS")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dir, "index.html"), []byte("index"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filepath.Join(dir, "test.txt"), []byte("test"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create sub directory with an index.html file inside it
|
||||
err = os.MkdirAll(filepath.Join(dir, "sub1"), os.ModePerm)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = os.WriteFile(filepath.Join(dir, "sub1", "index.html"), []byte("sub1 index"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = os.MkdirAll(filepath.Join(dir, "sub2"), os.ModePerm)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = os.WriteFile(filepath.Join(dir, "sub2", "test.txt"), []byte("sub2 test"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"missing file", "", ""},
|
||||
{"root with no explicit file", "", ""},
|
||||
{"root with explicit file", "test.txt", "test"},
|
||||
{"sub dir with no explicit file", "sub1", "sub1 index"},
|
||||
{"sub dir with no explicit file (no index.html)", "sub2", ""},
|
||||
{"sub dir explicit file", "sub2/test.txt", "sub2 test"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := &router.Event{
|
||||
Request: req,
|
||||
Response: rec,
|
||||
}
|
||||
|
||||
err = event.FileFS(os.DirFS(dir), s.path)
|
||||
|
||||
hasErr := err != nil
|
||||
expectErr := s.expected == ""
|
||||
if hasErr != expectErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", expectErr, hasErr, err)
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
raw, err := io.ReadAll(result.Body)
|
||||
result.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(raw) != s.expected {
|
||||
t.Fatalf("Expected body\n%s\ngot\n%s", s.expected, raw)
|
||||
}
|
||||
|
||||
// ensure that the proper file headers are added
|
||||
// (aka. http.ServeContent is invoked)
|
||||
length, _ := strconv.Atoi(result.Header.Get("content-length"))
|
||||
if length != len(s.expected) {
|
||||
t.Fatalf("Expected Content-Length %d, got %d", len(s.expected), length)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventError(t *testing.T) {
|
||||
err := new(router.Event).Error(123, "message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":123}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBadRequestError(t *testing.T) {
|
||||
err := new(router.Event).BadRequestError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":400}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventNotFoundError(t *testing.T) {
|
||||
err := new(router.Event).NotFoundError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":404}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventForbiddenError(t *testing.T) {
|
||||
err := new(router.Event).ForbiddenError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":403}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventUnauthorizedError(t *testing.T) {
|
||||
err := new(router.Event).UnauthorizedError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":401}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventTooManyRequestsError(t *testing.T) {
|
||||
err := new(router.Event).TooManyRequestsError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":429}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventInternalServerError(t *testing.T) {
|
||||
err := new(router.Event).InternalServerError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":500}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBindBody(t *testing.T) {
|
||||
type testDstStruct struct {
|
||||
A int `json:"a" xml:"a" form:"a"`
|
||||
B int `json:"b" xml:"b" form:"b"`
|
||||
C string `json:"c" xml:"c" form:"c"`
|
||||
}
|
||||
|
||||
emptyDst := `{"a":0,"b":0,"c":""}`
|
||||
|
||||
queryDst := `a=123&b=-456&c=test`
|
||||
|
||||
xmlDst := `
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<root>
|
||||
<a>123</a>
|
||||
<b>-456</b>
|
||||
<c>test</c>
|
||||
</root>
|
||||
`
|
||||
|
||||
jsonDst := `{"a":123,"b":-456,"c":"test"}`
|
||||
|
||||
// multipart
|
||||
mpBody := &bytes.Buffer{}
|
||||
mpWriter := multipart.NewWriter(mpBody)
|
||||
mpWriter.WriteField("@jsonPayload", `{"a":123}`)
|
||||
mpWriter.WriteField("b", "-456")
|
||||
mpWriter.WriteField("c", "test")
|
||||
if err := mpWriter.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
contentType string
|
||||
body io.Reader
|
||||
expectDst string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
contentType: "",
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: emptyDst,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
contentType: "application/rtf", // unsupported
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: emptyDst,
|
||||
expectError: true,
|
||||
},
|
||||
// empty body
|
||||
{
|
||||
contentType: "application/json;charset=emptybody",
|
||||
body: strings.NewReader(""),
|
||||
expectDst: emptyDst,
|
||||
},
|
||||
// json
|
||||
{
|
||||
contentType: "application/json",
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/json;charset=abc",
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
// xml
|
||||
{
|
||||
contentType: "text/xml",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "text/xml;charset=abc",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/xml",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/xml;charset=abc",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
// x-www-form-urlencoded
|
||||
{
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
body: strings.NewReader(queryDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/x-www-form-urlencoded;charset=abc",
|
||||
body: strings.NewReader(queryDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
// multipart
|
||||
{
|
||||
contentType: mpWriter.FormDataContentType(),
|
||||
body: mpBody,
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.contentType, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "/", s.body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add("content-type", s.contentType)
|
||||
|
||||
event := &router.Event{Request: req}
|
||||
|
||||
dst := testDstStruct{}
|
||||
|
||||
err = event.BindBody(&dst)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
dstRaw, err := json.Marshal(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(dstRaw) != s.expectDst {
|
||||
t.Fatalf("Expected dst\n%s\ngot\n%s", s.expectDst, dstRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testResponseWriteScenario[T any] struct {
|
||||
name string
|
||||
status int
|
||||
headers map[string]string
|
||||
body T
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
expectedBody string
|
||||
expectedError error
|
||||
}
|
||||
|
||||
func testEventResponseWrite[T any](
|
||||
t *testing.T,
|
||||
scenario testResponseWriteScenario[T],
|
||||
writeFunc func(e *router.Event) error,
|
||||
) {
|
||||
t.Run(scenario.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
event := &router.Event{
|
||||
Request: req,
|
||||
Response: &router.ResponseWriter{ResponseWriter: rec},
|
||||
}
|
||||
|
||||
for k, v := range scenario.headers {
|
||||
event.Response.Header().Add(k, v)
|
||||
}
|
||||
|
||||
err = writeFunc(event)
|
||||
if (scenario.expectedError != nil || err != nil) && !errors.Is(err, scenario.expectedError) {
|
||||
t.Fatalf("Expected error %v, got %v", scenario.expectedError, err)
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
if result.StatusCode != scenario.expectedStatus {
|
||||
t.Fatalf("Expected status code %d, got %d", scenario.expectedStatus, result.StatusCode)
|
||||
}
|
||||
|
||||
resultBody, err := io.ReadAll(result.Body)
|
||||
result.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
resultBody, err = json.Marshal(string(resultBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedBody, err := json.Marshal(scenario.expectedBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(resultBody, expectedBody) {
|
||||
t.Fatalf("Expected body\n%s\ngot\n%s", expectedBody, resultBody)
|
||||
}
|
||||
|
||||
for k, ev := range scenario.expectedHeaders {
|
||||
if v := result.Header.Get(k); v != ev {
|
||||
t.Fatalf("Expected %q header to be %q, got %q", k, ev, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
// (note: the struct is named RouterGroup instead of Group so that it can
|
||||
// be embedded in the Router without conflicting with the Group method)
|
||||
|
||||
// RouterGroup represents a collection of routes and other sub groups
|
||||
// that share common pattern prefix and middlewares.
|
||||
type RouterGroup[T hook.Resolver] struct {
|
||||
excludedMiddlewares map[string]struct{}
|
||||
children []any // Route or RouterGroup
|
||||
|
||||
Prefix string
|
||||
Middlewares []*hook.Handler[T]
|
||||
}
|
||||
|
||||
// Group creates and register a new child Group into the current one
|
||||
// with the specified prefix.
|
||||
//
|
||||
// The prefix follows the standard Go net/http ServeMux pattern format ("[HOST]/[PATH]")
|
||||
// and will be concatenated recursively into the final route path, meaning that
|
||||
// only the root level group could have HOST as part of the prefix.
|
||||
//
|
||||
// Returns the newly created group to allow chaining and registering
|
||||
// sub-routes and group specific middlewares.
|
||||
func (group *RouterGroup[T]) Group(prefix string) *RouterGroup[T] {
|
||||
newGroup := &RouterGroup[T]{}
|
||||
newGroup.Prefix = prefix
|
||||
|
||||
group.children = append(group.children, newGroup)
|
||||
|
||||
return newGroup
|
||||
}
|
||||
|
||||
// BindFunc registers one or multiple middleware functions to the current group.
|
||||
//
|
||||
// The registered middleware functions are "anonymous" and with default priority,
|
||||
// aka. executes in the order they were registered.
|
||||
//
|
||||
// If you need to specify a named middleware (ex. so that it can be removed)
|
||||
// or middleware with custom exec prirority, use [Group.Bind] method.
|
||||
func (group *RouterGroup[T]) BindFunc(middlewareFuncs ...hook.HandlerFunc[T]) *RouterGroup[T] {
|
||||
for _, m := range middlewareFuncs {
|
||||
group.Middlewares = append(group.Middlewares, &hook.Handler[T]{Func: m})
|
||||
}
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// Bind registers one or multiple middleware handlers to the current group.
|
||||
func (group *RouterGroup[T]) Bind(middlewares ...*hook.Handler[T]) *RouterGroup[T] {
|
||||
group.Middlewares = append(group.Middlewares, middlewares...)
|
||||
|
||||
// unmark the newly added middlewares in case they were previously "excluded"
|
||||
if group.excludedMiddlewares != nil {
|
||||
for _, m := range middlewares {
|
||||
if m.Id != "" {
|
||||
delete(group.excludedMiddlewares, m.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// Unbind removes one or more middlewares with the specified id(s)
|
||||
// from the current group and its children (if any).
|
||||
//
|
||||
// Anonymous middlewares are not removable, aka. this method does nothing
|
||||
// if the middleware id is an empty string.
|
||||
func (group *RouterGroup[T]) Unbind(middlewareIds ...string) *RouterGroup[T] {
|
||||
for _, middlewareId := range middlewareIds {
|
||||
if middlewareId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// remove from the group middlwares
|
||||
for i := len(group.Middlewares) - 1; i >= 0; i-- {
|
||||
if group.Middlewares[i].Id == middlewareId {
|
||||
group.Middlewares = append(group.Middlewares[:i], group.Middlewares[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
// remove from the group children
|
||||
for i := len(group.children) - 1; i >= 0; i-- {
|
||||
switch v := group.children[i].(type) {
|
||||
case *RouterGroup[T]:
|
||||
v.Unbind(middlewareId)
|
||||
case *Route[T]:
|
||||
v.Unbind(middlewareId)
|
||||
}
|
||||
}
|
||||
|
||||
// add to the exclude list
|
||||
if group.excludedMiddlewares == nil {
|
||||
group.excludedMiddlewares = map[string]struct{}{}
|
||||
}
|
||||
group.excludedMiddlewares[middlewareId] = struct{}{}
|
||||
}
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// Route registers a single route into the current group.
|
||||
//
|
||||
// Note that the final route path will be the concatenation of all parent groups prefixes + the route path.
|
||||
// The path follows the standard Go net/http ServeMux format ("[HOST]/[PATH]"),
|
||||
// meaning that only a top level group route could have HOST as part of the prefix.
|
||||
//
|
||||
// Returns the newly created route to allow attaching route-only middlewares.
|
||||
func (group *RouterGroup[T]) Route(method string, path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
route := &Route[T]{
|
||||
Method: method,
|
||||
Path: path,
|
||||
Action: action,
|
||||
}
|
||||
|
||||
group.children = append(group.children, route)
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Any is a shorthand for [Group.AddRoute] with "" as route method (aka. matches any method).
|
||||
func (group *RouterGroup[T]) Any(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route("", path, action)
|
||||
}
|
||||
|
||||
// GET is a shorthand for [Group.AddRoute] with GET as route method.
|
||||
func (group *RouterGroup[T]) GET(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodGet, path, action)
|
||||
}
|
||||
|
||||
// POST is a shorthand for [Group.AddRoute] with POST as route method.
|
||||
func (group *RouterGroup[T]) POST(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodPost, path, action)
|
||||
}
|
||||
|
||||
// DELETE is a shorthand for [Group.AddRoute] with DELETE as route method.
|
||||
func (group *RouterGroup[T]) DELETE(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodDelete, path, action)
|
||||
}
|
||||
|
||||
// PATCH is a shorthand for [Group.AddRoute] with PATCH as route method.
|
||||
func (group *RouterGroup[T]) PATCH(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodPatch, path, action)
|
||||
}
|
||||
|
||||
// PUT is a shorthand for [Group.AddRoute] with PUT as route method.
|
||||
func (group *RouterGroup[T]) PUT(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodPut, path, action)
|
||||
}
|
||||
|
||||
// HEAD is a shorthand for [Group.AddRoute] with HEAD as route method.
|
||||
func (group *RouterGroup[T]) HEAD(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodHead, path, action)
|
||||
}
|
||||
|
||||
// OPTIONS is a shorthand for [Group.AddRoute] with OPTIONS as route method.
|
||||
func (group *RouterGroup[T]) OPTIONS(path string, action hook.HandlerFunc[T]) *Route[T] {
|
||||
return group.Route(http.MethodOptions, path, action)
|
||||
}
|
||||
|
||||
// HasRoute checks whether the specified route pattern (method + path)
|
||||
// is registered in the current group or its children.
|
||||
//
|
||||
// This could be useful to conditionally register and checks for routes
|
||||
// in order prevent panic on duplicated routes.
|
||||
//
|
||||
// Note that routes with anonymous and named wildcard placeholder are treated as equal,
|
||||
// aka. "GET /abc/" is considered the same as "GET /abc/{something...}".
|
||||
func (group *RouterGroup[T]) HasRoute(method string, path string) bool {
|
||||
pattern := path
|
||||
if method != "" {
|
||||
pattern = strings.ToUpper(method) + " " + pattern
|
||||
}
|
||||
|
||||
return group.hasRoute(pattern, nil)
|
||||
}
|
||||
|
||||
func (group *RouterGroup[T]) hasRoute(pattern string, parents []*RouterGroup[T]) bool {
|
||||
for _, child := range group.children {
|
||||
switch v := child.(type) {
|
||||
case *RouterGroup[T]:
|
||||
if v.hasRoute(pattern, append(parents, group)) {
|
||||
return true
|
||||
}
|
||||
case *Route[T]:
|
||||
var result string
|
||||
|
||||
if v.Method != "" {
|
||||
result += v.Method + " "
|
||||
}
|
||||
|
||||
// add parent groups prefixes
|
||||
for _, p := range parents {
|
||||
result += p.Prefix
|
||||
}
|
||||
|
||||
// add current group prefix
|
||||
result += group.Prefix
|
||||
|
||||
// add current route path
|
||||
result += v.Path
|
||||
|
||||
if result == pattern || // direct match
|
||||
// compares without the named wildcard, aka. /abc/{test...} is equal to /abc/
|
||||
stripWildcard(result) == stripWildcard(pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
|
||||
|
||||
func stripWildcard(pattern string) string {
|
||||
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "/")
|
||||
}
|
||||
@@ -0,0 +1,425 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
func TestRouterGroupGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g0 := RouterGroup[*Event]{}
|
||||
|
||||
g1 := g0.Group("test1")
|
||||
g2 := g0.Group("test2")
|
||||
|
||||
if total := len(g0.children); total != 2 {
|
||||
t.Fatalf("Expected %d child groups, got %d", 2, total)
|
||||
}
|
||||
|
||||
if g1.Prefix != "test1" {
|
||||
t.Fatalf("Expected g1 with prefix %q, got %q", "test1", g1.Prefix)
|
||||
}
|
||||
if g2.Prefix != "test2" {
|
||||
t.Fatalf("Expected g2 with prefix %q, got %q", "test2", g2.Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupBindFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g := RouterGroup[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one function
|
||||
g.BindFunc(func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
})
|
||||
|
||||
// append multiple functions
|
||||
g.BindFunc(
|
||||
func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(g.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range g.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupBind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g := RouterGroup[*Event]{
|
||||
// mock excluded middlewares to check whether the entry will be deleted
|
||||
excludedMiddlewares: map[string]struct{}{"test2": {}},
|
||||
}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one handler
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// append multiple handlers
|
||||
g.Bind(
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(g.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range g.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls %q, got %q", "abc", calls)
|
||||
}
|
||||
|
||||
// ensures that the previously excluded middleware was removed
|
||||
if len(g.excludedMiddlewares) != 0 {
|
||||
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", g.excludedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupUnbind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g := RouterGroup[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// anonymous middlewares
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// middlewares with id
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Id: "test3",
|
||||
Func: func(e *Event) error {
|
||||
calls += "d"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// remove
|
||||
g.Unbind("") // should be no-op
|
||||
g.Unbind("test1", "test3")
|
||||
|
||||
if total := len(g.Middlewares); total != 2 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 2, total)
|
||||
}
|
||||
|
||||
for _, h := range g.Middlewares {
|
||||
if err := h.Func(nil); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if calls != "ac" {
|
||||
t.Fatalf("Expected calls %q, got %q", "ac", calls)
|
||||
}
|
||||
|
||||
// ensure that the ids were added in the exclude list
|
||||
excluded := []string{"test1", "test3"}
|
||||
if len(g.excludedMiddlewares) != len(excluded) {
|
||||
t.Fatalf("Expected excludes %v, got %v", excluded, g.excludedMiddlewares)
|
||||
}
|
||||
for id := range g.excludedMiddlewares {
|
||||
if !slices.Contains(excluded, id) {
|
||||
t.Fatalf("Expected %q to be marked as excluded", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := RouterGroup[*Event]{}
|
||||
|
||||
sub := group.Group("sub")
|
||||
|
||||
var called bool
|
||||
route := group.Route(http.MethodPost, "/test", func(e *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// ensure that the route was registered only to the main one
|
||||
// ---
|
||||
if len(sub.children) != 0 {
|
||||
t.Fatalf("Expected no sub children, got %d", len(sub.children))
|
||||
}
|
||||
|
||||
if len(group.children) != 2 {
|
||||
t.Fatalf("Expected %d group children, got %d", 2, len(group.children))
|
||||
}
|
||||
// ---
|
||||
|
||||
// check the registered route
|
||||
// ---
|
||||
if route != group.children[1] {
|
||||
t.Fatalf("Expected group children %v, got %v", route, group.children[1])
|
||||
}
|
||||
|
||||
if route.Method != http.MethodPost {
|
||||
t.Fatalf("Expected route method %q, got %q", http.MethodPost, route.Method)
|
||||
}
|
||||
|
||||
if route.Path != "/test" {
|
||||
t.Fatalf("Expected route path %q, got %q", "/test", route.Path)
|
||||
}
|
||||
|
||||
route.Action(nil)
|
||||
if !called {
|
||||
t.Fatal("Expected route action to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupRouteAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := RouterGroup[*Event]{}
|
||||
|
||||
testErr := errors.New("test")
|
||||
|
||||
testAction := func(e *Event) error {
|
||||
return testErr
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
route *Route[*Event]
|
||||
expectMethod string
|
||||
expectPath string
|
||||
}{
|
||||
{
|
||||
group.Any("/test", testAction),
|
||||
"",
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.GET("/test", testAction),
|
||||
http.MethodGet,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.POST("/test", testAction),
|
||||
http.MethodPost,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.DELETE("/test", testAction),
|
||||
http.MethodDelete,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.PATCH("/test", testAction),
|
||||
http.MethodPatch,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.PUT("/test", testAction),
|
||||
http.MethodPut,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.HEAD("/test", testAction),
|
||||
http.MethodHead,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.OPTIONS("/test", testAction),
|
||||
http.MethodOptions,
|
||||
"/test",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.expectMethod, s.expectPath), func(t *testing.T) {
|
||||
if s.route.Method != s.expectMethod {
|
||||
t.Fatalf("Expected method %q, got %q", s.expectMethod, s.route.Method)
|
||||
}
|
||||
|
||||
if s.route.Path != s.expectPath {
|
||||
t.Fatalf("Expected path %q, got %q", s.expectPath, s.route.Path)
|
||||
}
|
||||
|
||||
if err := s.route.Action(nil); !errors.Is(err, testErr) {
|
||||
t.Fatal("Expected test action")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupHasRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := RouterGroup[*Event]{}
|
||||
|
||||
group.Any("/any", nil)
|
||||
|
||||
group.GET("/base", nil)
|
||||
group.DELETE("/base", nil)
|
||||
|
||||
sub := group.Group("/sub1")
|
||||
sub.GET("/a", nil)
|
||||
sub.POST("/a", nil)
|
||||
|
||||
sub2 := sub.Group("/sub2")
|
||||
sub2.GET("/b", nil)
|
||||
sub2.GET("/b/{test}", nil)
|
||||
|
||||
// special cases to test the normalizations
|
||||
group.GET("/c/", nil) // the same as /c/{test...}
|
||||
group.GET("/d/{test...}", nil) // the same as /d/
|
||||
|
||||
scenarios := []struct {
|
||||
method string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
http.MethodGet,
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"",
|
||||
"/any",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodPost,
|
||||
"/base",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/base",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodDelete,
|
||||
"/base",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/a",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodPost,
|
||||
"/sub1/a",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodDelete,
|
||||
"/sub1/a",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub2/b",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/sub2/b",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/sub2/b/{test}",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/sub2/b/{test2}",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/c/{test...}",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/d/",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.method+"_"+s.path, func(t *testing.T) {
|
||||
has := group.HasRoute(s.method, s.path)
|
||||
|
||||
if has != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, has)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
_ io.ReadCloser = (*RereadableReadCloser)(nil)
|
||||
_ Rereader = (*RereadableReadCloser)(nil)
|
||||
)
|
||||
|
||||
// Rereader defines an interface for rewindable readers.
|
||||
type Rereader interface {
|
||||
Reread()
|
||||
}
|
||||
|
||||
// RereadableReadCloser defines a wrapper around a io.ReadCloser reader
|
||||
// allowing to read the original reader multiple times.
|
||||
type RereadableReadCloser struct {
|
||||
io.ReadCloser
|
||||
|
||||
copy *bytes.Buffer
|
||||
active io.Reader
|
||||
}
|
||||
|
||||
// Read implements the standard io.Reader interface.
|
||||
//
|
||||
// It reads up to len(b) bytes into b and at at the same time writes
|
||||
// the read data into an internal bytes buffer.
|
||||
//
|
||||
// On EOF the r is "rewinded" to allow reading from r multiple times.
|
||||
func (r *RereadableReadCloser) Read(b []byte) (int, error) {
|
||||
if r.active == nil {
|
||||
if r.copy == nil {
|
||||
r.copy = &bytes.Buffer{}
|
||||
}
|
||||
r.active = io.TeeReader(r.ReadCloser, r.copy)
|
||||
}
|
||||
|
||||
n, err := r.active.Read(b)
|
||||
if err == io.EOF {
|
||||
r.Reread()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Reread satisfies the [Rereader] interface and resets the r internal state to allow rereads.
|
||||
//
|
||||
// note: not named Reset to avoid conflicts with other reader interfaces.
|
||||
func (r *RereadableReadCloser) Reread() {
|
||||
if r.copy == nil || r.copy.Len() == 0 {
|
||||
return // nothing to reset or it has been already reset
|
||||
}
|
||||
|
||||
oldCopy := r.copy
|
||||
r.copy = &bytes.Buffer{}
|
||||
r.active = io.TeeReader(oldCopy, r.copy)
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestRereadableReadCloser(t *testing.T) {
|
||||
content := "test"
|
||||
|
||||
rereadable := &router.RereadableReadCloser{
|
||||
ReadCloser: io.NopCloser(strings.NewReader(content)),
|
||||
}
|
||||
|
||||
// read multiple times
|
||||
for i := 0; i < 3; i++ {
|
||||
result, err := io.ReadAll(rereadable)
|
||||
if err != nil {
|
||||
t.Fatalf("[read:%d] %v", i, err)
|
||||
}
|
||||
if str := string(result); str != content {
|
||||
t.Fatalf("[read:%d] Expected %q, got %q", i, content, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package router
|
||||
|
||||
import "github.com/pocketbase/pocketbase/tools/hook"
|
||||
|
||||
type Route[T hook.Resolver] struct {
|
||||
excludedMiddlewares map[string]struct{}
|
||||
|
||||
Action hook.HandlerFunc[T]
|
||||
Method string
|
||||
Path string
|
||||
Middlewares []*hook.Handler[T]
|
||||
}
|
||||
|
||||
// BindFunc registers one or multiple middleware functions to the current route.
|
||||
//
|
||||
// The registered middleware functions are "anonymous" and with default priority,
|
||||
// aka. executes in the order they were registered.
|
||||
//
|
||||
// If you need to specify a named middleware (ex. so that it can be removed)
|
||||
// or middleware with custom exec prirority, use the [Bind] method.
|
||||
func (route *Route[T]) BindFunc(middlewareFuncs ...hook.HandlerFunc[T]) *Route[T] {
|
||||
for _, m := range middlewareFuncs {
|
||||
route.Middlewares = append(route.Middlewares, &hook.Handler[T]{Func: m})
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Bind registers one or multiple middleware handlers to the current route.
|
||||
func (route *Route[T]) Bind(middlewares ...*hook.Handler[T]) *Route[T] {
|
||||
route.Middlewares = append(route.Middlewares, middlewares...)
|
||||
|
||||
// unmark the newly added middlewares in case they were previously "excluded"
|
||||
if route.excludedMiddlewares != nil {
|
||||
for _, m := range middlewares {
|
||||
if m.Id != "" {
|
||||
delete(route.excludedMiddlewares, m.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Unbind removes one or more middlewares with the specified id(s) from the current route.
|
||||
//
|
||||
// It also adds the removed middleware ids to an exclude list so that they could be skipped from
|
||||
// the execution chain in case the middleware is registered in a parent group.
|
||||
//
|
||||
// Anonymous middlewares are considered non-removable, aka. this method
|
||||
// does nothing if the middleware id is an empty string.
|
||||
func (route *Route[T]) Unbind(middlewareIds ...string) *Route[T] {
|
||||
for _, middlewareId := range middlewareIds {
|
||||
if middlewareId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// remove from the route's middlewares
|
||||
for i := len(route.Middlewares) - 1; i >= 0; i-- {
|
||||
if route.Middlewares[i].Id == middlewareId {
|
||||
route.Middlewares = append(route.Middlewares[:i], route.Middlewares[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
// add to the exclude list
|
||||
if route.excludedMiddlewares == nil {
|
||||
route.excludedMiddlewares = map[string]struct{}{}
|
||||
}
|
||||
route.excludedMiddlewares[middlewareId] = struct{}{}
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
func TestRouteBindFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Route[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one function
|
||||
r.BindFunc(func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
})
|
||||
|
||||
// append multiple functions
|
||||
r.BindFunc(
|
||||
func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(r.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range r.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteBind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Route[*Event]{
|
||||
// mock excluded middlewares to check whether the entry will be deleted
|
||||
excludedMiddlewares: map[string]struct{}{"test2": {}},
|
||||
}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one handler
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// append multiple handlers
|
||||
r.Bind(
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(r.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range r.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls %q, got %q", "abc", calls)
|
||||
}
|
||||
|
||||
// ensures that the previously excluded middleware was removed
|
||||
if len(r.excludedMiddlewares) != 0 {
|
||||
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", r.excludedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteUnbind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Route[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// anonymous middlewares
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// middlewares with id
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Id: "test3",
|
||||
Func: func(e *Event) error {
|
||||
calls += "d"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// remove
|
||||
r.Unbind("") // should be no-op
|
||||
r.Unbind("test1", "test3")
|
||||
|
||||
if total := len(r.Middlewares); total != 2 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 2, total)
|
||||
}
|
||||
|
||||
for _, h := range r.Middlewares {
|
||||
if err := h.Func(nil); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if calls != "ac" {
|
||||
t.Fatalf("Expected calls %q, got %q", "ac", calls)
|
||||
}
|
||||
|
||||
// ensure that the id was added in the exclude list
|
||||
excluded := []string{"test1", "test3"}
|
||||
if len(r.excludedMiddlewares) != len(excluded) {
|
||||
t.Fatalf("Expected excludes %v, got %v", excluded, r.excludedMiddlewares)
|
||||
}
|
||||
for id := range r.excludedMiddlewares {
|
||||
if !slices.Contains(excluded, id) {
|
||||
t.Fatalf("Expected %q to be marked as excluded", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
type EventCleanupFunc func()
|
||||
|
||||
// EventFactoryFunc defines the function responsible for creating a Route specific event
|
||||
// based on the provided request handler ServeHTTP data.
|
||||
//
|
||||
// Optionally return a clean up function that will be invoked right after the route execution.
|
||||
type EventFactoryFunc[T hook.Resolver] func(w http.ResponseWriter, r *http.Request) (T, EventCleanupFunc)
|
||||
|
||||
// Router defines a thin wrapper around the standard Go [http.ServeMux] by
|
||||
// adding support for routing sub-groups, middlewares and other common utils.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// r := NewRouter[*MyEvent](eventFactory)
|
||||
//
|
||||
// // middlewares
|
||||
// r.BindFunc(m1, m2)
|
||||
//
|
||||
// // routes
|
||||
// r.GET("/test", handler1)
|
||||
//
|
||||
// // sub-routers/groups
|
||||
// api := r.Group("/api")
|
||||
// api.GET("/admins", handler2)
|
||||
//
|
||||
// // generate a http.ServeMux instance based on the router configurations
|
||||
// mux, _ := r.BuildMux()
|
||||
//
|
||||
// http.ListenAndServe("localhost:8090", mux)
|
||||
type Router[T hook.Resolver] struct {
|
||||
*RouterGroup[T]
|
||||
|
||||
eventFactory EventFactoryFunc[T]
|
||||
}
|
||||
|
||||
// NewRouter creates a new Router instance with the provided event factory function.
|
||||
func NewRouter[T hook.Resolver](eventFactory EventFactoryFunc[T]) *Router[T] {
|
||||
return &Router[T]{
|
||||
RouterGroup: &RouterGroup[T]{},
|
||||
eventFactory: eventFactory,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildMux constructs a new mux [http.Handler] instance from the current router configurations.
|
||||
func (r *Router[T]) BuildMux() (http.Handler, error) {
|
||||
// Note that some of the default std Go handlers like the [http.NotFoundHandler]
|
||||
// cannot be currently extended and requires defining a custom "catch-all" route
|
||||
// so that the group middlewares could be executed.
|
||||
//
|
||||
// https://github.com/golang/go/issues/65648
|
||||
if !r.HasRoute("", "/") {
|
||||
r.Route("", "/", func(e T) error {
|
||||
return NewNotFoundError("", nil)
|
||||
})
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
if err := r.loadMux(mux, r.RouterGroup, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
func (r *Router[T]) loadMux(mux *http.ServeMux, group *RouterGroup[T], parents []*RouterGroup[T]) error {
|
||||
for _, child := range group.children {
|
||||
switch v := child.(type) {
|
||||
case *RouterGroup[T]:
|
||||
if err := r.loadMux(mux, v, append(parents, group)); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Route[T]:
|
||||
routeHook := &hook.Hook[T]{}
|
||||
|
||||
var pattern string
|
||||
|
||||
if v.Method != "" {
|
||||
pattern = v.Method + " "
|
||||
}
|
||||
|
||||
// add parent groups middlewares
|
||||
for _, p := range parents {
|
||||
pattern += p.Prefix
|
||||
for _, h := range p.Middlewares {
|
||||
if _, ok := p.excludedMiddlewares[h.Id]; !ok {
|
||||
if _, ok = group.excludedMiddlewares[h.Id]; !ok {
|
||||
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
|
||||
routeHook.Bind(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add current groups middlewares
|
||||
pattern += group.Prefix
|
||||
for _, h := range group.Middlewares {
|
||||
if _, ok := group.excludedMiddlewares[h.Id]; !ok {
|
||||
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
|
||||
routeHook.Bind(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add current route middlewares
|
||||
pattern += v.Path
|
||||
for _, h := range v.Middlewares {
|
||||
if _, ok := v.excludedMiddlewares[h.Id]; !ok {
|
||||
routeHook.Bind(h)
|
||||
}
|
||||
}
|
||||
|
||||
// add global panic-recover middleware
|
||||
routeHook.Bind(&hook.Handler[T]{
|
||||
Func: r.panicHandler,
|
||||
Priority: -9999999, // before everything else
|
||||
})
|
||||
|
||||
mux.HandleFunc(pattern, func(resp http.ResponseWriter, req *http.Request) {
|
||||
// wrap the response to add write and status tracking
|
||||
resp = &ResponseWriter{ResponseWriter: resp}
|
||||
|
||||
// wrap the request body to allow multiple reads
|
||||
req.Body = &RereadableReadCloser{ReadCloser: req.Body}
|
||||
|
||||
event, cleanupFunc := r.eventFactory(resp, req)
|
||||
|
||||
// trigger the handler hook chain
|
||||
err := routeHook.Trigger(event, v.Action)
|
||||
if err != nil {
|
||||
ErrorHandler(resp, req, err)
|
||||
}
|
||||
|
||||
if cleanupFunc != nil {
|
||||
cleanupFunc()
|
||||
}
|
||||
})
|
||||
default:
|
||||
return errors.New("invalid Group item type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// panicHandler registers a default panic-recover handling.
|
||||
func (r *Router[T]) panicHandler(event T) (err error) {
|
||||
// panic-recover
|
||||
defer func() {
|
||||
recoverResult := recover()
|
||||
if recoverResult == nil {
|
||||
return
|
||||
}
|
||||
|
||||
recoverErr, ok := recoverResult.(error)
|
||||
if !ok {
|
||||
recoverErr = fmt.Errorf("%v", recoverResult)
|
||||
} else if errors.Is(recoverErr, http.ErrAbortHandler) {
|
||||
// don't recover ErrAbortHandler so the response to the client can be aborted
|
||||
panic(recoverResult)
|
||||
}
|
||||
|
||||
stack := make([]byte, 2<<10) // 2 KB
|
||||
length := runtime.Stack(stack, true)
|
||||
err = NewInternalServerError("", fmt.Errorf("[PANIC RECOVER] %w %s", recoverErr, stack[:length]))
|
||||
}()
|
||||
|
||||
err = event.Next()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func ErrorHandler(resp http.ResponseWriter, req *http.Request, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ok, _ := getWritten(resp); ok {
|
||||
return // a response was already written (aka. already handled)
|
||||
}
|
||||
|
||||
header := resp.Header()
|
||||
if header.Get("Content-Type") == "" {
|
||||
header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
apiErr := ToApiError(err)
|
||||
|
||||
resp.WriteHeader(apiErr.Status)
|
||||
|
||||
if req.Method != http.MethodHead {
|
||||
if jsonErr := json.NewEncoder(resp).Encode(apiErr); jsonErr != nil {
|
||||
log.Println(jsonErr) // truly rare case, log to stderr only for dev purposes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type WriteTracker interface {
|
||||
// Written reports whether a write operation has occurred.
|
||||
Written() bool
|
||||
}
|
||||
|
||||
type StatusTracker interface {
|
||||
// Status reports the written response status code.
|
||||
Status() int
|
||||
}
|
||||
|
||||
type flushErrorer interface {
|
||||
FlushError() error
|
||||
}
|
||||
|
||||
var (
|
||||
_ WriteTracker = (*ResponseWriter)(nil)
|
||||
_ StatusTracker = (*ResponseWriter)(nil)
|
||||
_ http.Flusher = (*ResponseWriter)(nil)
|
||||
_ http.Hijacker = (*ResponseWriter)(nil)
|
||||
_ http.Pusher = (*ResponseWriter)(nil)
|
||||
_ io.ReaderFrom = (*ResponseWriter)(nil)
|
||||
_ flushErrorer = (*ResponseWriter)(nil)
|
||||
)
|
||||
|
||||
// ResponseWriter wraps a http.ResponseWriter to track its write state.
|
||||
type ResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
written bool
|
||||
status int
|
||||
}
|
||||
|
||||
func (rw *ResponseWriter) WriteHeader(status int) {
|
||||
if rw.written {
|
||||
return
|
||||
}
|
||||
|
||||
rw.written = true
|
||||
rw.status = status
|
||||
rw.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (rw *ResponseWriter) Write(b []byte) (int, error) {
|
||||
if !rw.written {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
return rw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// Written implements [WriteTracker] and returns whether the current response body has been already written.
|
||||
func (rw *ResponseWriter) Written() bool {
|
||||
return rw.written
|
||||
}
|
||||
|
||||
// Written implements [StatusTracker] and returns the written status code of the current response.
|
||||
func (rw *ResponseWriter) Status() int {
|
||||
return rw.status
|
||||
}
|
||||
|
||||
// Flush implements [http.Flusher] and allows an HTTP handler to flush buffered data to the client.
|
||||
// This method is no-op if the wrapped writer doesn't support it.
|
||||
func (rw *ResponseWriter) Flush() {
|
||||
_ = rw.FlushError()
|
||||
}
|
||||
|
||||
// FlushError is similar to [Flush] but returns [http.ErrNotSupported]
|
||||
// if the wrapped writer doesn't support it.
|
||||
func (rw *ResponseWriter) FlushError() error {
|
||||
err := http.NewResponseController(rw.ResponseWriter).Flush()
|
||||
if err == nil || !errors.Is(err, http.ErrNotSupported) {
|
||||
rw.written = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Hijack implements [http.Hijacker] and allows an HTTP handler to take over the current connection.
|
||||
func (rw *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(rw.ResponseWriter).Hijack()
|
||||
}
|
||||
|
||||
// Pusher implements [http.Pusher] to indicate HTTP/2 server push support.
|
||||
func (rw *ResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
w := rw.ResponseWriter
|
||||
for {
|
||||
switch p := w.(type) {
|
||||
case http.Pusher:
|
||||
return p.Push(target, opts)
|
||||
case RWUnwrapper:
|
||||
w = p.Unwrap()
|
||||
default:
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReaderFrom implements [io.ReaderFrom] by checking if the underlying writer supports it.
|
||||
// Otherwise calls [io.Copy].
|
||||
func (rw *ResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !rw.written {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
w := rw.ResponseWriter
|
||||
for {
|
||||
switch rf := w.(type) {
|
||||
case io.ReaderFrom:
|
||||
return rf.ReadFrom(r)
|
||||
case RWUnwrapper:
|
||||
w = rf.Unwrap()
|
||||
default:
|
||||
return io.Copy(rw.ResponseWriter, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWritter instance (usually used by [http.ResponseController]).
|
||||
func (rw *ResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return rw.ResponseWriter
|
||||
}
|
||||
|
||||
func getWritten(rw http.ResponseWriter) (bool, error) {
|
||||
for {
|
||||
switch w := rw.(type) {
|
||||
case WriteTracker:
|
||||
return w.Written(), nil
|
||||
case RWUnwrapper:
|
||||
rw = w.Unwrap()
|
||||
default:
|
||||
return false, http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getStatus(rw http.ResponseWriter) (int, error) {
|
||||
for {
|
||||
switch w := rw.(type) {
|
||||
case StatusTracker:
|
||||
return w.Status(), nil
|
||||
case RWUnwrapper:
|
||||
rw = w.Unwrap()
|
||||
default:
|
||||
return 0, http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
|
||||
return &router.Event{
|
||||
Response: w,
|
||||
Request: r,
|
||||
},
|
||||
func() {
|
||||
calls += ":cleanup"
|
||||
}
|
||||
})
|
||||
|
||||
r.BindFunc(func(e *router.Event) error {
|
||||
calls += "root_m:"
|
||||
|
||||
err := e.Next()
|
||||
|
||||
if err != nil {
|
||||
calls += "/error"
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
r.Any("/any", func(e *router.Event) error {
|
||||
calls += "/any"
|
||||
return nil
|
||||
})
|
||||
|
||||
r.GET("/a", func(e *router.Event) error {
|
||||
calls += "/a"
|
||||
return nil
|
||||
})
|
||||
|
||||
g1 := r.Group("/a/b").BindFunc(func(e *router.Event) error {
|
||||
calls += "a_b_group_m:"
|
||||
return e.Next()
|
||||
})
|
||||
g1.GET("/1", func(e *router.Event) error {
|
||||
calls += "/1_get"
|
||||
return nil
|
||||
}).BindFunc(func(e *router.Event) error {
|
||||
calls += "1_get_m:"
|
||||
return e.Next()
|
||||
})
|
||||
g1.POST("/1", func(e *router.Event) error {
|
||||
calls += "/1_post"
|
||||
return nil
|
||||
})
|
||||
g1.GET("/{param}", func(e *router.Event) error {
|
||||
calls += "/" + e.Request.PathValue("param")
|
||||
return errors.New("test") // should be normalized to an ApiError
|
||||
})
|
||||
g1.GET("/panic", func(e *router.Event) error {
|
||||
calls += "/panic"
|
||||
panic("test")
|
||||
})
|
||||
|
||||
mux, err := r.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
defer ts.Close()
|
||||
|
||||
client := ts.Client()
|
||||
|
||||
scenarios := []struct {
|
||||
method string
|
||||
path string
|
||||
calls string
|
||||
}{
|
||||
{http.MethodGet, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodOptions, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodPatch, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodPut, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodPost, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodDelete, "/any", "root_m:/any:cleanup"},
|
||||
// ---
|
||||
{http.MethodPost, "/a", "root_m:/error:cleanup"}, // missing
|
||||
{http.MethodGet, "/a", "root_m:/a:cleanup"},
|
||||
{http.MethodHead, "/a", "root_m:/a:cleanup"}, // auto registered with the GET
|
||||
{http.MethodGet, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
|
||||
{http.MethodHead, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
|
||||
{http.MethodPost, "/a/b/1", "root_m:a_b_group_m:/1_post:cleanup"},
|
||||
{http.MethodGet, "/a/b/456", "root_m:a_b_group_m:/456/error:cleanup"},
|
||||
{http.MethodGet, "/a/b/panic", "root_m:a_b_group_m:/panic:cleanup"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.method+"_"+s.path, func(t *testing.T) {
|
||||
calls = "" // reset
|
||||
|
||||
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if calls != s.calls {
|
||||
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterUnbind(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
|
||||
return &router.Event{
|
||||
Response: w,
|
||||
Request: r,
|
||||
},
|
||||
func() {
|
||||
calls += ":cleanup"
|
||||
}
|
||||
})
|
||||
r.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "root_1",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "root_1:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "root_2",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "root_2:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "root_3",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "root_3:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
r.GET("/action", func(e *router.Event) error {
|
||||
calls += "root_action"
|
||||
return nil
|
||||
}).Unbind("root_1")
|
||||
|
||||
ga := r.Group("/group_a")
|
||||
ga.Unbind("root_1")
|
||||
ga.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_a_1",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_a_1:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
ga.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_a_2",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_a_2:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
ga.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_a_3",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_a_3:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
ga.GET("/action", func(e *router.Event) error {
|
||||
calls += "group_a_action"
|
||||
return nil
|
||||
}).Unbind("root_2", "group_b_1", "group_a_1")
|
||||
|
||||
gb := r.Group("/group_b")
|
||||
gb.Unbind("root_2")
|
||||
gb.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_b_1",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_b_1:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
gb.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_b_2",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_b_2:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
gb.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_b_3",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_b_3:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
gb.GET("/action", func(e *router.Event) error {
|
||||
calls += "group_b_action"
|
||||
return nil
|
||||
}).Unbind("group_b_3", "group_a_3", "root_3")
|
||||
|
||||
mux, err := r.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
defer ts.Close()
|
||||
|
||||
client := ts.Client()
|
||||
|
||||
scenarios := []struct {
|
||||
method string
|
||||
path string
|
||||
calls string
|
||||
}{
|
||||
{http.MethodGet, "/action", "root_2:root_3:root_action:cleanup"},
|
||||
{http.MethodGet, "/group_a/action", "root_3:group_a_2:group_a_3:group_a_action:cleanup"},
|
||||
{http.MethodGet, "/group_b/action", "root_1:group_b_1:group_b_2:group_b_action:cleanup"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.method+"_"+s.path, func(t *testing.T) {
|
||||
calls = "" // reset
|
||||
|
||||
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if calls != s.calls {
|
||||
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
|
||||
// JSONPayloadKey is the key for the special UnmarshalRequestData case
|
||||
// used for reading serialized json payload without normalization.
|
||||
const JSONPayloadKey string = "@jsonPayload"
|
||||
|
||||
// UnmarshalRequestData unmarshals url.Values type of data (query, multipart/form-data, etc.) into dst.
|
||||
//
|
||||
// dst must be a pointer to a map[string]any or struct.
|
||||
//
|
||||
// If dst is a map[string]any, each data value will be inferred and
|
||||
// converted to its bool, numeric, or string equivalent value
|
||||
// (refer to inferValue() for the exact rules).
|
||||
//
|
||||
// If dst is a struct, the following field types are supported:
|
||||
// - bool
|
||||
// - string
|
||||
// - int, int8, int16, int32, int64
|
||||
// - uint, uint8, uint16, uint32, uint64
|
||||
// - float32, float64
|
||||
// - serialized json string if submitted under the special "@jsonPayload" key
|
||||
// - encoding.TextUnmarshaler
|
||||
// - pointer and slice variations of the above primitives (ex. *string, []string, *[]string []*string, etc.)
|
||||
// - named/anonymous struct fields
|
||||
// Dot-notation is used to target nested fields, ex. "nestedStructField.title".
|
||||
// - embedded struct fields
|
||||
// The embedded struct fields are treated by default as if they were defined in their parent struct.
|
||||
// If the embedded struct has a tag matching structTagKey then to set its fields the data keys must be prefixed with that tag
|
||||
// similar to the regular nested struct fields.
|
||||
//
|
||||
// structTagKey and structPrefix are used only when dst is a struct.
|
||||
//
|
||||
// structTagKey represents the tag to use to match a data entry with a struct field (defaults to "form").
|
||||
// If the struct field doesn't have the structTagKey tag, then the exported struct field name will be used as it is.
|
||||
//
|
||||
// structPrefix could be provided if all of the data keys are prefixed with a common string
|
||||
// and you want the struct field to match only the value without the structPrefix
|
||||
// (ex. for "user.name", "user.email" data keys and structPrefix "user", it will match "name" and "email" struct fields).
|
||||
//
|
||||
// Note that while the method was inspired by binders from echo, gorrila/schema, ozzo-routing
|
||||
// and other similar common routing packages, it is not intended to be a drop-in replacement.
|
||||
//
|
||||
// @todo Consider adding support for dot-notation keys, in addition to the prefix, (ex. parent.child.title) to express nested object keys.
|
||||
func UnmarshalRequestData(data map[string][]string, dst any, structTagKey string, structPrefix string) error {
|
||||
if len(data) == 0 {
|
||||
return nil // nothing to unmarshal
|
||||
}
|
||||
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Pointer {
|
||||
return errors.New("dst must be a pointer")
|
||||
}
|
||||
|
||||
dstValue = dereference(dstValue)
|
||||
|
||||
dstType := dstValue.Type()
|
||||
|
||||
switch dstType.Kind() {
|
||||
case reflect.Map: // map[string]any
|
||||
if dstType.Elem().Kind() != reflect.Interface {
|
||||
return errors.New("dst map value type must be any/interface{}")
|
||||
}
|
||||
|
||||
for k, v := range data {
|
||||
if k == JSONPayloadKey {
|
||||
continue // unmarshalled separately
|
||||
}
|
||||
|
||||
total := len(v)
|
||||
|
||||
if total == 1 {
|
||||
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(inferValue(v[0])))
|
||||
} else {
|
||||
normalized := make([]any, total)
|
||||
for i, vItem := range v {
|
||||
normalized[i] = inferValue(vItem)
|
||||
}
|
||||
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
// set a default tag key
|
||||
if structTagKey == "" {
|
||||
structTagKey = "form"
|
||||
}
|
||||
|
||||
err := unmarshalInStructValue(data, dstValue, structTagKey, structPrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("dst must be a map[string]any or struct")
|
||||
}
|
||||
|
||||
// @jsonPayload
|
||||
//
|
||||
// Special case to scan serialized json string without
|
||||
// normalization alongside the other data values
|
||||
// ---------------------------------------------------------------
|
||||
jsonPayloadValues := data[JSONPayloadKey]
|
||||
for _, payload := range jsonPayloadValues {
|
||||
if err := json.Unmarshal([]byte(payload), dst); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// unmarshalInStructValue unmarshals data into the provided struct reflect.Value fields.
|
||||
func unmarshalInStructValue(
|
||||
data map[string][]string,
|
||||
dstStructValue reflect.Value,
|
||||
structTagKey string,
|
||||
structPrefix string,
|
||||
) error {
|
||||
dstStructType := dstStructValue.Type()
|
||||
|
||||
for i := 0; i < dstStructValue.NumField(); i++ {
|
||||
fieldType := dstStructType.Field(i)
|
||||
|
||||
tag := fieldType.Tag.Get(structTagKey)
|
||||
|
||||
if tag == "-" || (!fieldType.Anonymous && !fieldType.IsExported()) {
|
||||
continue // disabled or unexported non-anonymous struct field
|
||||
}
|
||||
|
||||
fieldValue := dereference(dstStructValue.Field(i))
|
||||
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
isSlice := ft.Kind() == reflect.Slice
|
||||
if isSlice {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
name := tag
|
||||
if name == "" && !fieldType.Anonymous {
|
||||
name = fieldType.Name
|
||||
}
|
||||
if name != "" && structPrefix != "" {
|
||||
name = structPrefix + "." + name
|
||||
}
|
||||
|
||||
// (*)encoding.TextUnmarshaler field
|
||||
// ---
|
||||
if ft.Implements(textUnmarshalerType) || reflect.PointerTo(ft).Implements(textUnmarshalerType) {
|
||||
values, ok := data[name]
|
||||
if !ok || len(values) == 0 || !fieldValue.CanSet() {
|
||||
continue // no value to load or the field cannot be set
|
||||
}
|
||||
|
||||
if isSlice {
|
||||
n := len(values)
|
||||
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
|
||||
for i, v := range values {
|
||||
unmarshaler, ok := dereference(slice.Index(i)).Addr().Interface().(encoding.TextUnmarshaler)
|
||||
if ok {
|
||||
if err := unmarshaler.UnmarshalText([]byte(v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
fieldValue.Set(slice)
|
||||
} else {
|
||||
unmarshaler, ok := fieldValue.Addr().Interface().(encoding.TextUnmarshaler)
|
||||
if ok {
|
||||
if err := unmarshaler.UnmarshalText([]byte(values[0])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// "regular" field
|
||||
// ---
|
||||
if ft.Kind() != reflect.Struct {
|
||||
values, ok := data[name]
|
||||
if !ok || len(values) == 0 || !fieldValue.CanSet() {
|
||||
continue // no value to load
|
||||
}
|
||||
|
||||
if isSlice {
|
||||
n := len(values)
|
||||
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
|
||||
for i, v := range values {
|
||||
if err := setRegularReflectedValue(dereference(slice.Index(i)), v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fieldValue.Set(slice)
|
||||
} else {
|
||||
if err := setRegularReflectedValue(fieldValue, values[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// structs (embedded or nested)
|
||||
// ---
|
||||
// slice of structs
|
||||
if isSlice {
|
||||
// populating slice of structs is not supported at the moment
|
||||
// because the filling rules are ambiguous
|
||||
continue
|
||||
}
|
||||
|
||||
if tag != "" {
|
||||
structPrefix = tag
|
||||
} else {
|
||||
structPrefix = name // name is empty for anonymous structs -> no prefix
|
||||
}
|
||||
|
||||
if err := unmarshalInStructValue(data, fieldValue, structTagKey, structPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dereference returns the underlying value v points to.
|
||||
func dereference(v reflect.Value) reflect.Value {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
// initialize with a new value and continue searching
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// setRegularReflectedValue sets and casts value into rv.
|
||||
func setRegularReflectedValue(rv reflect.Value, value string) error {
|
||||
switch rv.Kind() {
|
||||
case reflect.String:
|
||||
rv.SetString(value)
|
||||
case reflect.Bool:
|
||||
if value == "" {
|
||||
value = "f"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetBool(v)
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
if value == "" {
|
||||
value = "0"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(value, 0, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetInt(v)
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
if value == "" {
|
||||
value = "0"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseUint(value, 0, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetUint(v)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if value == "" {
|
||||
value = "0"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseFloat(value, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetFloat(v)
|
||||
default:
|
||||
return errors.New("unknown value type " + rv.Kind().String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// In order to support more seamlessly both json and multipart/form-data requests,
|
||||
// the following normalization rules are applied for plain multipart string values:
|
||||
// - "true" is converted to the json `true`
|
||||
// - "false" is converted to the json `false`
|
||||
// - numeric (non-scientific) strings are converted to json number
|
||||
// - any other string (empty string too) is left as it is
|
||||
func inferValue(raw string) any {
|
||||
switch raw {
|
||||
case "":
|
||||
return raw
|
||||
case "true":
|
||||
return true
|
||||
case "false":
|
||||
return false
|
||||
default:
|
||||
// try to convert to number
|
||||
if raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9') {
|
||||
v, err := strconv.ParseFloat(raw, 64)
|
||||
if err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return raw
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func pointer[T any](val T) *T {
|
||||
return &val
|
||||
}
|
||||
|
||||
func TestUnmarshalRequestData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mapData := map[string][]string{
|
||||
"number1": {"1"},
|
||||
"number2": {"2", "3"},
|
||||
"number3": {"2.1", "-3.4"},
|
||||
"string0": {""},
|
||||
"string1": {"a"},
|
||||
"string2": {"b", "c"},
|
||||
"bool1": {"true"},
|
||||
"bool2": {"true", "false"},
|
||||
"mixed": {"true", "123", "test"},
|
||||
"@jsonPayload": {`{"json_a":null,"json_b":123}`, `{"json_c":[1,2,3]}`},
|
||||
}
|
||||
|
||||
structData := map[string][]string{
|
||||
"stringTag": {"a", "b"},
|
||||
"StringPtr": {"b"},
|
||||
"StringSlice": {"a", "b", "c", ""},
|
||||
"stringSlicePtrTag": {"d", "e"},
|
||||
"StringSliceOfPtr": {"f", "g"},
|
||||
|
||||
"boolTag": {"true"},
|
||||
"BoolPtr": {"true"},
|
||||
"BoolSlice": {"true", "false", ""},
|
||||
"boolSlicePtrTag": {"false", "false", "true"},
|
||||
"BoolSliceOfPtr": {"false", "true", "false"},
|
||||
|
||||
"int8Tag": {"-1", "2"},
|
||||
"Int8Ptr": {"3"},
|
||||
"Int8Slice": {"4", "5", ""},
|
||||
"int8SlicePtrTag": {"5", "6"},
|
||||
"Int8SliceOfPtr": {"7", "8"},
|
||||
|
||||
"int16Tag": {"-1", "2"},
|
||||
"Int16Ptr": {"3"},
|
||||
"Int16Slice": {"4", "5", ""},
|
||||
"int16SlicePtrTag": {"5", "6"},
|
||||
"Int16SliceOfPtr": {"7", "8"},
|
||||
|
||||
"int32Tag": {"-1", "2"},
|
||||
"Int32Ptr": {"3"},
|
||||
"Int32Slice": {"4", "5", ""},
|
||||
"int32SlicePtrTag": {"5", "6"},
|
||||
"Int32SliceOfPtr": {"7", "8"},
|
||||
|
||||
"int64Tag": {"-1", "2"},
|
||||
"Int64Ptr": {"3"},
|
||||
"Int64Slice": {"4", "5", ""},
|
||||
"int64SlicePtrTag": {"5", "6"},
|
||||
"Int64SliceOfPtr": {"7", "8"},
|
||||
|
||||
"intTag": {"-1", "2"},
|
||||
"IntPtr": {"3"},
|
||||
"IntSlice": {"4", "5", ""},
|
||||
"intSlicePtrTag": {"5", "6"},
|
||||
"IntSliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint8Tag": {"1", "2"},
|
||||
"Uint8Ptr": {"3"},
|
||||
"Uint8Slice": {"4", "5", ""},
|
||||
"uint8SlicePtrTag": {"5", "6"},
|
||||
"Uint8SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint16Tag": {"1", "2"},
|
||||
"Uint16Ptr": {"3"},
|
||||
"Uint16Slice": {"4", "5", ""},
|
||||
"uint16SlicePtrTag": {"5", "6"},
|
||||
"Uint16SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint32Tag": {"1", "2"},
|
||||
"Uint32Ptr": {"3"},
|
||||
"Uint32Slice": {"4", "5", ""},
|
||||
"uint32SlicePtrTag": {"5", "6"},
|
||||
"Uint32SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint64Tag": {"1", "2"},
|
||||
"Uint64Ptr": {"3"},
|
||||
"Uint64Slice": {"4", "5", ""},
|
||||
"uint64SlicePtrTag": {"5", "6"},
|
||||
"Uint64SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uintTag": {"1", "2"},
|
||||
"UintPtr": {"3"},
|
||||
"UintSlice": {"4", "5", ""},
|
||||
"uintSlicePtrTag": {"5", "6"},
|
||||
"UintSliceOfPtr": {"7", "8"},
|
||||
|
||||
"float32Tag": {"-1.2"},
|
||||
"Float32Ptr": {"1.5", "2.0"},
|
||||
"Float32Slice": {"1", "2.3", "-0.3", ""},
|
||||
"float32SlicePtrTag": {"-1.3", "3"},
|
||||
"Float32SliceOfPtr": {"0", "1.2"},
|
||||
|
||||
"float64Tag": {"-1.2"},
|
||||
"Float64Ptr": {"1.5", "2.0"},
|
||||
"Float64Slice": {"1", "2.3", "-0.3", ""},
|
||||
"float64SlicePtrTag": {"-1.3", "3"},
|
||||
"Float64SliceOfPtr": {"0", "1.2"},
|
||||
|
||||
"timeTag": {"2009-11-10T15:00:00Z"},
|
||||
"TimePtr": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
|
||||
"TimeSlice": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
|
||||
"timeSlicePtrTag": {"2009-11-10T15:00:00Z", "2009-11-10T16:00:00Z"},
|
||||
"TimeSliceOfPtr": {"2009-11-10T17:00:00Z", "2009-11-10T18:00:00Z"},
|
||||
|
||||
// @jsonPayload fields
|
||||
"@jsonPayload": {
|
||||
`{"payloadA":"test", "shouldBeIgnored": "abc"}`,
|
||||
`{"payloadB":[1,2,3], "payloadC":true}`,
|
||||
},
|
||||
|
||||
// unexported fields or `-` tags
|
||||
"unexperted": {"test"},
|
||||
"SkipExported": {"test"},
|
||||
"unexportedStructFieldWithoutTag.Name": {"test"},
|
||||
"unexportedStruct.Name": {"test"},
|
||||
|
||||
// structs
|
||||
"StructWithoutTag.Name": {"test1"},
|
||||
"exportedStruct.Name": {"test2"},
|
||||
|
||||
// embedded
|
||||
"embed_name": {"test3"},
|
||||
"embed2.embed_name2": {"test4"},
|
||||
}
|
||||
|
||||
type embed1 struct {
|
||||
Name string `form:"embed_name" json:"embed_name"`
|
||||
}
|
||||
|
||||
type embed2 struct {
|
||||
Name string `form:"embed_name2" json:"embed_name2"`
|
||||
}
|
||||
|
||||
//nolint
|
||||
type TestStruct struct {
|
||||
String string `form:"stringTag" query:"stringTag2"`
|
||||
StringPtr *string
|
||||
StringSlice []string
|
||||
StringSlicePtr *[]string `form:"stringSlicePtrTag"`
|
||||
StringSliceOfPtr []*string
|
||||
|
||||
Bool bool `form:"boolTag" query:"boolTag2"`
|
||||
BoolPtr *bool
|
||||
BoolSlice []bool
|
||||
BoolSlicePtr *[]bool `form:"boolSlicePtrTag"`
|
||||
BoolSliceOfPtr []*bool
|
||||
|
||||
Int8 int8 `form:"int8Tag" query:"int8Tag2"`
|
||||
Int8Ptr *int8
|
||||
Int8Slice []int8
|
||||
Int8SlicePtr *[]int8 `form:"int8SlicePtrTag"`
|
||||
Int8SliceOfPtr []*int8
|
||||
|
||||
Int16 int16 `form:"int16Tag" query:"int16Tag2"`
|
||||
Int16Ptr *int16
|
||||
Int16Slice []int16
|
||||
Int16SlicePtr *[]int16 `form:"int16SlicePtrTag"`
|
||||
Int16SliceOfPtr []*int16
|
||||
|
||||
Int32 int32 `form:"int32Tag" query:"int32Tag2"`
|
||||
Int32Ptr *int32
|
||||
Int32Slice []int32
|
||||
Int32SlicePtr *[]int32 `form:"int32SlicePtrTag"`
|
||||
Int32SliceOfPtr []*int32
|
||||
|
||||
Int64 int64 `form:"int64Tag" query:"int64Tag2"`
|
||||
Int64Ptr *int64
|
||||
Int64Slice []int64
|
||||
Int64SlicePtr *[]int64 `form:"int64SlicePtrTag"`
|
||||
Int64SliceOfPtr []*int64
|
||||
|
||||
Int int `form:"intTag" query:"intTag2"`
|
||||
IntPtr *int
|
||||
IntSlice []int
|
||||
IntSlicePtr *[]int `form:"intSlicePtrTag"`
|
||||
IntSliceOfPtr []*int
|
||||
|
||||
Uint8 uint8 `form:"uint8Tag" query:"uint8Tag2"`
|
||||
Uint8Ptr *uint8
|
||||
Uint8Slice []uint8
|
||||
Uint8SlicePtr *[]uint8 `form:"uint8SlicePtrTag"`
|
||||
Uint8SliceOfPtr []*uint8
|
||||
|
||||
Uint16 uint16 `form:"uint16Tag" query:"uint16Tag2"`
|
||||
Uint16Ptr *uint16
|
||||
Uint16Slice []uint16
|
||||
Uint16SlicePtr *[]uint16 `form:"uint16SlicePtrTag"`
|
||||
Uint16SliceOfPtr []*uint16
|
||||
|
||||
Uint32 uint32 `form:"uint32Tag" query:"uint32Tag2"`
|
||||
Uint32Ptr *uint32
|
||||
Uint32Slice []uint32
|
||||
Uint32SlicePtr *[]uint32 `form:"uint32SlicePtrTag"`
|
||||
Uint32SliceOfPtr []*uint32
|
||||
|
||||
Uint64 uint64 `form:"uint64Tag" query:"uint64Tag2"`
|
||||
Uint64Ptr *uint64
|
||||
Uint64Slice []uint64
|
||||
Uint64SlicePtr *[]uint64 `form:"uint64SlicePtrTag"`
|
||||
Uint64SliceOfPtr []*uint64
|
||||
|
||||
Uint uint `form:"uintTag" query:"uintTag2"`
|
||||
UintPtr *uint
|
||||
UintSlice []uint
|
||||
UintSlicePtr *[]uint `form:"uintSlicePtrTag"`
|
||||
UintSliceOfPtr []*uint
|
||||
|
||||
Float32 float32 `form:"float32Tag" query:"float32Tag2"`
|
||||
Float32Ptr *float32
|
||||
Float32Slice []float32
|
||||
Float32SlicePtr *[]float32 `form:"float32SlicePtrTag"`
|
||||
Float32SliceOfPtr []*float32
|
||||
|
||||
Float64 float64 `form:"float64Tag" query:"float64Tag2"`
|
||||
Float64Ptr *float64
|
||||
Float64Slice []float64
|
||||
Float64SlicePtr *[]float64 `form:"float64SlicePtrTag"`
|
||||
Float64SliceOfPtr []*float64
|
||||
|
||||
// encoding.TextUnmarshaler
|
||||
Time time.Time `form:"timeTag" query:"timeTag2"`
|
||||
TimePtr *time.Time
|
||||
TimeSlice []time.Time
|
||||
TimeSlicePtr *[]time.Time `form:"timeSlicePtrTag"`
|
||||
TimeSliceOfPtr []*time.Time
|
||||
|
||||
// @jsonPayload fields
|
||||
JSONPayloadA string `form:"shouldBeIgnored" json:"payloadA"`
|
||||
JSONPayloadB []int `json:"payloadB"`
|
||||
JSONPayloadC bool `json:"-"`
|
||||
|
||||
// unexported fields or `-` tags
|
||||
unexported string
|
||||
SkipExported string `form:"-"`
|
||||
unexportedStructFieldWithoutTag struct {
|
||||
Name string `json:"unexportedStructFieldWithoutTag_name"`
|
||||
}
|
||||
unexportedStructFieldWithTag struct {
|
||||
Name string `json:"unexportedStructFieldWithTag_name"`
|
||||
} `form:"unexportedStruct"`
|
||||
|
||||
// structs
|
||||
StructWithoutTag struct {
|
||||
Name string `json:"StructWithoutTag_name"`
|
||||
}
|
||||
StructWithTag struct {
|
||||
Name string `json:"StructWithTag_name"`
|
||||
} `form:"exportedStruct"`
|
||||
|
||||
// embedded
|
||||
embed1
|
||||
embed2 `form:"embed2"`
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data map[string][]string
|
||||
dst any
|
||||
tag string
|
||||
prefix string
|
||||
error bool
|
||||
result string
|
||||
}{
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
dst: pointer(map[string]any{}),
|
||||
error: false,
|
||||
result: `{}`,
|
||||
},
|
||||
{
|
||||
name: "non-pointer map[string]any",
|
||||
data: mapData,
|
||||
dst: map[string]any{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported *map[string]string",
|
||||
data: mapData,
|
||||
dst: pointer(map[string]string{}),
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported *map[string][]string",
|
||||
data: mapData,
|
||||
dst: pointer(map[string][]string{}),
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "*map[string]any",
|
||||
data: mapData,
|
||||
dst: pointer(map[string]any{}),
|
||||
result: `{"bool1":true,"bool2":[true,false],"json_a":null,"json_b":123,"json_c":[1,2,3],"mixed":[true,123,"test"],"number1":1,"number2":[2,3],"number3":[2.1,-3.4],"string0":"","string1":"a","string2":["b","c"]}`,
|
||||
},
|
||||
{
|
||||
name: "valid pointer struct (all fields)",
|
||||
data: structData,
|
||||
dst: &TestStruct{},
|
||||
result: `{"String":"a","StringPtr":"b","StringSlice":["a","b","c",""],"StringSlicePtr":["d","e"],"StringSliceOfPtr":["f","g"],"Bool":true,"BoolPtr":true,"BoolSlice":[true,false,false],"BoolSlicePtr":[false,false,true],"BoolSliceOfPtr":[false,true,false],"Int8":-1,"Int8Ptr":3,"Int8Slice":[4,5,0],"Int8SlicePtr":[5,6],"Int8SliceOfPtr":[7,8],"Int16":-1,"Int16Ptr":3,"Int16Slice":[4,5,0],"Int16SlicePtr":[5,6],"Int16SliceOfPtr":[7,8],"Int32":-1,"Int32Ptr":3,"Int32Slice":[4,5,0],"Int32SlicePtr":[5,6],"Int32SliceOfPtr":[7,8],"Int64":-1,"Int64Ptr":3,"Int64Slice":[4,5,0],"Int64SlicePtr":[5,6],"Int64SliceOfPtr":[7,8],"Int":-1,"IntPtr":3,"IntSlice":[4,5,0],"IntSlicePtr":[5,6],"IntSliceOfPtr":[7,8],"Uint8":1,"Uint8Ptr":3,"Uint8Slice":"BAUA","Uint8SlicePtr":"BQY=","Uint8SliceOfPtr":[7,8],"Uint16":1,"Uint16Ptr":3,"Uint16Slice":[4,5,0],"Uint16SlicePtr":[5,6],"Uint16SliceOfPtr":[7,8],"Uint32":1,"Uint32Ptr":3,"Uint32Slice":[4,5,0],"Uint32SlicePtr":[5,6],"Uint32SliceOfPtr":[7,8],"Uint64":1,"Uint64Ptr":3,"Uint64Slice":[4,5,0],"Uint64SlicePtr":[5,6],"Uint64SliceOfPtr":[7,8],"Uint":1,"UintPtr":3,"UintSlice":[4,5,0],"UintSlicePtr":[5,6],"UintSliceOfPtr":[7,8],"Float32":-1.2,"Float32Ptr":1.5,"Float32Slice":[1,2.3,-0.3,0],"Float32SlicePtr":[-1.3,3],"Float32SliceOfPtr":[0,1.2],"Float64":-1.2,"Float64Ptr":1.5,"Float64Slice":[1,2.3,-0.3,0],"Float64SlicePtr":[-1.3,3],"Float64SliceOfPtr":[0,1.2],"Time":"2009-11-10T15:00:00Z","TimePtr":"2009-11-10T14:00:00Z","TimeSlice":["2009-11-10T14:00:00Z","2009-11-10T15:00:00Z"],"TimeSlicePtr":["2009-11-10T15:00:00Z","2009-11-10T16:00:00Z"],"TimeSliceOfPtr":["2009-11-10T17:00:00Z","2009-11-10T18:00:00Z"],"payloadA":"test","payloadB":[1,2,3],"SkipExported":"","StructWithoutTag":{"StructWithoutTag_name":"test1"},"StructWithTag":{"StructWithTag_name":"test2"},"embed_name":"test3","embed_name2":"test4"}`,
|
||||
},
|
||||
{
|
||||
name: "non-pointer struct",
|
||||
data: structData,
|
||||
dst: TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct uint value",
|
||||
data: map[string][]string{"uintTag": {"-1"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct int value",
|
||||
data: map[string][]string{"intTag": {"abc"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct bool value",
|
||||
data: map[string][]string{"boolTag": {"abc"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct float value",
|
||||
data: map[string][]string{"float64Tag": {"abc"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct TextUnmarshaler value",
|
||||
data: map[string][]string{"timeTag": {"123"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "custom tagKey",
|
||||
data: map[string][]string{
|
||||
"tag1": {"a"},
|
||||
"tag2": {"b"},
|
||||
"tag3": {"c"},
|
||||
"Item": {"d"},
|
||||
},
|
||||
dst: &struct {
|
||||
Item string `form:"tag1" query:"tag2" json:"tag2"`
|
||||
}{},
|
||||
tag: "query",
|
||||
result: `{"tag2":"b"}`,
|
||||
},
|
||||
{
|
||||
name: "custom prefix",
|
||||
data: map[string][]string{
|
||||
"test.A": {"1"},
|
||||
"A": {"2"},
|
||||
"test.alias": {"3"},
|
||||
},
|
||||
dst: &struct {
|
||||
A string
|
||||
B string `form:"alias"`
|
||||
}{},
|
||||
prefix: "test",
|
||||
result: `{"A":"1","B":"3"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := router.UnmarshalRequestData(s.data, s.dst, s.tag, s.prefix)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.error {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.error, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(s.dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(raw, []byte(s.result)) {
|
||||
t.Fatalf("Expected dst \n%s\ngot\n%s", s.result, raw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// note: extra unexported checks in addition to the above test as there
|
||||
// is no easy way to print nested structs with all their fields.
|
||||
func TestUnmarshalRequestDataUnexportedFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:all
|
||||
type TestStruct struct {
|
||||
Exported string
|
||||
|
||||
unexported string
|
||||
// to ensure that the reflection doesn't take tags with higher priority than the exported state
|
||||
unexportedWithTag string `form:"unexportedWithTag" json:"unexportedWithTag"`
|
||||
}
|
||||
|
||||
dst := &TestStruct{}
|
||||
|
||||
err := router.UnmarshalRequestData(map[string][]string{
|
||||
"Exported": {"test"}, // just for reference
|
||||
|
||||
"Unexported": {"test"},
|
||||
"unexported": {"test"},
|
||||
"UnexportedWithTag": {"test"},
|
||||
"unexportedWithTag": {"test"},
|
||||
}, dst, "", "")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if dst.Exported != "test" {
|
||||
t.Fatalf("Expected the Exported field to be %q, got %q", "test", dst.Exported)
|
||||
}
|
||||
|
||||
if dst.unexported != "" {
|
||||
t.Fatalf("Expected the unexported field to remain empty, got %q", dst.unexported)
|
||||
}
|
||||
|
||||
if dst.unexportedWithTag != "" {
|
||||
t.Fatalf("Expected the unexportedWithTag field to remain empty, got %q", dst.unexportedWithTag)
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FireAndForget executes `f()` in a new go routine and auto recovers if panic.
|
||||
// FireAndForget executes f() in a new go routine and auto recovers if panic.
|
||||
//
|
||||
// **Note:** Use this only if you are not interested in the result of `f()`
|
||||
// **Note:** Use this only if you are not interested in the result of f()
|
||||
// and don't want to block the parent go routine.
|
||||
func FireAndForget(f func(), wg ...*sync.WaitGroup) {
|
||||
if len(wg) > 0 && wg[0] != nil {
|
||||
|
||||
@@ -64,9 +64,10 @@ func (f FilterData) BuildExpr(
|
||||
}
|
||||
}
|
||||
|
||||
if parsedFilterData.Has(raw) {
|
||||
return buildParsedFilterExpr(parsedFilterData.Get(raw), fieldResolver)
|
||||
if data, ok := parsedFilterData.GetOk(raw); ok {
|
||||
return buildParsedFilterExpr(data, fieldResolver)
|
||||
}
|
||||
|
||||
data, err := fexpr.Parse(raw)
|
||||
if err != nil {
|
||||
// depending on the users demand we may allow empty expressions
|
||||
@@ -78,9 +79,11 @@ func (f FilterData) BuildExpr(
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// store in cache
|
||||
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
|
||||
parsedFilterData.SetIfLessThanLimit(raw, data, 500)
|
||||
|
||||
return buildParsedFilterExpr(data, fieldResolver)
|
||||
}
|
||||
|
||||
@@ -431,6 +434,8 @@ func mergeParams(params ...dbx.Params) dbx.Params {
|
||||
return result
|
||||
}
|
||||
|
||||
// @todo consider adding support for custom single character wildcard
|
||||
//
|
||||
// wrapLikeParams wraps each provided param value string with `%`
|
||||
// if the param doesn't contain an explicit wildcard (`%`) character already.
|
||||
func wrapLikeParams(params dbx.Params) dbx.Params {
|
||||
|
||||
@@ -5,16 +5,20 @@ import (
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// DefaultPerPage specifies the default returned search result items.
|
||||
const DefaultPerPage int = 30
|
||||
|
||||
// @todo consider making it configurable
|
||||
//
|
||||
// MaxPerPage specifies the maximum allowed search result items returned in a single page.
|
||||
const MaxPerPage int = 500
|
||||
const MaxPerPage int = 1000
|
||||
|
||||
// url search query params
|
||||
const (
|
||||
@@ -27,23 +31,23 @@ const (
|
||||
|
||||
// Result defines the returned search result structure.
|
||||
type Result struct {
|
||||
Items any `json:"items"`
|
||||
Page int `json:"page"`
|
||||
PerPage int `json:"perPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
Items any `json:"items"`
|
||||
}
|
||||
|
||||
// Provider represents a single configured search provider instance.
|
||||
type Provider struct {
|
||||
fieldResolver FieldResolver
|
||||
query *dbx.SelectQuery
|
||||
skipTotal bool
|
||||
countCol string
|
||||
page int
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
page int
|
||||
perPage int
|
||||
skipTotal bool
|
||||
}
|
||||
|
||||
// NewProvider creates and returns a new search provider.
|
||||
@@ -208,6 +212,14 @@ func (s *Provider) Exec(items any) (*Result, error) {
|
||||
return nil, err
|
||||
}
|
||||
if expr != "" {
|
||||
// ensure that _rowid_ expressions are always prefixed with the first FROM table
|
||||
if sortField.Name == rowidSortKey && !strings.Contains(expr, ".") {
|
||||
queryInfo := modelsQuery.Info()
|
||||
if len(queryInfo.From) > 0 {
|
||||
expr = "[[" + inflector.Columnify(queryInfo.From[0]) + "]]." + expr
|
||||
}
|
||||
}
|
||||
|
||||
modelsQuery.AndOrderBy(expr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,38 +180,39 @@ func TestProviderParse(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Page(initialPage).
|
||||
PerPage(initialPerPage).
|
||||
Sort(initialSort).
|
||||
Filter(initialFilter)
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.query), func(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).
|
||||
Page(initialPage).
|
||||
PerPage(initialPerPage).
|
||||
Sort(initialSort).
|
||||
Filter(initialFilter)
|
||||
|
||||
err := p.Parse(s.query)
|
||||
err := p.Parse(s.query)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if p.page != s.expectPage {
|
||||
t.Errorf("(%d) Expected page %v, got %v", i, s.expectPage, p.page)
|
||||
}
|
||||
if p.page != s.expectPage {
|
||||
t.Fatalf("Expected page %v, got %v", s.expectPage, p.page)
|
||||
}
|
||||
|
||||
if p.perPage != s.expectPerPage {
|
||||
t.Errorf("(%d) Expected perPage %v, got %v", i, s.expectPerPage, p.perPage)
|
||||
}
|
||||
if p.perPage != s.expectPerPage {
|
||||
t.Fatalf("Expected perPage %v, got %v", s.expectPerPage, p.perPage)
|
||||
}
|
||||
|
||||
encodedSort, _ := json.Marshal(p.sort)
|
||||
if string(encodedSort) != s.expectSort {
|
||||
t.Errorf("(%d) Expected sort %v, got \n%v", i, s.expectSort, string(encodedSort))
|
||||
}
|
||||
encodedSort, _ := json.Marshal(p.sort)
|
||||
if string(encodedSort) != s.expectSort {
|
||||
t.Fatalf("Expected sort %v, got \n%v", s.expectSort, string(encodedSort))
|
||||
}
|
||||
|
||||
encodedFilter, _ := json.Marshal(p.filter)
|
||||
if string(encodedFilter) != s.expectFilter {
|
||||
t.Errorf("(%d) Expected filter %v, got \n%v", i, s.expectFilter, string(encodedFilter))
|
||||
}
|
||||
encodedFilter, _ := json.Marshal(p.filter)
|
||||
if string(encodedFilter) != s.expectFilter {
|
||||
t.Fatalf("Expected filter %v, got \n%v", s.expectFilter, string(encodedFilter))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +257,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{},
|
||||
false,
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":2,"totalPages":1,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":10,"totalItems":2,"totalPages":1}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
|
||||
@@ -270,7 +271,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{},
|
||||
false,
|
||||
false,
|
||||
`{"page":10,"perPage":30,"totalItems":2,"totalPages":1,"items":[]}`,
|
||||
`{"items":[],"page":10,"perPage":30,"totalItems":2,"totalPages":1}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30 OFFSET 270",
|
||||
@@ -306,7 +307,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
false,
|
||||
false,
|
||||
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2)",
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
|
||||
@@ -320,7 +321,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
true,
|
||||
false,
|
||||
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":-1,"totalPages":-1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":-1,"totalPages":-1}`,
|
||||
[]string{
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
|
||||
},
|
||||
@@ -333,7 +334,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{"test3 != ''"},
|
||||
false,
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":0,"totalPages":0,"items":[]}`,
|
||||
`{"items":[],"page":1,"perPage":10,"totalItems":0,"totalPages":0}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL)))",
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
@@ -347,7 +348,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{"test3 != ''"},
|
||||
true,
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":-1,"totalPages":-1,"items":[]}`,
|
||||
`{"items":[],"page":1,"perPage":10,"totalItems":-1,"totalPages":-1}`,
|
||||
[]string{
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
},
|
||||
@@ -360,7 +361,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{},
|
||||
false,
|
||||
false,
|
||||
`{"page":2,"perPage":1,"totalItems":2,"totalPages":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":2,"totalPages":2}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
@@ -374,7 +375,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
[]FilterData{},
|
||||
true,
|
||||
false,
|
||||
`{"page":2,"perPage":1,"totalItems":-1,"totalPages":-1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":-1,"totalPages":-1}`,
|
||||
[]string{
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
@@ -449,7 +450,7 @@ func TestProviderParseAndExec(t *testing.T) {
|
||||
"no extra query params (aka. use the provider presets)",
|
||||
"",
|
||||
false,
|
||||
`{"page":2,"perPage":123,"totalItems":2,"totalPages":1,"items":[]}`,
|
||||
`{"items":[],"page":2,"perPage":123,"totalItems":2,"totalPages":1}`,
|
||||
},
|
||||
{
|
||||
"invalid query",
|
||||
@@ -491,62 +492,63 @@ func TestProviderParseAndExec(t *testing.T) {
|
||||
"page > existing",
|
||||
"page=3&perPage=9999",
|
||||
false,
|
||||
`{"page":3,"perPage":500,"totalItems":2,"totalPages":1,"items":[]}`,
|
||||
`{"items":[],"page":3,"perPage":1000,"totalItems":2,"totalPages":1}`,
|
||||
},
|
||||
{
|
||||
"valid query params",
|
||||
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3",
|
||||
false,
|
||||
`{"page":1,"perPage":500,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":1,"totalPages":1}`,
|
||||
},
|
||||
{
|
||||
"valid query params with skipTotal=1",
|
||||
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3&skipTotal=1",
|
||||
false,
|
||||
`{"page":1,"perPage":500,"totalItems":-1,"totalPages":-1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":-1,"totalPages":-1}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testDB.CalledQueries = []string{} // reset
|
||||
|
||||
testResolver := &testFieldResolver{}
|
||||
provider := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(2).
|
||||
PerPage(123).
|
||||
Sort([]SortField{{"test2", SortAsc}}).
|
||||
Filter([]FilterData{"test1 > 0"})
|
||||
testResolver := &testFieldResolver{}
|
||||
provider := NewProvider(testResolver).
|
||||
Query(query).
|
||||
Page(2).
|
||||
PerPage(123).
|
||||
Sort([]SortField{{"test2", SortAsc}}).
|
||||
Filter([]FilterData{"test1 > 0"})
|
||||
|
||||
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
|
||||
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Errorf("[%s] Expected resolver.Update to be called %d, got %d", s.name, 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
if testResolver.UpdateQueryCalls != 1 {
|
||||
t.Fatalf("Expected resolver.Update to be called %d, got %d", 1, testResolver.UpdateQueryCalls)
|
||||
}
|
||||
|
||||
expectedQueries := 2
|
||||
if provider.skipTotal {
|
||||
expectedQueries = 1
|
||||
}
|
||||
expectedQueries := 2
|
||||
if provider.skipTotal {
|
||||
expectedQueries = 1
|
||||
}
|
||||
|
||||
if len(testDB.CalledQueries) != expectedQueries {
|
||||
t.Errorf("[%s] Expected %d db queries, got %d: \n%v", s.name, expectedQueries, len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
if len(testDB.CalledQueries) != expectedQueries {
|
||||
t.Fatalf("Expected %d db queries, got %d: \n%v", expectedQueries, len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Errorf("[%s] Expected result %v, got \n%v", s.name, s.expectResult, string(encoded))
|
||||
}
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != s.expectResult {
|
||||
t.Fatalf("Expected result \n%v\ngot\n%v", s.expectResult, string(encoded))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
// Returns error if `field` is not in `r.allowedFields`.
|
||||
func (r *SimpleFieldResolver) Resolve(field string) (*ResolverResult, error) {
|
||||
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
|
||||
return nil, fmt.Errorf("failed to resolve field %q", field)
|
||||
return nil, fmt.Errorf("Failed to resolve field %q.", field)
|
||||
}
|
||||
|
||||
parts := strings.Split(field, ".")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package search_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
@@ -23,22 +24,22 @@ func TestSimpleFieldResolverUpdateQuery(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
|
||||
db := dbx.NewFromDB(nil, "")
|
||||
query := db.Select("id").From("test")
|
||||
|
||||
r.Resolve(s.fieldName)
|
||||
r.Resolve(s.fieldName)
|
||||
|
||||
if err := r.UpdateQuery(nil); err != nil {
|
||||
t.Errorf("(%d) UpdateQuery failed with error %v", i, err)
|
||||
continue
|
||||
}
|
||||
if err := r.UpdateQuery(nil); err != nil {
|
||||
t.Fatalf("UpdateQuery failed with error %v", err)
|
||||
}
|
||||
|
||||
rawQuery := query.Build().SQL()
|
||||
// rawQuery := s.expectQuery
|
||||
rawQuery := query.Build().SQL()
|
||||
|
||||
if rawQuery != s.expectQuery {
|
||||
t.Errorf("(%d) Expected query %v, got \n%v", i, s.expectQuery, rawQuery)
|
||||
}
|
||||
if rawQuery != s.expectQuery {
|
||||
t.Fatalf("Expected query %v, got \n%v", s.expectQuery, rawQuery)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,25 +63,25 @@ func TestSimpleFieldResolverResolve(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
r, err := r.Resolve(s.fieldName)
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
|
||||
r, err := r.Resolve(s.fieldName)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
continue
|
||||
}
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Identifier != s.expectName {
|
||||
t.Errorf("(%d) Expected r.Identifier %q, got %q", i, s.expectName, r.Identifier)
|
||||
}
|
||||
if r.Identifier != s.expectName {
|
||||
t.Fatalf("Expected r.Identifier %q, got %q", s.expectName, r.Identifier)
|
||||
}
|
||||
|
||||
// params should be empty
|
||||
if len(r.Params) != 0 {
|
||||
t.Errorf("(%d) Expected 0 r.Params, got %v", i, r.Params)
|
||||
}
|
||||
if len(r.Params) != 0 {
|
||||
t.Fatalf("r.Params should be empty, got %v", r.Params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const randomSortKey string = "@random"
|
||||
const (
|
||||
randomSortKey string = "@random"
|
||||
rowidSortKey string = "@rowid"
|
||||
)
|
||||
|
||||
// sort field directions
|
||||
const (
|
||||
@@ -26,6 +29,11 @@ func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
|
||||
return "RANDOM()", nil
|
||||
}
|
||||
|
||||
// special case for the builtin SQLite rowid column
|
||||
if s.Name == rowidSortKey {
|
||||
return fmt.Sprintf("[[_rowid_]] %s", s.Direction), nil
|
||||
}
|
||||
|
||||
result, err := fieldResolver.Resolve(s.Name)
|
||||
|
||||
// invalidate empty fields and non-column identifiers
|
||||
|
||||
+26
-18
@@ -2,6 +2,7 @@ package search_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
@@ -29,27 +30,30 @@ func TestSortFieldBuildExpr(t *testing.T) {
|
||||
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
|
||||
// special @random field (ignore direction)
|
||||
{search.SortField{"@random", search.SortDesc}, false, "RANDOM()"},
|
||||
// special _rowid_ field
|
||||
{search.SortField{"@rowid", search.SortDesc}, false, "[[_rowid_]] DESC"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := s.sortField.BuildExpr(resolver)
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%s_%s", s.sortField.Name, s.sortField.Name), func(t *testing.T) {
|
||||
result, err := s.sortField.BuildExpr(resolver)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if result != s.expectExpression {
|
||||
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectExpression, result)
|
||||
}
|
||||
if result != s.expectExpression {
|
||||
t.Fatalf("Expected expression %v, got %v", s.expectExpression, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortFromString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value string
|
||||
expectedJson string
|
||||
value string
|
||||
expected string
|
||||
}{
|
||||
{"", `[{"name":"","direction":"ASC"}]`},
|
||||
{"test", `[{"name":"test","direction":"ASC"}]`},
|
||||
@@ -57,14 +61,18 @@ func TestParseSortFromString(t *testing.T) {
|
||||
{"-test", `[{"name":"test","direction":"DESC"}]`},
|
||||
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
|
||||
{"@random,-test", `[{"name":"@random","direction":"ASC"},{"name":"test","direction":"DESC"}]`},
|
||||
{"-@rowid,-test", `[{"name":"@rowid","direction":"DESC"},{"name":"test","direction":"DESC"}]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result := search.ParseSortFromString(s.value)
|
||||
encoded, _ := json.Marshal(result)
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.value, func(t *testing.T) {
|
||||
result := search.ParseSortFromString(s.value)
|
||||
encoded, _ := json.Marshal(result)
|
||||
encodedStr := string(encoded)
|
||||
|
||||
if string(encoded) != s.expectedJson {
|
||||
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectedJson, string(encoded))
|
||||
}
|
||||
if encodedStr != s.expected {
|
||||
t.Fatalf("Expected expression %s, got %s", s.expected, encodedStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
@@ -17,30 +18,30 @@ func TestEncrypt(t *testing.T) {
|
||||
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Encrypt([]byte(scenario.data), scenario.key)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.data), func(t *testing.T) {
|
||||
result, err := security.Encrypt([]byte(s.data), s.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
hasErr := err != nil
|
||||
|
||||
if scenario.expectError && result != "" {
|
||||
t.Errorf("(%d) Expected empty string, got %q", i, result)
|
||||
}
|
||||
if !scenario.expectError && result == "" {
|
||||
t.Errorf("(%d) Expected non empty encrypted result string", i)
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
if result != "" {
|
||||
decrypted, _ := security.Decrypt(result, scenario.key)
|
||||
if string(decrypted) != scenario.data {
|
||||
t.Errorf("(%d) Expected decrypted value to match with the data input, got %q", i, decrypted)
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
if result != "" {
|
||||
t.Fatalf("Expected empty Encrypt result on error, got %q", result)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
decrypted, err := security.Decrypt(result, s.key)
|
||||
if err != nil || string(decrypted) != s.data {
|
||||
t.Fatalf("Expected decrypted value to match with the data input, got %q (%v)", decrypted, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,19 +58,23 @@ func TestDecrypt(t *testing.T) {
|
||||
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Decrypt(scenario.cipher, scenario.key)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
|
||||
result, err := security.Decrypt(s.cipher, s.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
hasErr := err != nil
|
||||
|
||||
resultStr := string(result)
|
||||
if resultStr != scenario.expectedData {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expectedData, resultStr)
|
||||
}
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if str := string(result); str != s.expectedData {
|
||||
t.Fatalf("Expected %q, got %q", s.expectedData, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+3
-12
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
// @todo update to v5
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
@@ -43,11 +44,9 @@ func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
|
||||
}
|
||||
|
||||
// NewJWT generates and returns new HS256 signed JWT.
|
||||
func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
|
||||
seconds := time.Duration(secondsDuration) * time.Second
|
||||
|
||||
func NewJWT(payload jwt.MapClaims, signingKey string, duration time.Duration) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"exp": time.Now().Add(seconds).Unix(),
|
||||
"exp": time.Now().Add(duration).Unix(),
|
||||
}
|
||||
|
||||
for k, v := range payload {
|
||||
@@ -56,11 +55,3 @@ func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (st
|
||||
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
|
||||
}
|
||||
|
||||
// Deprecated:
|
||||
// Consider replacing with NewJWT().
|
||||
//
|
||||
// NewToken is a legacy alias for NewJWT that generates a HS256 signed JWT.
|
||||
func NewToken(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
|
||||
return NewJWT(payload, signingKey, secondsDuration)
|
||||
}
|
||||
|
||||
+61
-54
@@ -1,7 +1,10 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
@@ -102,26 +105,30 @@ func TestParseJWT(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.ParseJWT(scenario.token, scenario.secret)
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
if len(result) != len(scenario.expectClaims) {
|
||||
t.Errorf("(%d) Expected %v got %v", i, scenario.expectClaims, result)
|
||||
}
|
||||
for k, v := range scenario.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Errorf("(%d) Missing expected claim %q", i, k)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.token), func(t *testing.T) {
|
||||
result, err := security.ParseJWT(s.token, s.secret)
|
||||
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, v, k, v2)
|
||||
|
||||
if len(result) != len(s.expectClaims) {
|
||||
t.Fatalf("Expected %v claims got %v", s.expectClaims, result)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range s.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Fatalf("Missing expected claim %q", k)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Fatalf("Expected %v for %q claim, got %v", v, k, v2)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,51 +136,51 @@ func TestNewJWT(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
claims jwt.MapClaims
|
||||
key string
|
||||
duration int64
|
||||
duration time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
// empty, zero duration
|
||||
{jwt.MapClaims{}, "", 0, true},
|
||||
// empty, 10 seconds duration
|
||||
{jwt.MapClaims{}, "", 10, false},
|
||||
{jwt.MapClaims{}, "", 10 * time.Second, false},
|
||||
// non-empty, 10 seconds duration
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10, false},
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10 * time.Second, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Errorf("(%d) Expected NewJWT to succeed, got error %v", i, tokenErr)
|
||||
continue
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasParseErr to be %v, got %v (%v)", i, scenario.expectError, hasParseErr, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Errorf("(%d) Missing required claim exp, got %v", i, claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Errorf("(%d) Expected %v claims, got %v", i, scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, claims[j], k, scenario.claims[j])
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Fatalf("Expected NewJWT to succeed, got error %v", tokenErr)
|
||||
}
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Fatalf("Expected hasParseErr to be %v, got %v (%v)", scenario.expectError, hasParseErr, parseErr)
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Fatalf("Missing required claim exp, got %v", claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Fatalf("Expected %v claims, got %v", scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Fatalf("Expected %v for %q claim, got %v", claims[j], k, scenario.claims[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,11 @@ package security
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"math/big"
|
||||
mathRand "math/rand"
|
||||
"time"
|
||||
mathRand "math/rand" // @todo replace with rand/v2?
|
||||
)
|
||||
|
||||
const defaultRandomAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
func init() {
|
||||
mathRand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RandomString generates a cryptographically random string with the specified length.
|
||||
//
|
||||
// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding.
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"regexp/syntax"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultMaxRepeat = 6
|
||||
|
||||
var anyCharNotNLPairs = []rune{'A', 'Z', 'a', 'z', '0', '9'}
|
||||
|
||||
// RandomStringByRegex generates a random string matching the regex pattern.
|
||||
// If optFlags is not set, fallbacks to [syntax.Perl].
|
||||
//
|
||||
// NB! While the source of the randomness comes from [crypto/rand] this method
|
||||
// is not recommended to be used on its own in critical secure contexts because
|
||||
// the generated length could vary too much on the used pattern and may not be
|
||||
// as secure as simply calling [security.RandomString].
|
||||
// If you still insist on using it for such purposes, consider at least
|
||||
// a large enough minimum length for the generated string, e.g. `[a-z0-9]{30}`.
|
||||
//
|
||||
// This function is inspired by github.com/pipe01/revregexp, github.com/lucasjones/reggen and other similar packages.
|
||||
func RandomStringByRegex(pattern string, optFlags ...syntax.Flags) (string, error) {
|
||||
var flags syntax.Flags
|
||||
if len(optFlags) == 0 {
|
||||
flags = syntax.Perl
|
||||
} else {
|
||||
for _, f := range optFlags {
|
||||
flags |= f
|
||||
}
|
||||
}
|
||||
|
||||
r, err := syntax.Parse(pattern, flags)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var sb = new(strings.Builder)
|
||||
|
||||
err = writeRandomStringByRegex(r, sb)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func writeRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder) error {
|
||||
// https://pkg.go.dev/regexp/syntax#Op
|
||||
switch r.Op {
|
||||
case syntax.OpCharClass:
|
||||
c, err := randomRuneFromPairs(r.Rune)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sb.WriteRune(c)
|
||||
return err
|
||||
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
|
||||
c, err := randomRuneFromPairs(anyCharNotNLPairs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sb.WriteRune(c)
|
||||
return err
|
||||
case syntax.OpAlternate:
|
||||
idx, err := randomNumber(len(r.Sub))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRandomStringByRegex(r.Sub[idx], sb)
|
||||
case syntax.OpConcat:
|
||||
var err error
|
||||
for _, sub := range r.Sub {
|
||||
err = writeRandomStringByRegex(sub, sb)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return err
|
||||
case syntax.OpRepeat:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, r.Min, r.Max)
|
||||
case syntax.OpQuest:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 0, 1)
|
||||
case syntax.OpPlus:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 1, -1)
|
||||
case syntax.OpStar:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 0, -1)
|
||||
case syntax.OpCapture:
|
||||
return writeRandomStringByRegex(r.Sub[0], sb)
|
||||
case syntax.OpLiteral:
|
||||
_, err := sb.WriteString(string(r.Rune))
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("unsupported pattern operator %d", r.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func repeatRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder, min int, max int) error {
|
||||
if max < 0 {
|
||||
max = defaultMaxRepeat
|
||||
}
|
||||
|
||||
if max < min {
|
||||
max = min
|
||||
}
|
||||
|
||||
n := min
|
||||
if max != min {
|
||||
randRange, err := randomNumber(max - min)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n += randRange
|
||||
}
|
||||
|
||||
var err error
|
||||
for i := 0; i < n; i++ {
|
||||
err = writeRandomStringByRegex(r, sb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomRuneFromPairs(pairs []rune) (rune, error) {
|
||||
idx, err := randomNumber(len(pairs) / 2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return randomRuneFromRange(pairs[idx*2], pairs[idx*2+1])
|
||||
}
|
||||
|
||||
func randomRuneFromRange(min rune, max rune) (rune, error) {
|
||||
offset, err := randomNumber(int(max - min + 1))
|
||||
if err != nil {
|
||||
return min, err
|
||||
}
|
||||
|
||||
return min + rune(offset), nil
|
||||
}
|
||||
|
||||
func randomNumber(maxSoft int) (int, error) {
|
||||
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(maxSoft)))
|
||||
|
||||
return int(randRange.Int64()), err
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"regexp/syntax"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestRandomStringByRegex(t *testing.T) {
|
||||
generated := []string{}
|
||||
|
||||
scenarios := []struct {
|
||||
pattern string
|
||||
flags []syntax.Flags
|
||||
expectError bool
|
||||
}{
|
||||
{``, nil, true},
|
||||
{`test`, nil, false},
|
||||
{`\d+`, []syntax.Flags{syntax.POSIX}, true},
|
||||
{`\d+`, nil, false},
|
||||
{`\d*`, nil, false},
|
||||
{`\d{1,10}`, nil, false},
|
||||
{`\d{3}`, nil, false},
|
||||
{`\d{0,}-abc`, nil, false},
|
||||
{`[a-zA-Z]*`, nil, false},
|
||||
{`[^a-zA-Z]{5,30}`, nil, false},
|
||||
{`\w+_abc`, nil, false},
|
||||
{`[a-zA-Z_]*`, nil, false},
|
||||
{`[2-9]{5}-\w+`, nil, false},
|
||||
{`(a|b|c)`, nil, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, s.pattern), func(t *testing.T) {
|
||||
str, err := security.RandomStringByRegex(s.pattern, s.flags...)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
r, err := regexp.Compile(s.pattern)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !r.Match([]byte(str)) {
|
||||
t.Fatalf("Expected %q to match pattern %v", str, s.pattern)
|
||||
}
|
||||
|
||||
if slices.Contains(generated, str) {
|
||||
t.Fatalf("The generated string %q already exists in\n%v", str, generated)
|
||||
}
|
||||
|
||||
generated = append(generated, str)
|
||||
})
|
||||
}
|
||||
}
|
||||
+115
-24
@@ -1,11 +1,18 @@
|
||||
package store
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// @todo remove after https://github.com/golang/go/issues/20135
|
||||
const ShrinkThreshold = 200 // the number is arbitrary chosen
|
||||
|
||||
// Store defines a concurrent safe in memory key-value data store.
|
||||
type Store[T any] struct {
|
||||
data map[string]T
|
||||
mux sync.RWMutex
|
||||
data map[string]T
|
||||
mu sync.RWMutex
|
||||
deleted int64
|
||||
}
|
||||
|
||||
// New creates a new Store[T] instance with a shallow copy of the provided data (if any).
|
||||
@@ -20,8 +27,8 @@ func New[T any](data map[string]T) *Store[T] {
|
||||
// Reset clears the store and replaces the store data with a
|
||||
// shallow copy of the provided newData.
|
||||
func (s *Store[T]) Reset(newData map[string]T) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(newData) > 0 {
|
||||
s.data = make(map[string]T, len(newData))
|
||||
@@ -31,38 +38,50 @@ func (s *Store[T]) Reset(newData map[string]T) {
|
||||
} else {
|
||||
s.data = make(map[string]T)
|
||||
}
|
||||
|
||||
s.deleted = 0
|
||||
}
|
||||
|
||||
// Length returns the current number of elements in the store.
|
||||
func (s *Store[T]) Length() int {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
// RemoveAll removes all the existing store entries.
|
||||
func (s *Store[T]) RemoveAll() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.data = make(map[string]T)
|
||||
s.Reset(nil)
|
||||
}
|
||||
|
||||
// Remove removes a single entry from the store.
|
||||
//
|
||||
// Remove does nothing if key doesn't exist in the store.
|
||||
func (s *Store[T]) Remove(key string) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.data, key)
|
||||
s.deleted++
|
||||
|
||||
// reassign to a new map so that the old one can be gc-ed because it doesn't shrink
|
||||
//
|
||||
// @todo remove after https://github.com/golang/go/issues/20135
|
||||
if s.deleted >= ShrinkThreshold {
|
||||
newData := make(map[string]T, len(s.data))
|
||||
for k, v := range s.data {
|
||||
newData[k] = v
|
||||
}
|
||||
s.data = newData
|
||||
s.deleted = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Has checks if element with the specified key exist or not.
|
||||
func (s *Store[T]) Has(key string) bool {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
_, ok := s.data[key]
|
||||
|
||||
@@ -73,16 +92,26 @@ func (s *Store[T]) Has(key string) bool {
|
||||
//
|
||||
// If key is not set, the zero T value is returned.
|
||||
func (s *Store[T]) Get(key string) T {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
return s.data[key]
|
||||
}
|
||||
|
||||
// GetOk is similar to Get but returns also a boolean indicating whether the key exists or not.
|
||||
func (s *Store[T]) GetOk(key string) (T, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
v, ok := s.data[key]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// GetAll returns a shallow copy of the current store data.
|
||||
func (s *Store[T]) GetAll() map[string]T {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var clone = make(map[string]T, len(s.data))
|
||||
|
||||
@@ -93,10 +122,24 @@ func (s *Store[T]) GetAll() map[string]T {
|
||||
return clone
|
||||
}
|
||||
|
||||
// Values returns a slice with all of the current store values.
|
||||
func (s *Store[T]) Values() []T {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var values = make([]T, 0, len(s.data))
|
||||
|
||||
for _, v := range s.data {
|
||||
values = append(values, v)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
// Set sets (or overwrite if already exist) a new value for key.
|
||||
func (s *Store[T]) Set(key string, value T) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[string]T)
|
||||
@@ -105,16 +148,34 @@ func (s *Store[T]) Set(key string, value T) {
|
||||
s.data[key] = value
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a single existing value for the provided key
|
||||
// or stores a new one if it doesn't exist.
|
||||
func (s *Store[T]) GetOrSet(key string, setFunc func() T) T {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[string]T)
|
||||
}
|
||||
|
||||
v, ok := s.data[key]
|
||||
if !ok {
|
||||
v = setFunc()
|
||||
s.data[key] = v
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// SetIfLessThanLimit sets (or overwrite if already exist) a new value for key.
|
||||
//
|
||||
// This method is similar to Set() but **it will skip adding new elements**
|
||||
// to the store if the store length has reached the specified limit.
|
||||
// false is returned if maxAllowedElements limit is reached.
|
||||
func (s *Store[T]) SetIfLessThanLimit(key string, value T, maxAllowedElements int) bool {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// init map if not already
|
||||
if s.data == nil {
|
||||
s.data = make(map[string]T)
|
||||
}
|
||||
@@ -132,3 +193,33 @@ func (s *Store[T]) SetIfLessThanLimit(key string, value T, maxAllowedElements in
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler] and imports the
|
||||
// provided JSON data into the store.
|
||||
//
|
||||
// The store entries that match with the ones from the data will be overwritten with the new value.
|
||||
func (s *Store[T]) UnmarshalJSON(data []byte) error {
|
||||
raw := map[string]T{}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[string]T)
|
||||
}
|
||||
|
||||
for k, v := range raw {
|
||||
s.data[k] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler] and export the current
|
||||
// store data into valid JSON.
|
||||
func (s *Store[T]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(s.GetAll())
|
||||
}
|
||||
|
||||
+163
-5
@@ -3,6 +3,8 @@ package store_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
@@ -137,11 +139,41 @@ func TestGet(t *testing.T) {
|
||||
{"missing", 0}, // should auto fallback to the zero value
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
val := s.Get(scenario.key)
|
||||
if val != scenario.expect {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expect, val)
|
||||
}
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.key, func(t *testing.T) {
|
||||
val := s.Get(scenario.key)
|
||||
if val != scenario.expect {
|
||||
t.Fatalf("Expected %v, got %v", scenario.expect, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOk(t *testing.T) {
|
||||
s := store.New(map[string]int{"test1": 0, "test2": 1})
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
expectValue int
|
||||
expectOk bool
|
||||
}{
|
||||
{"test1", 0, true},
|
||||
{"test2", 1, true},
|
||||
{"missing", 0, false}, // should auto fallback to the zero value
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.key, func(t *testing.T) {
|
||||
val, ok := s.GetOk(scenario.key)
|
||||
|
||||
if ok != scenario.expectOk {
|
||||
t.Fatalf("Expected ok %v, got %v", scenario.expectOk, ok)
|
||||
}
|
||||
|
||||
if val != scenario.expectValue {
|
||||
t.Fatalf("Expected %v, got %v", scenario.expectValue, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,6 +205,27 @@ func TestGetAll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
data := map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
}
|
||||
|
||||
values := store.New(data).Values()
|
||||
|
||||
expected := []int{1, 2}
|
||||
|
||||
if len(values) != len(expected) {
|
||||
t.Fatalf("Expected %d values, got %d", len(expected), len(values))
|
||||
}
|
||||
|
||||
for _, v := range expected {
|
||||
if !slices.Contains(values, v) {
|
||||
t.Fatalf("Missing value %v in\n%v", v, values)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
s := store.Store[int]{}
|
||||
|
||||
@@ -196,6 +249,37 @@ func TestSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrSet(t *testing.T) {
|
||||
s := store.New(map[string]int{
|
||||
"test1": 0,
|
||||
"test2": 1,
|
||||
"test3": 3,
|
||||
})
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
value int
|
||||
expected int
|
||||
}{
|
||||
{"test2", 20, 1},
|
||||
{"test3", 2, 3},
|
||||
{"test_new", 20, 20},
|
||||
{"test_new", 50, 20}, // should return the previously inserted value
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.key, func(t *testing.T) {
|
||||
result := s.GetOrSet(scenario.key, func() int {
|
||||
return scenario.value
|
||||
})
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Fatalf("Expected %v, got %v", scenario.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetIfLessThanLimit(t *testing.T) {
|
||||
s := store.Store[int]{}
|
||||
|
||||
@@ -230,3 +314,77 @@ func TestSetIfLessThanLimit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJSON(t *testing.T) {
|
||||
s := store.Store[string]{}
|
||||
s.Set("b", "old") // should be overwritten
|
||||
s.Set("c", "test3") // ensures that the old values are not removed
|
||||
|
||||
raw := []byte(`{"a":"test1", "b":"test2"}`)
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := s.Get("a"); v != "test1" {
|
||||
t.Fatalf("Expected store.a to be %q, got %q", "test1", v)
|
||||
}
|
||||
|
||||
if v := s.Get("b"); v != "test2" {
|
||||
t.Fatalf("Expected store.b to be %q, got %q", "test2", v)
|
||||
}
|
||||
|
||||
if v := s.Get("c"); v != "test3" {
|
||||
t.Fatalf("Expected store.c to be %q, got %q", "test3", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalJSON(t *testing.T) {
|
||||
s := &store.Store[string]{}
|
||||
s.Set("a", "test1")
|
||||
s.Set("b", "test2")
|
||||
|
||||
expected := []byte(`{"a":"test1", "b":"test2"}`)
|
||||
|
||||
result, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if bytes.Equal(result, expected) {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShrink(t *testing.T) {
|
||||
s := &store.Store[int]{}
|
||||
|
||||
total := 1000
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
s.Set(strconv.Itoa(i), i)
|
||||
}
|
||||
|
||||
if s.Length() != total {
|
||||
t.Fatalf("Expected %d items, got %d", total, s.Length())
|
||||
}
|
||||
|
||||
// trigger map "shrink"
|
||||
for i := 0; i < store.ShrinkThreshold; i++ {
|
||||
s.Remove(strconv.Itoa(i))
|
||||
}
|
||||
|
||||
// ensure that after the deletion, the new map was copied properly
|
||||
if s.Length() != total-store.ShrinkThreshold {
|
||||
t.Fatalf("Expected %d items, got %d", total-store.ShrinkThreshold, s.Length())
|
||||
}
|
||||
|
||||
for k := range s.GetAll() {
|
||||
kInt, err := strconv.Atoi(k)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert %s into int: %v", k, err)
|
||||
}
|
||||
if kInt < store.ShrinkThreshold {
|
||||
t.Fatalf("Key %q should have been deleted", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,47 +2,41 @@ package subscriptions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
// Broker defines a struct for managing subscriptions clients.
|
||||
type Broker struct {
|
||||
clients map[string]Client
|
||||
mux sync.RWMutex
|
||||
store *store.Store[Client]
|
||||
}
|
||||
|
||||
// NewBroker initializes and returns a new Broker instance.
|
||||
func NewBroker() *Broker {
|
||||
return &Broker{
|
||||
clients: make(map[string]Client),
|
||||
store: store.New[Client](nil),
|
||||
}
|
||||
}
|
||||
|
||||
// Clients returns a shallow copy of all registered clients indexed
|
||||
// with their connection id.
|
||||
func (b *Broker) Clients() map[string]Client {
|
||||
b.mux.RLock()
|
||||
defer b.mux.RUnlock()
|
||||
return b.store.GetAll()
|
||||
}
|
||||
|
||||
copy := make(map[string]Client, len(b.clients))
|
||||
|
||||
for id, c := range b.clients {
|
||||
copy[id] = c
|
||||
}
|
||||
|
||||
return copy
|
||||
// ChunkedClients splits the current clients into a chunked slice.
|
||||
func (b *Broker) ChunkedClients(chunkSize int) [][]Client {
|
||||
return list.ToChunks(b.store.Values(), chunkSize)
|
||||
}
|
||||
|
||||
// ClientById finds a registered client by its id.
|
||||
//
|
||||
// Returns non-nil error when client with clientId is not registered.
|
||||
func (b *Broker) ClientById(clientId string) (Client, error) {
|
||||
b.mux.RLock()
|
||||
defer b.mux.RUnlock()
|
||||
|
||||
client, ok := b.clients[clientId]
|
||||
client, ok := b.store.GetOk(clientId)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("No client associated with connection ID %q", clientId)
|
||||
return nil, fmt.Errorf("no client associated with connection ID %q", clientId)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
@@ -50,21 +44,17 @@ func (b *Broker) ClientById(clientId string) (Client, error) {
|
||||
|
||||
// Register adds a new client to the broker instance.
|
||||
func (b *Broker) Register(client Client) {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
|
||||
b.clients[client.Id()] = client
|
||||
b.store.Set(client.Id(), client)
|
||||
}
|
||||
|
||||
// Unregister removes a single client by its id.
|
||||
//
|
||||
// If client with clientId doesn't exist, this method does nothing.
|
||||
func (b *Broker) Unregister(clientId string) {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
|
||||
if client, ok := b.clients[clientId]; ok {
|
||||
client.Discard()
|
||||
delete(b.clients, clientId)
|
||||
client := b.store.Get(clientId)
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
client.Discard()
|
||||
b.store.Remove(clientId)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user