merge v0.23.0-rc changes

This commit is contained in:
Gani Georgiev
2024-09-29 19:23:19 +03:00
parent ad92992324
commit 844f18cac3
753 changed files with 85141 additions and 63396 deletions
+5 -7
View File
@@ -3,6 +3,7 @@ package archive
import (
"archive/zip"
"compress/flate"
"errors"
"io"
"io/fs"
"os"
@@ -23,24 +24,21 @@ func Create(src string, dest string, skipPaths ...string) error {
if err != nil {
return err
}
defer zf.Close()
zw := zip.NewWriter(zf)
defer zw.Close()
// register a custom Deflate compressor
zw.RegisterCompressor(zip.Deflate, func(out io.Writer) (io.WriteCloser, error) {
return flate.NewWriter(out, flate.BestSpeed)
})
if err := zipAddFS(zw, os.DirFS(src), skipPaths...); err != nil {
err = zipAddFS(zw, os.DirFS(src), skipPaths...)
if err != nil {
// try to cleanup at least the created zip file
os.Remove(dest)
return err
return errors.Join(err, zw.Close(), zf.Close(), os.Remove(dest))
}
return nil
return errors.Join(zw.Close(), zf.Close())
}
// note remove after similar method is added in the std lib (https://github.com/golang/go/issues/54898)
+14 -10
View File
@@ -18,6 +18,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameApple] = wrapFactory(NewAppleProvider)
}
var _ Provider = (*Apple)(nil)
// NameApple is the unique name of the Apple provider.
@@ -27,23 +31,23 @@ const NameApple string = "apple"
//
// [OIDC differences]: https://bitbucket.org/openid/connect/src/master/How-Sign-in-with-Apple-differs-from-OpenID-Connect.md
type Apple struct {
*baseProvider
BaseProvider
jwksUrl string
jwksURL string
}
// NewAppleProvider creates a new Apple provider instance with some defaults.
func NewAppleProvider() *Apple {
return &Apple{
baseProvider: &baseProvider{
BaseProvider: BaseProvider{
ctx: context.Background(),
displayName: "Apple",
pkce: true,
scopes: []string{"name", "email"},
authUrl: "https://appleid.apple.com/auth/authorize",
tokenUrl: "https://appleid.apple.com/auth/token",
authURL: "https://appleid.apple.com/auth/authorize",
tokenURL: "https://appleid.apple.com/auth/token",
},
jwksUrl: "https://appleid.apple.com/auth/keys",
jwksURL: "https://appleid.apple.com/auth/keys",
}
}
@@ -51,7 +55,7 @@ func NewAppleProvider() *Apple {
//
// API reference: https://developer.apple.com/documentation/sign_in_with_apple/tokenresponse.
func (p *Apple) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -98,11 +102,11 @@ func (p *Apple) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
return user, nil
}
// FetchRawUserData implements Provider.FetchRawUserData interface.
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
//
// Apple doesn't have a UserInfo endpoint and claims about users
// are instead included in the "id_token" (https://openid.net/specs/openid-connect-core-1_0.html#id_tokenExample)
func (p *Apple) FetchRawUserData(token *oauth2.Token) ([]byte, error) {
func (p *Apple) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
idToken, _ := token.Extra("id_token").(string)
claims, err := p.parseAndVerifyIdToken(idToken)
@@ -209,7 +213,7 @@ type jwk struct {
}
func (p *Apple) fetchJWK(kid string) (*jwk, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.jwksUrl, nil)
req, err := http.NewRequestWithContext(p.ctx, "GET", p.jwksURL, nil)
if err != nil {
return nil, err
}
+73 -87
View File
@@ -2,6 +2,7 @@ package auth
import (
"context"
"encoding/json"
"errors"
"net/http"
@@ -9,17 +10,22 @@ import (
"golang.org/x/oauth2"
)
// AuthUser defines a standardized oauth2 user data structure.
type AuthUser struct {
Id string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
AvatarUrl string `json:"avatarUrl"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
Expiry types.DateTime `json:"expiry"`
RawUser map[string]any `json:"rawUser"`
// ProviderFactoryFunc defines a function for initializing a new OAuth2 provider.
type ProviderFactoryFunc func() Provider
// Providers defines a map with all of the available OAuth2 providers.
//
// To register a new provider append a new entry in the map.
var Providers = map[string]ProviderFactoryFunc{}
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
func NewProviderByName(name string) (Provider, error) {
factory, ok := Providers[name]
if !ok {
return nil, errors.New("missing provider " + name)
}
return factory(), nil
}
// Provider defines a common interface for an OAuth2 client.
@@ -61,104 +67,84 @@ type Provider interface {
// SetClientSecret sets the provider client's app secret.
SetClientSecret(secret string)
// RedirectUrl returns the end address to redirect the user
// RedirectURL returns the end address to redirect the user
// going through the OAuth flow.
RedirectUrl() string
RedirectURL() string
// SetRedirectUrl sets the provider's RedirectUrl.
SetRedirectUrl(url string)
// SetRedirectURL sets the provider's RedirectURL.
SetRedirectURL(url string)
// AuthUrl returns the provider's authorization service url.
AuthUrl() string
// AuthURL returns the provider's authorization service url.
AuthURL() string
// SetAuthUrl sets the provider's AuthUrl.
SetAuthUrl(url string)
// SetAuthURL sets the provider's AuthURL.
SetAuthURL(url string)
// TokenUrl returns the provider's token exchange service url.
TokenUrl() string
// TokenURL returns the provider's token exchange service url.
TokenURL() string
// SetTokenUrl sets the provider's TokenUrl.
SetTokenUrl(url string)
// SetTokenURL sets the provider's TokenURL.
SetTokenURL(url string)
// UserApiUrl returns the provider's user info api url.
UserApiUrl() string
// UserInfoURL returns the provider's user info api url.
UserInfoURL() string
// SetUserApiUrl sets the provider's UserApiUrl.
SetUserApiUrl(url string)
// SetUserInfoURL sets the provider's UserInfoURL.
SetUserInfoURL(url string)
// Client returns an http client using the provided token.
Client(token *oauth2.Token) *http.Client
// BuildAuthUrl returns a URL to the provider's consent page
// BuildAuthURL returns a URL to the provider's consent page
// that asks for permissions for the required scopes explicitly.
BuildAuthUrl(state string, opts ...oauth2.AuthCodeOption) string
BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string
// FetchToken converts an authorization code to token.
FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
// FetchRawUserData requests and marshalizes into `result` the
// FetchRawUserInfo requests and marshalizes into `result` the
// the OAuth user api response.
FetchRawUserData(token *oauth2.Token) ([]byte, error)
FetchRawUserInfo(token *oauth2.Token) ([]byte, error)
// FetchAuthUser is similar to FetchRawUserData, but normalizes and
// FetchAuthUser is similar to FetchRawUserInfo, but normalizes and
// marshalizes the user api response into a standardized AuthUser struct.
FetchAuthUser(token *oauth2.Token) (user *AuthUser, err error)
}
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
func NewProviderByName(name string) (Provider, error) {
switch name {
case NameGoogle:
return NewGoogleProvider(), nil
case NameFacebook:
return NewFacebookProvider(), nil
case NameGithub:
return NewGithubProvider(), nil
case NameGitlab:
return NewGitlabProvider(), nil
case NameDiscord:
return NewDiscordProvider(), nil
case NameTwitter:
return NewTwitterProvider(), nil
case NameMicrosoft:
return NewMicrosoftProvider(), nil
case NameSpotify:
return NewSpotifyProvider(), nil
case NameKakao:
return NewKakaoProvider(), nil
case NameTwitch:
return NewTwitchProvider(), nil
case NameStrava:
return NewStravaProvider(), nil
case NameGitee:
return NewGiteeProvider(), nil
case NameLivechat:
return NewLivechatProvider(), nil
case NameGitea:
return NewGiteaProvider(), nil
case NameOIDC:
return NewOIDCProvider(), nil
case NameOIDC + "2":
return NewOIDCProvider(), nil
case NameOIDC + "3":
return NewOIDCProvider(), nil
case NameApple:
return NewAppleProvider(), nil
case NameInstagram:
return NewInstagramProvider(), nil
case NameVK:
return NewVKProvider(), nil
case NameYandex:
return NewYandexProvider(), nil
case NamePatreon:
return NewPatreonProvider(), nil
case NameMailcow:
return NewMailcowProvider(), nil
case NameBitbucket:
return NewBitbucketProvider(), nil
case NamePlanningcenter:
return NewPlanningcenterProvider(), nil
default:
return nil, errors.New("Missing provider " + name)
// wrapFactory is a helper that wraps a Provider specific factory
// function and returns its result as Provider interface.
func wrapFactory[T Provider](factory func() T) ProviderFactoryFunc {
return func() Provider {
return factory()
}
}
// AuthUser defines a standardized OAuth2 user data structure.
type AuthUser struct {
Expiry types.DateTime `json:"expiry"`
RawUser map[string]any `json:"rawUser"`
Id string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
AvatarURL string `json:"avatarURL"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
// @todo
// deprecated: use AvatarURL instead
// AvatarUrl will be removed after dropping v0.22 support
AvatarUrl string `json:"avatarUrl"`
}
// MarshalJSON implements the [json.Marshaler] interface.
//
// @todo remove after dropping v0.22 support
func (au AuthUser) MarshalJSON() ([]byte, error) {
type alias AuthUser // prevent recursion
au2 := alias(au)
au2.AvatarURL = au.AvatarURL // ensure that the legacy field is populated
return json.Marshal(au2)
}
+8
View File
@@ -6,6 +6,14 @@ import (
"github.com/pocketbase/pocketbase/tools/auth"
)
func TestProvidersCount(t *testing.T) {
expected := 25
if total := len(auth.Providers); total != expected {
t.Fatalf("Expected %d providers, got %d", expected, total)
}
}
func TestNewProviderByName(t *testing.T) {
var err error
var p auth.Provider
+57 -57
View File
@@ -9,147 +9,147 @@ import (
"golang.org/x/oauth2"
)
// baseProvider defines common fields and methods used by OAuth2 client providers.
type baseProvider struct {
// BaseProvider defines common fields and methods used by OAuth2 client providers.
type BaseProvider struct {
ctx context.Context
clientId string
clientSecret string
displayName string
redirectUrl string
authUrl string
tokenUrl string
userApiUrl string
redirectURL string
authURL string
tokenURL string
userInfoURL string
scopes []string
pkce bool
}
// Context implements Provider.Context() interface method.
func (p *baseProvider) Context() context.Context {
func (p *BaseProvider) Context() context.Context {
return p.ctx
}
// SetContext implements Provider.SetContext() interface method.
func (p *baseProvider) SetContext(ctx context.Context) {
func (p *BaseProvider) SetContext(ctx context.Context) {
p.ctx = ctx
}
// PKCE implements Provider.PKCE() interface method.
func (p *baseProvider) PKCE() bool {
func (p *BaseProvider) PKCE() bool {
return p.pkce
}
// SetPKCE implements Provider.SetPKCE() interface method.
func (p *baseProvider) SetPKCE(enable bool) {
func (p *BaseProvider) SetPKCE(enable bool) {
p.pkce = enable
}
// DisplayName implements Provider.DisplayName() interface method.
func (p *baseProvider) DisplayName() string {
func (p *BaseProvider) DisplayName() string {
return p.displayName
}
// SetDisplayName implements Provider.SetDisplayName() interface method.
func (p *baseProvider) SetDisplayName(displayName string) {
func (p *BaseProvider) SetDisplayName(displayName string) {
p.displayName = displayName
}
// Scopes implements Provider.Scopes() interface method.
func (p *baseProvider) Scopes() []string {
func (p *BaseProvider) Scopes() []string {
return p.scopes
}
// SetScopes implements Provider.SetScopes() interface method.
func (p *baseProvider) SetScopes(scopes []string) {
func (p *BaseProvider) SetScopes(scopes []string) {
p.scopes = scopes
}
// ClientId implements Provider.ClientId() interface method.
func (p *baseProvider) ClientId() string {
func (p *BaseProvider) ClientId() string {
return p.clientId
}
// SetClientId implements Provider.SetClientId() interface method.
func (p *baseProvider) SetClientId(clientId string) {
func (p *BaseProvider) SetClientId(clientId string) {
p.clientId = clientId
}
// ClientSecret implements Provider.ClientSecret() interface method.
func (p *baseProvider) ClientSecret() string {
func (p *BaseProvider) ClientSecret() string {
return p.clientSecret
}
// SetClientSecret implements Provider.SetClientSecret() interface method.
func (p *baseProvider) SetClientSecret(secret string) {
func (p *BaseProvider) SetClientSecret(secret string) {
p.clientSecret = secret
}
// RedirectUrl implements Provider.RedirectUrl() interface method.
func (p *baseProvider) RedirectUrl() string {
return p.redirectUrl
// RedirectURL implements Provider.RedirectURL() interface method.
func (p *BaseProvider) RedirectURL() string {
return p.redirectURL
}
// SetRedirectUrl implements Provider.SetRedirectUrl() interface method.
func (p *baseProvider) SetRedirectUrl(url string) {
p.redirectUrl = url
// SetRedirectURL implements Provider.SetRedirectURL() interface method.
func (p *BaseProvider) SetRedirectURL(url string) {
p.redirectURL = url
}
// AuthUrl implements Provider.AuthUrl() interface method.
func (p *baseProvider) AuthUrl() string {
return p.authUrl
// AuthURL implements Provider.AuthURL() interface method.
func (p *BaseProvider) AuthURL() string {
return p.authURL
}
// SetAuthUrl implements Provider.SetAuthUrl() interface method.
func (p *baseProvider) SetAuthUrl(url string) {
p.authUrl = url
// SetAuthURL implements Provider.SetAuthURL() interface method.
func (p *BaseProvider) SetAuthURL(url string) {
p.authURL = url
}
// TokenUrl implements Provider.TokenUrl() interface method.
func (p *baseProvider) TokenUrl() string {
return p.tokenUrl
// TokenURL implements Provider.TokenURL() interface method.
func (p *BaseProvider) TokenURL() string {
return p.tokenURL
}
// SetTokenUrl implements Provider.SetTokenUrl() interface method.
func (p *baseProvider) SetTokenUrl(url string) {
p.tokenUrl = url
// SetTokenURL implements Provider.SetTokenURL() interface method.
func (p *BaseProvider) SetTokenURL(url string) {
p.tokenURL = url
}
// UserApiUrl implements Provider.UserApiUrl() interface method.
func (p *baseProvider) UserApiUrl() string {
return p.userApiUrl
// UserInfoURL implements Provider.UserInfoURL() interface method.
func (p *BaseProvider) UserInfoURL() string {
return p.userInfoURL
}
// SetUserApiUrl implements Provider.SetUserApiUrl() interface method.
func (p *baseProvider) SetUserApiUrl(url string) {
p.userApiUrl = url
// SetUserInfoURL implements Provider.SetUserInfoURL() interface method.
func (p *BaseProvider) SetUserInfoURL(url string) {
p.userInfoURL = url
}
// BuildAuthUrl implements Provider.BuildAuthUrl() interface method.
func (p *baseProvider) BuildAuthUrl(state string, opts ...oauth2.AuthCodeOption) string {
// BuildAuthURL implements Provider.BuildAuthURL() interface method.
func (p *BaseProvider) BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string {
return p.oauth2Config().AuthCodeURL(state, opts...)
}
// FetchToken implements Provider.FetchToken() interface method.
func (p *baseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
func (p *BaseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.oauth2Config().Exchange(p.ctx, code, opts...)
}
// Client implements Provider.Client() interface method.
func (p *baseProvider) Client(token *oauth2.Token) *http.Client {
func (p *BaseProvider) Client(token *oauth2.Token) *http.Client {
return p.oauth2Config().Client(p.ctx, token)
}
// FetchRawUserData implements Provider.FetchRawUserData() interface method.
func (p *baseProvider) FetchRawUserData(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userApiUrl, nil)
// FetchRawUserInfo implements Provider.FetchRawUserInfo() interface method.
func (p *BaseProvider) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
if err != nil {
return nil, err
}
return p.sendRawUserDataRequest(req, token)
return p.sendRawUserInfoRequest(req, token)
}
// sendRawUserDataRequest sends the specified user data request and return its raw response body.
func (p *baseProvider) sendRawUserDataRequest(req *http.Request, token *oauth2.Token) ([]byte, error) {
// sendRawUserInfoRequest sends the specified user info request and return its raw response body.
func (p *BaseProvider) sendRawUserInfoRequest(req *http.Request, token *oauth2.Token) ([]byte, error) {
client := p.Client(token)
res, err := client.Do(req)
@@ -167,7 +167,7 @@ func (p *baseProvider) sendRawUserDataRequest(req *http.Request, token *oauth2.T
if res.StatusCode >= 400 {
return nil, fmt.Errorf(
"failed to fetch OAuth2 user profile via %s (%d):\n%s",
p.userApiUrl,
p.userInfoURL,
res.StatusCode,
string(result),
)
@@ -177,15 +177,15 @@ func (p *baseProvider) sendRawUserDataRequest(req *http.Request, token *oauth2.T
}
// oauth2Config constructs a oauth2.Config instance based on the provider settings.
func (p *baseProvider) oauth2Config() *oauth2.Config {
func (p *BaseProvider) oauth2Config() *oauth2.Config {
return &oauth2.Config{
RedirectURL: p.redirectUrl,
RedirectURL: p.redirectURL,
ClientID: p.clientId,
ClientSecret: p.clientSecret,
Scopes: p.scopes,
Endpoint: oauth2.Endpoint{
AuthURL: p.authUrl,
TokenURL: p.tokenUrl,
AuthURL: p.authURL,
TokenURL: p.tokenURL,
},
}
}
+52 -52
View File
@@ -8,7 +8,7 @@ import (
)
func TestContext(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
before := b.Scopes()
if before != nil {
@@ -24,7 +24,7 @@ func TestContext(t *testing.T) {
}
func TestDisplayName(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
before := b.DisplayName()
if before != "" {
@@ -40,7 +40,7 @@ func TestDisplayName(t *testing.T) {
}
func TestPKCE(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
before := b.PKCE()
if before != false {
@@ -56,7 +56,7 @@ func TestPKCE(t *testing.T) {
}
func TestScopes(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
before := b.Scopes()
if len(before) != 0 {
@@ -72,7 +72,7 @@ func TestScopes(t *testing.T) {
}
func TestClientId(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
before := b.ClientId()
if before != "" {
@@ -88,7 +88,7 @@ func TestClientId(t *testing.T) {
}
func TestClientSecret(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
before := b.ClientSecret()
if before != "" {
@@ -103,82 +103,82 @@ func TestClientSecret(t *testing.T) {
}
}
func TestRedirectUrl(t *testing.T) {
b := baseProvider{}
func TestRedirectURL(t *testing.T) {
b := BaseProvider{}
before := b.RedirectUrl()
before := b.RedirectURL()
if before != "" {
t.Fatalf("Expected RedirectUrl to be empty, got %v", before)
t.Fatalf("Expected RedirectURL to be empty, got %v", before)
}
b.SetRedirectUrl("test")
b.SetRedirectURL("test")
after := b.RedirectUrl()
after := b.RedirectURL()
if after != "test" {
t.Fatalf("Expected RedirectUrl to be 'test', got %v", after)
t.Fatalf("Expected RedirectURL to be 'test', got %v", after)
}
}
func TestAuthUrl(t *testing.T) {
b := baseProvider{}
func TestAuthURL(t *testing.T) {
b := BaseProvider{}
before := b.AuthUrl()
before := b.AuthURL()
if before != "" {
t.Fatalf("Expected authUrl to be empty, got %v", before)
t.Fatalf("Expected authURL to be empty, got %v", before)
}
b.SetAuthUrl("test")
b.SetAuthURL("test")
after := b.AuthUrl()
after := b.AuthURL()
if after != "test" {
t.Fatalf("Expected authUrl to be 'test', got %v", after)
t.Fatalf("Expected authURL to be 'test', got %v", after)
}
}
func TestTokenUrl(t *testing.T) {
b := baseProvider{}
func TestTokenURL(t *testing.T) {
b := BaseProvider{}
before := b.TokenUrl()
before := b.TokenURL()
if before != "" {
t.Fatalf("Expected tokenUrl to be empty, got %v", before)
t.Fatalf("Expected tokenURL to be empty, got %v", before)
}
b.SetTokenUrl("test")
b.SetTokenURL("test")
after := b.TokenUrl()
after := b.TokenURL()
if after != "test" {
t.Fatalf("Expected tokenUrl to be 'test', got %v", after)
t.Fatalf("Expected tokenURL to be 'test', got %v", after)
}
}
func TestUserApiUrl(t *testing.T) {
b := baseProvider{}
func TestUserInfoURL(t *testing.T) {
b := BaseProvider{}
before := b.UserApiUrl()
before := b.UserInfoURL()
if before != "" {
t.Fatalf("Expected userApiUrl to be empty, got %v", before)
t.Fatalf("Expected userInfoURL to be empty, got %v", before)
}
b.SetUserApiUrl("test")
b.SetUserInfoURL("test")
after := b.UserApiUrl()
after := b.UserInfoURL()
if after != "test" {
t.Fatalf("Expected userApiUrl to be 'test', got %v", after)
t.Fatalf("Expected userInfoURL to be 'test', got %v", after)
}
}
func TestBuildAuthUrl(t *testing.T) {
b := baseProvider{
authUrl: "authUrl_test",
tokenUrl: "tokenUrl_test",
redirectUrl: "redirectUrl_test",
func TestBuildAuthURL(t *testing.T) {
b := BaseProvider{
authURL: "authURL_test",
tokenURL: "tokenURL_test",
redirectURL: "redirectURL_test",
clientId: "clientId_test",
clientSecret: "clientSecret_test",
scopes: []string{"test_scope"},
}
expected := "authUrl_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectUrl_test&response_type=code&scope=test_scope&state=state_test"
result := b.BuildAuthUrl("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
expected := "authURL_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectURL_test&response_type=code&scope=test_scope&state=state_test"
result := b.BuildAuthURL("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
if result != expected {
t.Errorf("Expected auth url %q, got %q", expected, result)
@@ -186,7 +186,7 @@ func TestBuildAuthUrl(t *testing.T) {
}
func TestClient(t *testing.T) {
b := baseProvider{}
b := BaseProvider{}
result := b.Client(&oauth2.Token{})
if result == nil {
@@ -195,10 +195,10 @@ func TestClient(t *testing.T) {
}
func TestOauth2Config(t *testing.T) {
b := baseProvider{
authUrl: "authUrl_test",
tokenUrl: "tokenUrl_test",
redirectUrl: "redirectUrl_test",
b := BaseProvider{
authURL: "authURL_test",
tokenURL: "tokenURL_test",
redirectURL: "redirectURL_test",
clientId: "clientId_test",
clientSecret: "clientSecret_test",
scopes: []string{"test"},
@@ -206,8 +206,8 @@ func TestOauth2Config(t *testing.T) {
result := b.oauth2Config()
if result.RedirectURL != b.RedirectUrl() {
t.Errorf("Expected redirectUrl %s, got %s", b.RedirectUrl(), result.RedirectURL)
if result.RedirectURL != b.RedirectURL() {
t.Errorf("Expected redirectURL %s, got %s", b.RedirectURL(), result.RedirectURL)
}
if result.ClientID != b.ClientId() {
@@ -218,12 +218,12 @@ func TestOauth2Config(t *testing.T) {
t.Errorf("Expected clientSecret %s, got %s", b.ClientSecret(), result.ClientSecret)
}
if result.Endpoint.AuthURL != b.AuthUrl() {
t.Errorf("Expected authUrl %s, got %s", b.AuthUrl(), result.Endpoint.AuthURL)
if result.Endpoint.AuthURL != b.AuthURL() {
t.Errorf("Expected authURL %s, got %s", b.AuthURL(), result.Endpoint.AuthURL)
}
if result.Endpoint.TokenURL != b.TokenUrl() {
t.Errorf("Expected authUrl %s, got %s", b.TokenUrl(), result.Endpoint.TokenURL)
if result.Endpoint.TokenURL != b.TokenURL() {
t.Errorf("Expected authURL %s, got %s", b.TokenURL(), result.Endpoint.TokenURL)
}
if len(result.Scopes) != len(b.Scopes()) || result.Scopes[0] != b.Scopes()[0] {
+12 -8
View File
@@ -10,6 +10,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameBitbucket] = wrapFactory(NewBitbucketProvider)
}
var _ Provider = (*Bitbucket)(nil)
// NameBitbucket is the unique name of the Bitbucket provider.
@@ -17,19 +21,19 @@ const NameBitbucket = "bitbucket"
// Bitbucket is an auth provider for Bitbucket.
type Bitbucket struct {
*baseProvider
BaseProvider
}
// NewBitbucketProvider creates a new Bitbucket provider instance with some defaults.
func NewBitbucketProvider() *Bitbucket {
return &Bitbucket{&baseProvider{
return &Bitbucket{BaseProvider{
ctx: context.Background(),
displayName: "Bitbucket",
pkce: false,
scopes: []string{"account"},
authUrl: "https://bitbucket.org/site/oauth2/authorize",
tokenUrl: "https://bitbucket.org/site/oauth2/access_token",
userApiUrl: "https://api.bitbucket.org/2.0/user",
authURL: "https://bitbucket.org/site/oauth2/authorize",
tokenURL: "https://bitbucket.org/site/oauth2/access_token",
userInfoURL: "https://api.bitbucket.org/2.0/user",
}}
}
@@ -37,7 +41,7 @@ func NewBitbucketProvider() *Bitbucket {
//
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-get
func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -76,7 +80,7 @@ func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name: extracted.DisplayName,
Username: extracted.Username,
Email: email,
AvatarUrl: extracted.Links.Avatar.Href,
AvatarURL: extracted.Links.Avatar.Href,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
@@ -95,7 +99,7 @@ func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
//
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-emails-get
func (p *Bitbucket) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
response, err := p.Client(token).Get(p.userApiUrl + "/emails")
response, err := p.Client(token).Get(p.userInfoURL + "/emails")
if err != nil {
return "", err
}
+12 -8
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameDiscord] = wrapFactory(NewDiscordProvider)
}
var _ Provider = (*Discord)(nil)
// NameDiscord is the unique name of the Discord provider.
@@ -16,21 +20,21 @@ const NameDiscord string = "discord"
// Discord allows authentication via Discord OAuth2.
type Discord struct {
*baseProvider
BaseProvider
}
// NewDiscordProvider creates a new Discord provider instance with some defaults.
func NewDiscordProvider() *Discord {
// https://discord.com/developers/docs/topics/oauth2
// https://discord.com/developers/docs/resources/user#get-current-user
return &Discord{&baseProvider{
return &Discord{BaseProvider{
ctx: context.Background(),
displayName: "Discord",
pkce: true,
scopes: []string{"identify", "email"},
authUrl: "https://discord.com/api/oauth2/authorize",
tokenUrl: "https://discord.com/api/oauth2/token",
userApiUrl: "https://discord.com/api/users/@me",
authURL: "https://discord.com/api/oauth2/authorize",
tokenURL: "https://discord.com/api/oauth2/token",
userInfoURL: "https://discord.com/api/users/@me",
}}
}
@@ -38,7 +42,7 @@ func NewDiscordProvider() *Discord {
//
// API reference: https://discord.com/developers/docs/resources/user#user-object
func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -62,7 +66,7 @@ func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
// Build a full avatar URL using the avatar hash provided in the API response
// https://discord.com/developers/docs/reference#image-formatting
avatarUrl := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", extracted.Id, extracted.Avatar)
avatarURL := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", extracted.Id, extracted.Avatar)
// Concatenate the user's username and discriminator into a single username string
username := fmt.Sprintf("%s#%s", extracted.Username, extracted.Discriminator)
@@ -71,7 +75,7 @@ func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: extracted.Id,
Name: username,
Username: extracted.Username,
AvatarUrl: avatarUrl,
AvatarURL: avatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+11 -7
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2/facebook"
)
func init() {
Providers[NameFacebook] = wrapFactory(NewFacebookProvider)
}
var _ Provider = (*Facebook)(nil)
// NameFacebook is the unique name of the Facebook provider.
@@ -16,19 +20,19 @@ const NameFacebook string = "facebook"
// Facebook allows authentication via Facebook OAuth2.
type Facebook struct {
*baseProvider
BaseProvider
}
// NewFacebookProvider creates new Facebook provider instance with some defaults.
func NewFacebookProvider() *Facebook {
return &Facebook{&baseProvider{
return &Facebook{BaseProvider{
ctx: context.Background(),
displayName: "Facebook",
pkce: true,
scopes: []string{"email"},
authUrl: facebook.Endpoint.AuthURL,
tokenUrl: facebook.Endpoint.TokenURL,
userApiUrl: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
authURL: facebook.Endpoint.AuthURL,
tokenURL: facebook.Endpoint.TokenURL,
userInfoURL: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
}}
}
@@ -36,7 +40,7 @@ func NewFacebookProvider() *Facebook {
//
// API reference: https://developers.facebook.com/docs/graph-api/reference/user/
func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -62,7 +66,7 @@ func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: extracted.Id,
Name: extracted.Name,
Email: extracted.Email,
AvatarUrl: extracted.Picture.Data.Url,
AvatarURL: extracted.Picture.Data.Url,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameGitea] = wrapFactory(NewGiteaProvider)
}
var _ Provider = (*Gitea)(nil)
// NameGitea is the unique name of the Gitea provider.
@@ -16,19 +20,19 @@ const NameGitea string = "gitea"
// Gitea allows authentication via Gitea OAuth2.
type Gitea struct {
*baseProvider
BaseProvider
}
// NewGiteaProvider creates new Gitea provider instance with some defaults.
func NewGiteaProvider() *Gitea {
return &Gitea{&baseProvider{
return &Gitea{BaseProvider{
ctx: context.Background(),
displayName: "Gitea",
pkce: true,
scopes: []string{"read:user", "user:email"},
authUrl: "https://gitea.com/login/oauth/authorize",
tokenUrl: "https://gitea.com/login/oauth/access_token",
userApiUrl: "https://gitea.com/api/v1/user",
authURL: "https://gitea.com/login/oauth/authorize",
tokenURL: "https://gitea.com/login/oauth/access_token",
userInfoURL: "https://gitea.com/api/v1/user",
}}
}
@@ -36,7 +40,7 @@ func NewGiteaProvider() *Gitea {
//
// API reference: https://try.gitea.io/api/swagger#/user/userGetCurrent
func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -51,7 +55,7 @@ func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name string `json:"full_name"`
Username string `json:"login"`
Email string `json:"email"`
AvatarUrl string `json:"avatar_url"`
AvatarURL string `json:"avatar_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
@@ -62,7 +66,7 @@ func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name: extracted.Name,
Username: extracted.Username,
Email: extracted.Email,
AvatarUrl: extracted.AvatarUrl,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -11,6 +11,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameGitee] = wrapFactory(NewGiteeProvider)
}
var _ Provider = (*Gitee)(nil)
// NameGitee is the unique name of the Gitee provider.
@@ -18,19 +22,19 @@ const NameGitee string = "gitee"
// Gitee allows authentication via Gitee OAuth2.
type Gitee struct {
*baseProvider
BaseProvider
}
// NewGiteeProvider creates new Gitee provider instance with some defaults.
func NewGiteeProvider() *Gitee {
return &Gitee{&baseProvider{
return &Gitee{BaseProvider{
ctx: context.Background(),
displayName: "Gitee",
pkce: true,
scopes: []string{"user_info", "emails"},
authUrl: "https://gitee.com/oauth/authorize",
tokenUrl: "https://gitee.com/oauth/token",
userApiUrl: "https://gitee.com/api/v5/user",
authURL: "https://gitee.com/oauth/authorize",
tokenURL: "https://gitee.com/oauth/token",
userInfoURL: "https://gitee.com/api/v5/user",
}}
}
@@ -38,7 +42,7 @@ func NewGiteeProvider() *Gitee {
//
// API reference: https://gitee.com/api/v5/swagger#/getV5User
func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -53,7 +57,7 @@ func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarUrl string `json:"avatar_url"`
AvatarURL string `json:"avatar_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
@@ -63,7 +67,7 @@ func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: strconv.Itoa(extracted.Id),
Name: extracted.Name,
Username: extracted.Login,
AvatarUrl: extracted.AvatarUrl,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+13 -9
View File
@@ -11,6 +11,10 @@ import (
"golang.org/x/oauth2/github"
)
func init() {
Providers[NameGithub] = wrapFactory(NewGithubProvider)
}
var _ Provider = (*Github)(nil)
// NameGithub is the unique name of the Github provider.
@@ -18,19 +22,19 @@ const NameGithub string = "github"
// Github allows authentication via Github OAuth2.
type Github struct {
*baseProvider
BaseProvider
}
// NewGithubProvider creates new Github provider instance with some defaults.
func NewGithubProvider() *Github {
return &Github{&baseProvider{
return &Github{BaseProvider{
ctx: context.Background(),
displayName: "GitHub",
pkce: true, // technically is not supported yet but it is safe as the PKCE params are just ignored
scopes: []string{"read:user", "user:email"},
authUrl: github.Endpoint.AuthURL,
tokenUrl: github.Endpoint.TokenURL,
userApiUrl: "https://api.github.com/user",
authURL: github.Endpoint.AuthURL,
tokenURL: github.Endpoint.TokenURL,
userInfoURL: "https://api.github.com/user",
}}
}
@@ -38,7 +42,7 @@ func NewGithubProvider() *Github {
//
// API reference: https://docs.github.com/en/rest/reference/users#get-the-authenticated-user
func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -53,7 +57,7 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarUrl string `json:"avatar_url"`
AvatarURL string `json:"avatar_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
@@ -64,7 +68,7 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name: extracted.Name,
Username: extracted.Login,
Email: extracted.Email,
AvatarUrl: extracted.AvatarUrl,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
@@ -95,7 +99,7 @@ func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
func (p *Github) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
client := p.Client(token)
response, err := client.Get(p.userApiUrl + "/emails")
response, err := client.Get(p.userInfoURL + "/emails")
if err != nil {
return "", err
}
+12 -8
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameGitlab] = wrapFactory(NewGitlabProvider)
}
var _ Provider = (*Gitlab)(nil)
// NameGitlab is the unique name of the Gitlab provider.
@@ -16,19 +20,19 @@ const NameGitlab string = "gitlab"
// Gitlab allows authentication via Gitlab OAuth2.
type Gitlab struct {
*baseProvider
BaseProvider
}
// NewGitlabProvider creates new Gitlab provider instance with some defaults.
func NewGitlabProvider() *Gitlab {
return &Gitlab{&baseProvider{
return &Gitlab{BaseProvider{
ctx: context.Background(),
displayName: "GitLab",
pkce: true,
scopes: []string{"read_user"},
authUrl: "https://gitlab.com/oauth/authorize",
tokenUrl: "https://gitlab.com/oauth/token",
userApiUrl: "https://gitlab.com/api/v4/user",
authURL: "https://gitlab.com/oauth/authorize",
tokenURL: "https://gitlab.com/oauth/token",
userInfoURL: "https://gitlab.com/api/v4/user",
}}
}
@@ -36,7 +40,7 @@ func NewGitlabProvider() *Gitlab {
//
// API reference: https://docs.gitlab.com/ee/api/users.html#for-admin
func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -51,7 +55,7 @@ func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
AvatarUrl string `json:"avatar_url"`
AvatarURL string `json:"avatar_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
@@ -62,7 +66,7 @@ func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name: extracted.Name,
Username: extracted.Username,
Email: extracted.Email,
AvatarUrl: extracted.AvatarUrl,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+11 -7
View File
@@ -8,6 +8,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameGoogle] = wrapFactory(NewGoogleProvider)
}
var _ Provider = (*Google)(nil)
// NameGoogle is the unique name of the Google provider.
@@ -15,12 +19,12 @@ const NameGoogle string = "google"
// Google allows authentication via Google OAuth2.
type Google struct {
*baseProvider
BaseProvider
}
// NewGoogleProvider creates new Google provider instance with some defaults.
func NewGoogleProvider() *Google {
return &Google{&baseProvider{
return &Google{BaseProvider{
ctx: context.Background(),
displayName: "Google",
pkce: true,
@@ -28,15 +32,15 @@ func NewGoogleProvider() *Google {
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/userinfo.email",
},
authUrl: "https://accounts.google.com/o/oauth2/auth",
tokenUrl: "https://accounts.google.com/o/oauth2/token",
userApiUrl: "https://www.googleapis.com/oauth2/v1/userinfo",
authURL: "https://accounts.google.com/o/oauth2/auth",
tokenURL: "https://accounts.google.com/o/oauth2/token",
userInfoURL: "https://www.googleapis.com/oauth2/v1/userinfo",
}}
}
// FetchAuthUser returns an AuthUser instance based the Google's user api.
func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -60,7 +64,7 @@ func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
AvatarUrl: extracted.Picture,
AvatarURL: extracted.Picture,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+10 -6
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2/instagram"
)
func init() {
Providers[NameInstagram] = wrapFactory(NewInstagramProvider)
}
var _ Provider = (*Instagram)(nil)
// NameInstagram is the unique name of the Instagram provider.
@@ -16,19 +20,19 @@ const NameInstagram string = "instagram"
// Instagram allows authentication via Instagram OAuth2.
type Instagram struct {
*baseProvider
BaseProvider
}
// NewInstagramProvider creates new Instagram provider instance with some defaults.
func NewInstagramProvider() *Instagram {
return &Instagram{&baseProvider{
return &Instagram{BaseProvider{
ctx: context.Background(),
displayName: "Instagram",
pkce: true,
scopes: []string{"user_profile"},
authUrl: instagram.Endpoint.AuthURL,
tokenUrl: instagram.Endpoint.TokenURL,
userApiUrl: "https://graph.instagram.com/me?fields=id,username,account_type",
authURL: instagram.Endpoint.AuthURL,
tokenURL: instagram.Endpoint.TokenURL,
userInfoURL: "https://graph.instagram.com/me?fields=id,username,account_type",
}}
}
@@ -36,7 +40,7 @@ func NewInstagramProvider() *Instagram {
//
// API reference: https://developers.facebook.com/docs/instagram-basic-display-api/reference/user#fields
func (p *Instagram) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
+12 -8
View File
@@ -10,6 +10,10 @@ import (
"golang.org/x/oauth2/kakao"
)
func init() {
Providers[NameKakao] = wrapFactory(NewKakaoProvider)
}
var _ Provider = (*Kakao)(nil)
// NameKakao is the unique name of the Kakao provider.
@@ -17,19 +21,19 @@ const NameKakao string = "kakao"
// Kakao allows authentication via Kakao OAuth2.
type Kakao struct {
*baseProvider
BaseProvider
}
// NewKakaoProvider creates a new Kakao provider instance with some defaults.
func NewKakaoProvider() *Kakao {
return &Kakao{&baseProvider{
return &Kakao{BaseProvider{
ctx: context.Background(),
displayName: "Kakao",
pkce: true,
scopes: []string{"account_email", "profile_nickname", "profile_image"},
authUrl: kakao.Endpoint.AuthURL,
tokenUrl: kakao.Endpoint.TokenURL,
userApiUrl: "https://kapi.kakao.com/v2/user/me",
authURL: kakao.Endpoint.AuthURL,
tokenURL: kakao.Endpoint.TokenURL,
userInfoURL: "https://kapi.kakao.com/v2/user/me",
}}
}
@@ -37,7 +41,7 @@ func NewKakaoProvider() *Kakao {
//
// API reference: https://developers.kakao.com/docs/latest/en/kakaologin/rest-api#req-user-info-response
func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -51,7 +55,7 @@ func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id int `json:"id"`
Profile struct {
Nickname string `json:"nickname"`
ImageUrl string `json:"profile_image"`
ImageURL string `json:"profile_image"`
} `json:"properties"`
KakaoAccount struct {
Email string `json:"email"`
@@ -66,7 +70,7 @@ func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user := &AuthUser{
Id: strconv.Itoa(extracted.Id),
Username: extracted.Profile.Nickname,
AvatarUrl: extracted.Profile.ImageUrl,
AvatarURL: extracted.Profile.ImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -8,6 +8,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameLivechat] = wrapFactory(NewLivechatProvider)
}
var _ Provider = (*Livechat)(nil)
// NameLivechat is the unique name of the Livechat provider.
@@ -15,19 +19,19 @@ const NameLivechat = "livechat"
// Livechat allows authentication via Livechat OAuth2.
type Livechat struct {
*baseProvider
BaseProvider
}
// NewLivechatProvider creates new Livechat provider instance with some defaults.
func NewLivechatProvider() *Livechat {
return &Livechat{&baseProvider{
return &Livechat{BaseProvider{
ctx: context.Background(),
displayName: "LiveChat",
pkce: true,
scopes: []string{}, // default scopes are specified from the provider dashboard
authUrl: "https://accounts.livechat.com/",
tokenUrl: "https://accounts.livechat.com/token",
userApiUrl: "https://accounts.livechat.com/v2/accounts/me",
authURL: "https://accounts.livechat.com/",
tokenURL: "https://accounts.livechat.com/token",
userInfoURL: "https://accounts.livechat.com/v2/accounts/me",
}}
}
@@ -35,7 +39,7 @@ func NewLivechatProvider() *Livechat {
//
// API reference: https://developers.livechat.com/docs/authorization
func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -50,7 +54,7 @@ func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name string `json:"name"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
AvatarUrl string `json:"avatar_url"`
AvatarURL string `json:"avatar_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
@@ -59,7 +63,7 @@ func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
AvatarUrl: extracted.AvatarUrl,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+7 -3
View File
@@ -10,6 +10,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameMailcow] = wrapFactory(NewMailcowProvider)
}
var _ Provider = (*Mailcow)(nil)
// NameMailcow is the unique name of the mailcow provider.
@@ -17,12 +21,12 @@ const NameMailcow string = "mailcow"
// Mailcow allows authentication via mailcow OAuth2.
type Mailcow struct {
*baseProvider
BaseProvider
}
// NewMailcowProvider creates a new mailcow provider instance with some defaults.
func NewMailcowProvider() *Mailcow {
return &Mailcow{&baseProvider{
return &Mailcow{BaseProvider{
ctx: context.Background(),
displayName: "mailcow",
pkce: true,
@@ -34,7 +38,7 @@ func NewMailcowProvider() *Mailcow {
//
// API reference: https://github.com/mailcow/mailcow-dockerized/blob/master/data/web/oauth/profile.php
func (p *Mailcow) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
+10 -6
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2/microsoft"
)
func init() {
Providers[NameMicrosoft] = wrapFactory(NewMicrosoftProvider)
}
var _ Provider = (*Microsoft)(nil)
// NameMicrosoft is the unique name of the Microsoft provider.
@@ -16,20 +20,20 @@ const NameMicrosoft string = "microsoft"
// Microsoft allows authentication via AzureADEndpoint OAuth2.
type Microsoft struct {
*baseProvider
BaseProvider
}
// NewMicrosoftProvider creates new Microsoft AD provider instance with some defaults.
func NewMicrosoftProvider() *Microsoft {
endpoints := microsoft.AzureADEndpoint("")
return &Microsoft{&baseProvider{
return &Microsoft{BaseProvider{
ctx: context.Background(),
displayName: "Microsoft",
pkce: true,
scopes: []string{"User.Read"},
authUrl: endpoints.AuthURL,
tokenUrl: endpoints.TokenURL,
userApiUrl: "https://graph.microsoft.com/v1.0/me",
authURL: endpoints.AuthURL,
tokenURL: endpoints.TokenURL,
userInfoURL: "https://graph.microsoft.com/v1.0/me",
}}
}
@@ -38,7 +42,7 @@ func NewMicrosoftProvider() *Microsoft {
// API reference: https://learn.microsoft.com/en-us/azure/active-directory/develop/userinfo
// Graph explorer: https://developer.microsoft.com/en-us/graph/graph-explorer
func (p *Microsoft) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
+12 -4
View File
@@ -8,6 +8,12 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameOIDC] = wrapFactory(NewOIDCProvider)
Providers[NameOIDC+"2"] = wrapFactory(NewOIDCProvider)
Providers[NameOIDC+"3"] = wrapFactory(NewOIDCProvider)
}
var _ Provider = (*OIDC)(nil)
// NameOIDC is the unique name of the OpenID Connect (OIDC) provider.
@@ -15,12 +21,12 @@ const NameOIDC string = "oidc"
// OIDC allows authentication via OpenID Connect (OIDC) OAuth2 provider.
type OIDC struct {
*baseProvider
BaseProvider
}
// NewOIDCProvider creates new OpenID Connect (OIDC) provider instance with some defaults.
func NewOIDCProvider() *OIDC {
return &OIDC{&baseProvider{
return &OIDC{BaseProvider{
ctx: context.Background(),
displayName: "OIDC",
pkce: true,
@@ -35,8 +41,10 @@ func NewOIDCProvider() *OIDC {
// FetchAuthUser returns an AuthUser instance based the provider's user api.
//
// API reference: https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
//
// @todo consider adding support for reading the user data from the id_token.
func (p *OIDC) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -62,7 +70,7 @@ func (p *OIDC) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: extracted.Id,
Name: extracted.Name,
Username: extracted.Username,
AvatarUrl: extracted.Picture,
AvatarURL: extracted.Picture,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -8,6 +8,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NamePatreon] = wrapFactory(NewPatreonProvider)
}
var _ Provider = (*Patreon)(nil)
// NamePatreon is the unique name of the Patreon provider.
@@ -15,19 +19,19 @@ const NamePatreon string = "patreon"
// Patreon allows authentication via Patreon OAuth2.
type Patreon struct {
*baseProvider
BaseProvider
}
// NewPatreonProvider creates new Patreon provider instance with some defaults.
func NewPatreonProvider() *Patreon {
return &Patreon{&baseProvider{
return &Patreon{BaseProvider{
ctx: context.Background(),
displayName: "Patreon",
pkce: true,
scopes: []string{"identity", "identity[email]"},
authUrl: "https://www.patreon.com/oauth2/authorize",
tokenUrl: "https://www.patreon.com/api/oauth2/token",
userApiUrl: "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=full_name,email,vanity,image_url,is_email_verified",
authURL: "https://www.patreon.com/oauth2/authorize",
tokenURL: "https://www.patreon.com/api/oauth2/token",
userInfoURL: "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=full_name,email,vanity,image_url,is_email_verified",
}}
}
@@ -37,7 +41,7 @@ func NewPatreonProvider() *Patreon {
// https://docs.patreon.com/#get-api-oauth2-v2-identity
// https://docs.patreon.com/#user-v2
func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -54,7 +58,7 @@ func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Email string `json:"email"`
Name string `json:"full_name"`
Username string `json:"vanity"`
AvatarUrl string `json:"image_url"`
AvatarURL string `json:"image_url"`
IsEmailVerified bool `json:"is_email_verified"`
} `json:"attributes"`
} `json:"data"`
@@ -67,7 +71,7 @@ func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: extracted.Data.Id,
Username: extracted.Data.Attributes.Username,
Name: extracted.Data.Attributes.Name,
AvatarUrl: extracted.Data.Attributes.AvatarUrl,
AvatarURL: extracted.Data.Attributes.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NamePlanningcenter] = wrapFactory(NewPlanningcenterProvider)
}
var _ Provider = (*Planningcenter)(nil)
// NamePlanningcenter is the unique name of the Planningcenter provider.
@@ -16,19 +20,19 @@ const NamePlanningcenter string = "planningcenter"
// Planningcenter allows authentication via Planningcenter OAuth2.
type Planningcenter struct {
*baseProvider
BaseProvider
}
// NewPlanningcenterProvider creates a new Planningcenter provider instance with some defaults.
func NewPlanningcenterProvider() *Planningcenter {
return &Planningcenter{&baseProvider{
return &Planningcenter{BaseProvider{
ctx: context.Background(),
displayName: "Planning Center",
pkce: true,
scopes: []string{"people"},
authUrl: "https://api.planningcenteronline.com/oauth/authorize",
tokenUrl: "https://api.planningcenteronline.com/oauth/token",
userApiUrl: "https://api.planningcenteronline.com/people/v2/me",
authURL: "https://api.planningcenteronline.com/oauth/authorize",
tokenURL: "https://api.planningcenteronline.com/oauth/token",
userInfoURL: "https://api.planningcenteronline.com/people/v2/me",
}}
}
@@ -36,7 +40,7 @@ func NewPlanningcenterProvider() *Planningcenter {
//
// API reference: https://developer.planning.center/docs/#/overview/authentication
func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -52,7 +56,7 @@ func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Attributes struct {
Status string `json:"status"`
Name string `json:"name"`
AvatarUrl string `json:"avatar"`
AvatarURL string `json:"avatar"`
// don't map the email because users can have multiple assigned
// and it's not clear if they are verified
}
@@ -69,7 +73,7 @@ func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user := &AuthUser{
Id: extracted.Data.Id,
Name: extracted.Data.Attributes.Name,
AvatarUrl: extracted.Data.Attributes.AvatarUrl,
AvatarURL: extracted.Data.Attributes.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2/spotify"
)
func init() {
Providers[NameSpotify] = wrapFactory(NewSpotifyProvider)
}
var _ Provider = (*Spotify)(nil)
// NameSpotify is the unique name of the Spotify provider.
@@ -16,12 +20,12 @@ const NameSpotify string = "spotify"
// Spotify allows authentication via Spotify OAuth2.
type Spotify struct {
*baseProvider
BaseProvider
}
// NewSpotifyProvider creates a new Spotify provider instance with some defaults.
func NewSpotifyProvider() *Spotify {
return &Spotify{&baseProvider{
return &Spotify{BaseProvider{
ctx: context.Background(),
displayName: "Spotify",
pkce: true,
@@ -30,9 +34,9 @@ func NewSpotifyProvider() *Spotify {
// currently Spotify doesn't return information whether the email is verified or not
// "user-read-email",
},
authUrl: spotify.Endpoint.AuthURL,
tokenUrl: spotify.Endpoint.TokenURL,
userApiUrl: "https://api.spotify.com/v1/me",
authURL: spotify.Endpoint.AuthURL,
tokenURL: spotify.Endpoint.TokenURL,
userInfoURL: "https://api.spotify.com/v1/me",
}}
}
@@ -40,7 +44,7 @@ func NewSpotifyProvider() *Spotify {
//
// API reference: https://developer.spotify.com/documentation/web-api/reference/#/operations/get-current-users-profile
func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -54,7 +58,7 @@ func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id string `json:"id"`
Name string `json:"display_name"`
Images []struct {
Url string `json:"url"`
URL string `json:"url"`
} `json:"images"`
// don't map the email because per the official docs
// the email field is "unverified" and there is no proof
@@ -76,7 +80,7 @@ func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if len(extracted.Images) > 0 {
user.AvatarUrl = extracted.Images[0].Url
user.AvatarURL = extracted.Images[0].URL
}
return user, nil
+12 -8
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameStrava] = wrapFactory(NewStravaProvider)
}
var _ Provider = (*Strava)(nil)
// NameStrava is the unique name of the Strava provider.
@@ -16,21 +20,21 @@ const NameStrava string = "strava"
// Strava allows authentication via Strava OAuth2.
type Strava struct {
*baseProvider
BaseProvider
}
// NewStravaProvider creates new Strava provider instance with some defaults.
func NewStravaProvider() *Strava {
return &Strava{&baseProvider{
return &Strava{BaseProvider{
ctx: context.Background(),
displayName: "Strava",
pkce: true,
scopes: []string{
"profile:read_all",
},
authUrl: "https://www.strava.com/oauth/authorize",
tokenUrl: "https://www.strava.com/api/v3/oauth/token",
userApiUrl: "https://www.strava.com/api/v3/athlete",
authURL: "https://www.strava.com/oauth/authorize",
tokenURL: "https://www.strava.com/api/v3/oauth/token",
userInfoURL: "https://www.strava.com/api/v3/athlete",
}}
}
@@ -38,7 +42,7 @@ func NewStravaProvider() *Strava {
//
// API reference: https://developers.strava.com/docs/authentication/
func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -53,7 +57,7 @@ func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
FirstName string `json:"firstname"`
LastName string `json:"lastname"`
Username string `json:"username"`
ProfileImageUrl string `json:"profile"`
ProfileImageURL string `json:"profile"`
// At the time of writing, Strava OAuth2 doesn't support returning the user email address
// Email string `json:"email"`
@@ -65,7 +69,7 @@ func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user := &AuthUser{
Name: extracted.FirstName + " " + extracted.LastName,
Username: extracted.Username,
AvatarUrl: extracted.ProfileImageUrl,
AvatarURL: extracted.ProfileImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+16 -12
View File
@@ -11,6 +11,10 @@ import (
"golang.org/x/oauth2/twitch"
)
func init() {
Providers[NameTwitch] = wrapFactory(NewTwitchProvider)
}
var _ Provider = (*Twitch)(nil)
// NameTwitch is the unique name of the Twitch provider.
@@ -18,19 +22,19 @@ const NameTwitch string = "twitch"
// Twitch allows authentication via Twitch OAuth2.
type Twitch struct {
*baseProvider
BaseProvider
}
// NewTwitchProvider creates new Twitch provider instance with some defaults.
func NewTwitchProvider() *Twitch {
return &Twitch{&baseProvider{
return &Twitch{BaseProvider{
ctx: context.Background(),
displayName: "Twitch",
pkce: true,
scopes: []string{"user:read:email"},
authUrl: twitch.Endpoint.AuthURL,
tokenUrl: twitch.Endpoint.TokenURL,
userApiUrl: "https://api.twitch.tv/helix/users",
authURL: twitch.Endpoint.AuthURL,
tokenURL: twitch.Endpoint.TokenURL,
userInfoURL: "https://api.twitch.tv/helix/users",
}}
}
@@ -38,7 +42,7 @@ func NewTwitchProvider() *Twitch {
//
// API reference: https://dev.twitch.tv/docs/api/reference#get-users
func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -54,7 +58,7 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Login string `json:"login"`
DisplayName string `json:"display_name"`
Email string `json:"email"`
ProfileImageUrl string `json:"profile_image_url"`
ProfileImageURL string `json:"profile_image_url"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
@@ -70,7 +74,7 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Name: extracted.Data[0].DisplayName,
Username: extracted.Data[0].Login,
Email: extracted.Data[0].Email,
AvatarUrl: extracted.Data[0].ProfileImageUrl,
AvatarURL: extracted.Data[0].ProfileImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
@@ -81,16 +85,16 @@ func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
return user, nil
}
// FetchRawUserData implements Provider.FetchRawUserData interface.
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
//
// This differ from baseProvider because Twitch requires the `Client-Id` header.
func (p *Twitch) FetchRawUserData(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequest("GET", p.userApiUrl, nil)
func (p *Twitch) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequest("GET", p.userInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Client-Id", p.clientId)
return p.sendRawUserDataRequest(req, token)
return p.sendRawUserInfoRequest(req, token)
}
+12 -8
View File
@@ -8,6 +8,10 @@ import (
"golang.org/x/oauth2"
)
func init() {
Providers[NameTwitter] = wrapFactory(NewTwitterProvider)
}
var _ Provider = (*Twitter)(nil)
// NameTwitter is the unique name of the Twitter provider.
@@ -15,12 +19,12 @@ const NameTwitter string = "twitter"
// Twitter allows authentication via Twitter OAuth2.
type Twitter struct {
*baseProvider
BaseProvider
}
// NewTwitterProvider creates new Twitter provider instance with some defaults.
func NewTwitterProvider() *Twitter {
return &Twitter{&baseProvider{
return &Twitter{BaseProvider{
ctx: context.Background(),
displayName: "Twitter",
pkce: true,
@@ -31,9 +35,9 @@ func NewTwitterProvider() *Twitter {
// (see https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me)
"tweet.read",
},
authUrl: "https://twitter.com/i/oauth2/authorize",
tokenUrl: "https://api.twitter.com/2/oauth2/token",
userApiUrl: "https://api.twitter.com/2/users/me?user.fields=id,name,username,profile_image_url",
authURL: "https://twitter.com/i/oauth2/authorize",
tokenURL: "https://api.twitter.com/2/oauth2/token",
userInfoURL: "https://api.twitter.com/2/users/me?user.fields=id,name,username,profile_image_url",
}}
}
@@ -41,7 +45,7 @@ func NewTwitterProvider() *Twitter {
//
// API reference: https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me
func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -56,7 +60,7 @@ func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
ProfileImageUrl string `json:"profile_image_url"`
ProfileImageURL string `json:"profile_image_url"`
// NB! At the time of writing, Twitter OAuth2 doesn't support returning the user email address
// (see https://twittercommunity.com/t/which-api-to-get-user-after-oauth2-authorization/162417/33)
@@ -71,7 +75,7 @@ func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: extracted.Data.Id,
Name: extracted.Data.Name,
Username: extracted.Data.Username,
AvatarUrl: extracted.Data.ProfileImageUrl,
AvatarURL: extracted.Data.ProfileImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+12 -8
View File
@@ -13,6 +13,10 @@ import (
"golang.org/x/oauth2/vk"
)
func init() {
Providers[NameVK] = wrapFactory(NewVKProvider)
}
var _ Provider = (*VK)(nil)
// NameVK is the unique name of the VK provider.
@@ -20,21 +24,21 @@ const NameVK string = "vk"
// VK allows authentication via VK OAuth2.
type VK struct {
*baseProvider
BaseProvider
}
// NewVKProvider creates new VK provider instance with some defaults.
//
// Docs: https://dev.vk.com/api/oauth-parameters
func NewVKProvider() *VK {
return &VK{&baseProvider{
return &VK{BaseProvider{
ctx: context.Background(),
displayName: "ВКонтакте",
pkce: false, // VK currently doesn't support PKCE and throws an error if PKCE params are send
scopes: []string{"email"},
authUrl: vk.Endpoint.AuthURL,
tokenUrl: vk.Endpoint.TokenURL,
userApiUrl: "https://api.vk.com/method/users.get?fields=photo_max,screen_name&v=5.131",
authURL: vk.Endpoint.AuthURL,
tokenURL: vk.Endpoint.TokenURL,
userInfoURL: "https://api.vk.com/method/users.get?fields=photo_max,screen_name&v=5.131",
}}
}
@@ -42,7 +46,7 @@ func NewVKProvider() *VK {
//
// API reference: https://dev.vk.com/method/users.get
func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -58,7 +62,7 @@ func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
Username string `json:"screen_name"`
AvatarUrl string `json:"photo_max"`
AvatarURL string `json:"photo_max"`
} `json:"response"`
}{}
@@ -74,7 +78,7 @@ func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
Id: strconv.Itoa(extracted.Response[0].Id),
Name: strings.TrimSpace(extracted.Response[0].FirstName + " " + extracted.Response[0].LastName),
Username: extracted.Response[0].Username,
AvatarUrl: extracted.Response[0].AvatarUrl,
AvatarURL: extracted.Response[0].AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
+11 -7
View File
@@ -9,6 +9,10 @@ import (
"golang.org/x/oauth2/yandex"
)
func init() {
Providers[NameYandex] = wrapFactory(NewYandexProvider)
}
var _ Provider = (*Yandex)(nil)
// NameYandex is the unique name of the Yandex provider.
@@ -16,21 +20,21 @@ const NameYandex string = "yandex"
// Yandex allows authentication via Yandex OAuth2.
type Yandex struct {
*baseProvider
BaseProvider
}
// NewYandexProvider creates new Yandex provider instance with some defaults.
//
// Docs: https://yandex.ru/dev/id/doc/en/
func NewYandexProvider() *Yandex {
return &Yandex{&baseProvider{
return &Yandex{BaseProvider{
ctx: context.Background(),
displayName: "Yandex",
pkce: true,
scopes: []string{"login:email", "login:avatar", "login:info"},
authUrl: yandex.Endpoint.AuthURL,
tokenUrl: yandex.Endpoint.TokenURL,
userApiUrl: "https://login.yandex.ru/info",
authURL: yandex.Endpoint.AuthURL,
tokenURL: yandex.Endpoint.TokenURL,
userInfoURL: "https://login.yandex.ru/info",
}}
}
@@ -38,7 +42,7 @@ func NewYandexProvider() *Yandex {
//
// API reference: https://yandex.ru/dev/id/doc/en/user-information#response-format
func (p *Yandex) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserData(token)
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
@@ -73,7 +77,7 @@ func (p *Yandex) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if !extracted.IsAvatarEmpty {
user.AvatarUrl = "https://avatars.yandex.net/get-yapic/" + extracted.AvatarId + "/islands-200"
user.AvatarURL = "https://avatars.yandex.net/get-yapic/" + extracted.AvatarId + "/islands-200"
}
return user, nil
+24 -25
View File
@@ -26,10 +26,9 @@ type Cron struct {
ticker *time.Ticker
startTimer *time.Timer
jobs map[string]*job
interval time.Duration
tickerDone chan bool
sync.RWMutex
interval time.Duration
mux sync.RWMutex
}
// New create a new Cron struct with default tick interval of 1 minute
@@ -50,10 +49,10 @@ func New() *Cron {
// (it usually should be >= 1 minute).
func (c *Cron) SetInterval(d time.Duration) {
// update interval
c.Lock()
c.mux.Lock()
wasStarted := c.ticker != nil
c.interval = d
c.Unlock()
c.mux.Unlock()
// restart the ticker
if wasStarted {
@@ -63,8 +62,8 @@ func (c *Cron) SetInterval(d time.Duration) {
// SetTimezone changes the current cron tick timezone.
func (c *Cron) SetTimezone(l *time.Location) {
c.Lock()
defer c.Unlock()
c.mux.Lock()
defer c.mux.Unlock()
c.timezone = l
}
@@ -88,8 +87,8 @@ func (c *Cron) Add(jobId string, cronExpr string, run func()) error {
return errors.New("failed to add new cron job: run must be non-nil function")
}
c.Lock()
defer c.Unlock()
c.mux.Lock()
defer c.mux.Unlock()
schedule, err := NewSchedule(cronExpr)
if err != nil {
@@ -106,24 +105,24 @@ func (c *Cron) Add(jobId string, cronExpr string, run func()) error {
// Remove removes a single cron job by its id.
func (c *Cron) Remove(jobId string) {
c.Lock()
defer c.Unlock()
c.mux.Lock()
defer c.mux.Unlock()
delete(c.jobs, jobId)
}
// RemoveAll removes all registered cron jobs.
func (c *Cron) RemoveAll() {
c.Lock()
defer c.Unlock()
c.mux.Lock()
defer c.mux.Unlock()
c.jobs = map[string]*job{}
}
// Total returns the current total number of registered cron jobs.
func (c *Cron) Total() int {
c.RLock()
defer c.RUnlock()
c.mux.RLock()
defer c.mux.RUnlock()
return len(c.jobs)
}
@@ -132,8 +131,8 @@ func (c *Cron) Total() int {
//
// You can resume the ticker by calling Start().
func (c *Cron) Stop() {
c.Lock()
defer c.Unlock()
c.mux.Lock()
defer c.mux.Unlock()
if c.startTimer != nil {
c.startTimer.Stop()
@@ -160,11 +159,11 @@ func (c *Cron) Start() {
next := now.Add(c.interval).Truncate(c.interval)
delay := next.Sub(now)
c.Lock()
c.mux.Lock()
c.startTimer = time.AfterFunc(delay, func() {
c.Lock()
c.mux.Lock()
c.ticker = time.NewTicker(c.interval)
c.Unlock()
c.mux.Unlock()
// run immediately at 00
c.runDue(time.Now())
@@ -181,21 +180,21 @@ func (c *Cron) Start() {
}
}()
})
c.Unlock()
c.mux.Unlock()
}
// HasStarted checks whether the current Cron ticker has been started.
func (c *Cron) HasStarted() bool {
c.RLock()
defer c.RUnlock()
c.mux.RLock()
defer c.mux.RUnlock()
return c.ticker != nil
}
// runDue runs all registered jobs that are scheduled for the provided time.
func (c *Cron) runDue(t time.Time) {
c.RLock()
defer c.RUnlock()
c.mux.RLock()
defer c.mux.RUnlock()
moment := NewMoment(t.In(c.timezone))
+29 -24
View File
@@ -2,6 +2,7 @@ package cron_test
import (
"encoding/json"
"fmt"
"testing"
"time"
@@ -252,26 +253,28 @@ func TestNewSchedule(t *testing.T) {
}
for _, s := range scenarios {
schedule, err := cron.NewSchedule(s.cronExpr)
t.Run(s.cronExpr, func(t *testing.T) {
schedule, err := cron.NewSchedule(s.cronExpr)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.cronExpr, s.expectError, hasErr, err)
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
continue
}
if hasErr {
return
}
encoded, err := json.Marshal(schedule)
if err != nil {
t.Fatalf("[%s] Failed to marshalize the result schedule: %v", s.cronExpr, err)
}
encodedStr := string(encoded)
encoded, err := json.Marshal(schedule)
if err != nil {
t.Fatalf("Failed to marshalize the result schedule: %v", err)
}
encodedStr := string(encoded)
if encodedStr != s.expectSchedule {
t.Fatalf("[%s] Expected \n%s, \ngot \n%s", s.cronExpr, s.expectSchedule, encodedStr)
}
if encodedStr != s.expectSchedule {
t.Fatalf("Expected \n%s, \ngot \n%s", s.expectSchedule, encodedStr)
}
})
}
}
@@ -390,15 +393,17 @@ func TestScheduleIsDue(t *testing.T) {
}
for i, s := range scenarios {
schedule, err := cron.NewSchedule(s.cronExpr)
if err != nil {
t.Fatalf("[%d-%s] Unexpected cron error: %v", i, s.cronExpr, err)
}
t.Run(fmt.Sprintf("%d-%s", i, s.cronExpr), func(t *testing.T) {
schedule, err := cron.NewSchedule(s.cronExpr)
if err != nil {
t.Fatalf("Unexpected cron error: %v", err)
}
result := schedule.IsDue(s.moment)
result := schedule.IsDue(s.moment)
if result != s.expected {
t.Fatalf("[%d-%s] Expected %v, got %v", i, s.cronExpr, s.expected, result)
}
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
+6 -6
View File
@@ -5,31 +5,31 @@ import (
"strings"
)
// JsonEach returns JSON_EACH SQLite string expression with
// JSONEach returns JSON_EACH SQLite string expression with
// some normalizations for non-json columns.
func JsonEach(column string) string {
func JSONEach(column string) string {
return fmt.Sprintf(
`json_each(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE json_array([[%s]]) END)`,
column, column, column,
)
}
// JsonArrayLength returns JSON_ARRAY_LENGTH SQLite string expression
// JSONArrayLength returns JSON_ARRAY_LENGTH SQLite string expression
// with some normalizations for non-json columns.
//
// It works with both json and non-json column values.
//
// Returns 0 for empty string or NULL column values.
func JsonArrayLength(column string) string {
func JSONArrayLength(column string) string {
return fmt.Sprintf(
`json_array_length(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE (CASE WHEN [[%s]] = '' OR [[%s]] IS NULL THEN json_array() ELSE json_array([[%s]]) END) END)`,
column, column, column, column, column,
)
}
// JsonExtract returns a JSON_EXTRACT SQLite string expression with
// JSONExtract returns a JSON_EXTRACT SQLite string expression with
// some normalizations for non-json columns.
func JsonExtract(column string, path string) string {
func JSONExtract(column string, path string) string {
// prefix the path with dot if it is not starting with array notation
if path != "" && !strings.HasPrefix(path, "[") {
path = "." + path
+6 -7
View File
@@ -6,8 +6,8 @@ import (
"github.com/pocketbase/pocketbase/tools/dbutils"
)
func TestJsonEach(t *testing.T) {
result := dbutils.JsonEach("a.b")
func TestJSONEach(t *testing.T) {
result := dbutils.JSONEach("a.b")
expected := "json_each(CASE WHEN json_valid([[a.b]]) THEN [[a.b]] ELSE json_array([[a.b]]) END)"
@@ -16,8 +16,8 @@ func TestJsonEach(t *testing.T) {
}
}
func TestJsonArrayLength(t *testing.T) {
result := dbutils.JsonArrayLength("a.b")
func TestJSONArrayLength(t *testing.T) {
result := dbutils.JSONArrayLength("a.b")
expected := "json_array_length(CASE WHEN json_valid([[a.b]]) THEN [[a.b]] ELSE (CASE WHEN [[a.b]] = '' OR [[a.b]] IS NULL THEN json_array() ELSE json_array([[a.b]]) END) END)"
@@ -26,7 +26,7 @@ func TestJsonArrayLength(t *testing.T) {
}
}
func TestJsonExtract(t *testing.T) {
func TestJSONExtract(t *testing.T) {
scenarios := []struct {
name string
column string
@@ -55,12 +55,11 @@ func TestJsonExtract(t *testing.T) {
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := dbutils.JsonExtract(s.column, s.path)
result := dbutils.JSONExtract(s.column, s.path)
if result != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, result)
}
})
}
}
+21 -9
View File
@@ -27,10 +27,20 @@ type FileReader interface {
//
// The file could be from a local path, multipart/form-data header, etc.
type File struct {
Reader FileReader
Name string
OriginalName string
Size int64
Reader FileReader `form:"-" json:"-" xml:"-"`
Name string `form:"name" json:"name" xml:"name"`
OriginalName string `form:"originalName" json:"originalName" xml:"originalName"`
Size int64 `form:"size" json:"size" xml:"size"`
}
// AsMap implements [core.mapExtractor] and returns a value suitable
// to be used in an API rule expression.
func (f *File) AsMap() map[string]any {
return map[string]any{
"name": f.Name,
"originalName": f.OriginalName,
"size": f.Size,
}
}
// NewFileFromPath creates a new File instance from the provided local file path.
@@ -79,7 +89,7 @@ func NewFileFromMultipart(mh *multipart.FileHeader) (*File, error) {
return f, nil
}
// NewFileFromUrl creates a new File from the provided url by
// NewFileFromURL creates a new File from the provided url by
// downloading the resource and load it as BytesReader.
//
// Example
@@ -87,8 +97,8 @@ func NewFileFromMultipart(mh *multipart.FileHeader) (*File, error) {
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// file, err := filesystem.NewFileFromUrl(ctx, "https://example.com/image.png")
func NewFileFromUrl(ctx context.Context, url string) (*File, error) {
// file, err := filesystem.NewFileFromURL(ctx, "https://example.com/image.png")
func NewFileFromURL(ctx context.Context, url string) (*File, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
@@ -168,6 +178,8 @@ func (r *bytesReadSeekCloser) Close() error {
var extInvalidCharsRegex = regexp.MustCompile(`[^\w\.\*\-\+\=\#]+`)
const randomAlphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
func normalizeName(fr FileReader, name string) string {
// extension
// ---
@@ -187,7 +199,7 @@ func normalizeName(fr FileReader, name string) string {
cleanName := inflector.Snakecase(strings.TrimSuffix(name, originalExt))
if length := len(cleanName); length < 3 {
// the name is too short so we concatenate an additional random part
cleanName += security.RandomString(10)
cleanName += security.RandomStringWithAlphabet(10, randomAlphabet)
} else if length > 100 {
// keep only the first 100 characters (it is multibyte safe after Snakecase)
cleanName = cleanName[:100]
@@ -196,7 +208,7 @@ func normalizeName(fr FileReader, name string) string {
return fmt.Sprintf(
"%s_%s%s",
cleanName,
security.RandomString(10), // ensure that there is always a random part
security.RandomStringWithAlphabet(10, randomAlphabet), // ensure that there is always a random part
cleanExt,
)
}
+30 -6
View File
@@ -12,11 +12,35 @@ import (
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/filesystem"
)
func TestFileAsMap(t *testing.T) {
file, err := filesystem.NewFileFromBytes([]byte("test"), "test123.txt")
if err != nil {
t.Fatal(err)
}
result := file.AsMap()
if len(result) != 3 {
t.Fatalf("Expected map with %d keys, got\n%v", 3, result)
}
if result["size"] != int64(4) {
t.Fatalf("Expected size %d, got %#v", 4, result["size"])
}
if str, ok := result["name"].(string); !ok || !strings.HasPrefix(str, "test123") {
t.Fatalf("Expected name to have prefix %q, got %#v", "test123", result["name"])
}
if result["originalName"] != "test123.txt" {
t.Fatalf("Expected originalName %q, got %#v", "test123.txt", result["originalName"])
}
}
func TestNewFileFromPath(t *testing.T) {
testDir := createTestDir(t)
defer os.RemoveAll(testDir)
@@ -83,7 +107,7 @@ func TestNewFileFromMultipart(t *testing.T) {
}
req := httptest.NewRequest("", "/", formData)
req.Header.Set(echo.HeaderContentType, mp.FormDataContentType())
req.Header.Set("Content-Type", mp.FormDataContentType())
req.ParseMultipartForm(32 << 20)
_, mh, err := req.FormFile("test")
@@ -115,7 +139,7 @@ func TestNewFileFromMultipart(t *testing.T) {
}
}
func TestNewFileFromUrlTimeout(t *testing.T) {
func TestNewFileFromURLTimeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/error" {
w.WriteHeader(http.StatusInternalServerError)
@@ -129,7 +153,7 @@ func TestNewFileFromUrlTimeout(t *testing.T) {
{
ctx, cancel := context.WithCancel(context.Background())
cancel()
f, err := filesystem.NewFileFromUrl(ctx, srv.URL+"/cancel")
f, err := filesystem.NewFileFromURL(ctx, srv.URL+"/cancel")
if err == nil {
t.Fatal("[ctx_cancel] Expected error, got nil")
}
@@ -140,7 +164,7 @@ func TestNewFileFromUrlTimeout(t *testing.T) {
// error response
{
f, err := filesystem.NewFileFromUrl(context.Background(), srv.URL+"/error")
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/error")
if err == nil {
t.Fatal("[error_status] Expected error, got nil")
}
@@ -154,7 +178,7 @@ func TestNewFileFromUrlTimeout(t *testing.T) {
originalName := "image_! noext"
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".txt")
f, err := filesystem.NewFileFromUrl(context.Background(), srv.URL+"/"+originalName)
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/"+originalName)
if err != nil {
t.Fatalf("[valid] Unexpected error %v", err)
}
+84 -25
View File
@@ -20,13 +20,17 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/disintegration/imaging"
"github.com/gabriel-vasile/mimetype"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3lite"
"github.com/pocketbase/pocketbase/tools/list"
"gocloud.dev/blob"
"gocloud.dev/blob/fileblob"
"gocloud.dev/gcerrors"
)
var gcpIgnoreHeaders = []string{"Accept-Encoding"}
var ErrNotFound = errors.New("blob not found")
type System struct {
ctx context.Context
bucket *blob.Bucket
@@ -47,25 +51,23 @@ func NewS3(
cred := credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")
cfg, err := config.LoadDefaultConfig(ctx,
cfg, err := config.LoadDefaultConfig(
ctx,
config.WithCredentialsProvider(cred),
config.WithRegion(region),
config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
// ensure that the endpoint has url scheme for
// backward compatibility with v1 of the aws sdk
prefixedEndpoint := endpoint
if !strings.Contains(endpoint, "://") {
prefixedEndpoint = "https://" + endpoint
}
return aws.Endpoint{URL: prefixedEndpoint, SigningRegion: region}, nil
})),
)
if err != nil {
return nil, err
}
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
// ensure that the endpoint has url scheme for
// backward compatibility with v1 of the aws sdk
if !strings.Contains(endpoint, "://") {
endpoint = "https://" + endpoint
}
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = s3ForcePathStyle
// Google Cloud Storage alters the Accept-Encoding header,
@@ -76,7 +78,7 @@ func NewS3(
}
})
bucket, err := OpenBucketV2(ctx, client, bucketName, nil)
bucket, err := s3lite.OpenBucketV2(ctx, client, bucketName, nil)
if err != nil {
return nil, err
}
@@ -116,32 +118,59 @@ func (s *System) Close() error {
}
// Exists checks if file with fileKey path exists or not.
//
// If the file doesn't exist returns false and ErrNotFound.
func (s *System) Exists(fileKey string) (bool, error) {
return s.bucket.Exists(s.ctx, fileKey)
exists, err := s.bucket.Exists(s.ctx, fileKey)
if gcerrors.Code(err) == gcerrors.NotFound {
err = ErrNotFound
}
return exists, err
}
// Attributes returns the attributes for the file with fileKey path.
//
// If the file doesn't exist it returns ErrNotFound.
func (s *System) Attributes(fileKey string) (*blob.Attributes, error) {
return s.bucket.Attributes(s.ctx, fileKey)
attrs, err := s.bucket.Attributes(s.ctx, fileKey)
if gcerrors.Code(err) == gcerrors.NotFound {
err = ErrNotFound
}
return attrs, err
}
// GetFile returns a file content reader for the given fileKey.
//
// NB! Make sure to call `Close()` after you are done working with it.
// NB! Make sure to call Close() on the file after you are done working with it.
//
// If the file doesn't exist returns ErrNotFound.
func (s *System) GetFile(fileKey string) (*blob.Reader, error) {
br, err := s.bucket.NewReader(s.ctx, fileKey, nil)
if err != nil {
return nil, err
if gcerrors.Code(err) == gcerrors.NotFound {
err = ErrNotFound
}
return br, nil
return br, err
}
// Copy copies the file stored at srcKey to dstKey.
//
// If srcKey file doesn't exist, it returns ErrNotFound.
//
// If dstKey file already exists, it is overwritten.
func (s *System) Copy(srcKey, dstKey string) error {
return s.bucket.Copy(s.ctx, dstKey, srcKey, nil)
err := s.bucket.Copy(s.ctx, dstKey, srcKey, nil)
if gcerrors.Code(err) == gcerrors.NotFound {
err = ErrNotFound
}
return err
}
// List returns a flat list with info for all files under the specified prefix.
@@ -178,14 +207,13 @@ func (s *System) Upload(content []byte, fileKey string) error {
}
if _, err := w.Write(content); err != nil {
w.Close()
return err
return errors.Join(err, w.Close())
}
return w.Close()
}
// UploadFile uploads the provided multipart file to the fileKey location.
// UploadFile uploads the provided File to the fileKey location.
func (s *System) UploadFile(file *File, fileKey string) error {
f, err := file.Reader.Open()
if err != nil {
@@ -270,8 +298,16 @@ func (s *System) UploadMultipart(fh *multipart.FileHeader, fileKey string) error
}
// Delete deletes stored file at fileKey location.
//
// If the file doesn't exist returns ErrNotFound.
func (s *System) Delete(fileKey string) error {
return s.bucket.Delete(s.ctx, fileKey)
err := s.bucket.Delete(s.ctx, fileKey)
if gcerrors.Code(err) == gcerrors.NotFound {
return ErrNotFound
}
return err
}
// DeletePrefix deletes everything starting with the specified prefix.
@@ -345,6 +381,26 @@ func (s *System) DeletePrefix(prefix string) []error {
return failed
}
// Checks if the provided dir prefix doesn't have any files.
//
// A trailing slash will be appended to a non-empty dir string argument
// to ensure that the checked prefix is a "directory".
//
// Returns "false" in case the has at least one file, otherwise - "true".
func (s *System) IsEmptyDir(dir string) bool {
if dir != "" && !strings.HasSuffix(dir, "/") {
dir += "/"
}
iter := s.bucket.List(&blob.ListOptions{
Prefix: dir,
})
_, err := iter.Next(s.ctx)
return err == io.EOF
}
var inlineServeContentTypes = []string{
// image
"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp", "image/x-icon", "image/bmp",
@@ -371,8 +427,11 @@ const forceAttachmentParam = "download"
//
// If the `download` query parameter is used the file will be always served for
// download no matter of its type (aka. with "Content-Disposition: attachment").
//
// Internally this method uses [http.ServeContent] so Range requests,
// If-Match, If-Unmodified-Since, etc. headers are handled transparently.
func (s *System) Serve(res http.ResponseWriter, req *http.Request, fileKey string, name string) error {
br, readErr := s.bucket.NewReader(s.ctx, fileKey, nil)
br, readErr := s.GetFile(fileKey)
if readErr != nil {
return readErr
}
@@ -444,7 +503,7 @@ func (s *System) CreateThumb(originalKey string, thumbKey, thumbSize string) err
}
// fetch the original
r, readErr := s.bucket.NewReader(s.ctx, originalKey, nil)
r, readErr := s.GetFile(originalKey)
if readErr != nil {
return readErr
}
+182 -111
View File
@@ -2,6 +2,7 @@ package filesystem_test
import (
"bytes"
"errors"
"image"
"image/png"
"mime/multipart"
@@ -19,11 +20,11 @@ func TestFileSystemExists(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
scenarios := []struct {
file string
@@ -35,12 +36,18 @@ func TestFileSystemExists(t *testing.T) {
{"image.png", true},
}
for i, scenario := range scenarios {
exists, _ := fs.Exists(scenario.file)
for _, s := range scenarios {
t.Run(s.file, func(t *testing.T) {
exists, err := fsys.Exists(s.file)
if exists != scenario.exists {
t.Errorf("(%d) Expected %v, got %v", i, scenario.exists, exists)
}
if err != nil {
t.Fatal(err)
}
if exists != s.exists {
t.Fatalf("Expected exists %v, got %v", s.exists, exists)
}
})
}
}
@@ -48,11 +55,11 @@ func TestFileSystemAttributes(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
scenarios := []struct {
file string
@@ -65,20 +72,24 @@ func TestFileSystemAttributes(t *testing.T) {
{"image.png", false, "image/png"},
}
for i, scenario := range scenarios {
attr, err := fs.Attributes(scenario.file)
for _, s := range scenarios {
t.Run(s.file, func(t *testing.T) {
attr, err := fsys.Attributes(s.file)
if err == nil && scenario.expectError {
t.Errorf("(%d) Expected error, got nil", i)
}
hasErr := err != nil
if err != nil && !scenario.expectError {
t.Errorf("(%d) Expected nil, got error, %v", i, err)
}
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
if err == nil && attr.ContentType != scenario.expectContentType {
t.Errorf("(%d) Expected attr.ContentType to be %q, got %q", i, scenario.expectContentType, attr.ContentType)
}
if hasErr && !errors.Is(err, filesystem.ErrNotFound) {
t.Fatalf("Expected ErrNotFound err, got %q", err)
}
if !hasErr && attr.ContentType != s.expectContentType {
t.Fatalf("Expected attr.ContentType to be %q, got %q", s.expectContentType, attr.ContentType)
}
})
}
}
@@ -86,17 +97,17 @@ func TestFileSystemDelete(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
if err := fs.Delete("missing.txt"); err == nil {
t.Fatal("Expected error, got nil")
if err := fsys.Delete("missing.txt"); err == nil || !errors.Is(err, filesystem.ErrNotFound) {
t.Fatalf("Expected ErrNotFound error, got %v", err)
}
if err := fs.Delete("image.png"); err != nil {
if err := fsys.Delete("image.png"); err != nil {
t.Fatalf("Expected nil, got error %v", err)
}
}
@@ -105,29 +116,29 @@ func TestFileSystemDeletePrefixWithoutTrailingSlash(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
if errs := fs.DeletePrefix(""); len(errs) == 0 {
if errs := fsys.DeletePrefix(""); len(errs) == 0 {
t.Fatal("Expected error, got nil", errs)
}
if errs := fs.DeletePrefix("missing"); len(errs) != 0 {
if errs := fsys.DeletePrefix("missing"); len(errs) != 0 {
t.Fatalf("Not existing prefix shouldn't error, got %v", errs)
}
if errs := fs.DeletePrefix("test"); len(errs) != 0 {
if errs := fsys.DeletePrefix("test"); len(errs) != 0 {
t.Fatalf("Expected nil, got errors %v", errs)
}
// ensure that the test/* files are deleted
if exists, _ := fs.Exists("test/sub1.txt"); exists {
if exists, _ := fsys.Exists("test/sub1.txt"); exists {
t.Fatalf("Expected test/sub1.txt to be deleted")
}
if exists, _ := fs.Exists("test/sub2.txt"); exists {
if exists, _ := fsys.Exists("test/sub2.txt"); exists {
t.Fatalf("Expected test/sub2.txt to be deleted")
}
@@ -141,25 +152,25 @@ func TestFileSystemDeletePrefixWithTrailingSlash(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
if errs := fs.DeletePrefix("missing/"); len(errs) != 0 {
if errs := fsys.DeletePrefix("missing/"); len(errs) != 0 {
t.Fatalf("Not existing prefix shouldn't error, got %v", errs)
}
if errs := fs.DeletePrefix("test/"); len(errs) != 0 {
if errs := fsys.DeletePrefix("test/"); len(errs) != 0 {
t.Fatalf("Expected nil, got errors %v", errs)
}
// ensure that the test/* files are deleted
if exists, _ := fs.Exists("test/sub1.txt"); exists {
if exists, _ := fsys.Exists("test/sub1.txt"); exists {
t.Fatalf("Expected test/sub1.txt to be deleted")
}
if exists, _ := fs.Exists("test/sub2.txt"); exists {
if exists, _ := fsys.Exists("test/sub2.txt"); exists {
t.Fatalf("Expected test/sub2.txt to be deleted")
}
@@ -169,6 +180,41 @@ func TestFileSystemDeletePrefixWithTrailingSlash(t *testing.T) {
}
}
func TestFileSystemIsEmptyDir(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fsys.Close()
scenarios := []struct {
dir string
expected bool
}{
{"", false}, // special case that shouldn't be suffixed with delimiter to search for any files within the bucket
{"/", true},
{"missing", true},
{"missing/", true},
{"test", false},
{"test/", false},
{"empty", true},
{"empty/", true},
}
for _, s := range scenarios {
t.Run(s.dir, func(t *testing.T) {
result := fsys.IsEmptyDir(s.dir)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestFileSystemUploadMultipart(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
@@ -193,24 +239,24 @@ func TestFileSystemUploadMultipart(t *testing.T) {
defer file.Close()
// ---
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
fileKey := "newdir/newkey.txt"
uploadErr := fs.UploadMultipart(fh, fileKey)
uploadErr := fsys.UploadMultipart(fh, fileKey)
if uploadErr != nil {
t.Fatal(uploadErr)
}
if exists, _ := fs.Exists(fileKey); !exists {
if exists, _ := fsys.Exists(fileKey); !exists {
t.Fatalf("Expected %q to exist", fileKey)
}
attrs, err := fs.Attributes(fileKey)
attrs, err := fsys.Attributes(fileKey)
if err != nil {
t.Fatalf("Failed to fetch file attributes: %v", err)
}
@@ -223,11 +269,11 @@ func TestFileSystemUploadFile(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
fileKey := "newdir/newkey.txt"
@@ -238,16 +284,16 @@ func TestFileSystemUploadFile(t *testing.T) {
file.OriginalName = "test.txt"
uploadErr := fs.UploadFile(file, fileKey)
uploadErr := fsys.UploadFile(file, fileKey)
if uploadErr != nil {
t.Fatal(uploadErr)
}
if exists, _ := fs.Exists(fileKey); !exists {
if exists, _ := fsys.Exists(fileKey); !exists {
t.Fatalf("Expected %q to exist", fileKey)
}
attrs, err := fs.Attributes(fileKey)
attrs, err := fsys.Attributes(fileKey)
if err != nil {
t.Fatalf("Failed to fetch file attributes: %v", err)
}
@@ -260,20 +306,20 @@ func TestFileSystemUpload(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
fileKey := "newdir/newkey.txt"
uploadErr := fs.Upload([]byte("demo"), fileKey)
uploadErr := fsys.Upload([]byte("demo"), fileKey)
if uploadErr != nil {
t.Fatal(uploadErr)
}
if exists, _ := fs.Exists(fileKey); !exists {
if exists, _ := fsys.Exists(fileKey); !exists {
t.Fatalf("Expected %s to exist", fileKey)
}
}
@@ -282,11 +328,11 @@ func TestFileSystemServe(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
csp := "default-src 'none'; media-src 'self'; style-src 'unsafe-inline'; sandbox"
cacheControl := "max-age=2592000, stale-while-revalidate=86400"
@@ -409,39 +455,41 @@ func TestFileSystemServe(t *testing.T) {
}
for _, s := range scenarios {
res := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
t.Run(s.path, func(t *testing.T) {
res := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
query := req.URL.Query()
for k, v := range s.query {
query.Set(k, v)
}
req.URL.RawQuery = query.Encode()
for k, v := range s.headers {
res.Header().Set(k, v)
}
err := fs.Serve(res, req, s.path, s.name)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("(%s) Expected hasError %v, got %v (%v)", s.path, s.expectError, hasErr, err)
continue
}
if s.expectError {
continue
}
result := res.Result()
for hName, hValue := range s.expectHeaders {
v := result.Header.Get(hName)
if v != hValue {
t.Errorf("(%s) Expected value %q for header %q, got %q", s.path, hValue, hName, v)
query := req.URL.Query()
for k, v := range s.query {
query.Set(k, v)
}
}
req.URL.RawQuery = query.Encode()
for k, v := range s.headers {
res.Header().Set(k, v)
}
err := fsys.Serve(res, req, s.path, s.name)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasError %v, got %v (%v)", s.expectError, hasErr, err)
}
if s.expectError {
return
}
result := res.Result()
defer result.Body.Close()
for hName, hValue := range s.expectHeaders {
v := result.Header.Get(hName)
if v != hValue {
t.Errorf("Expected value %q for header %q, got %q", hValue, hName, v)
}
}
})
}
}
@@ -449,20 +497,38 @@ func TestFileSystemGetFile(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
f, err := fs.GetFile("image.png")
if err != nil {
t.Fatal(err)
scenarios := []struct {
file string
expectError bool
}{
{"missing.png", true},
{"image.png", false},
}
defer f.Close()
if f == nil {
t.Fatal("File is supposed to be found")
for _, s := range scenarios {
t.Run(s.file, func(t *testing.T) {
f, err := fsys.GetFile(s.file)
hasErr := err != nil
if !hasErr {
defer f.Close()
}
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
if hasErr && !errors.Is(err, filesystem.ErrNotFound) {
t.Fatalf("Expected ErrNotFound error, got %v", err)
}
})
}
}
@@ -470,25 +536,26 @@ func TestFileSystemCopy(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
src := "image.png"
dst := "image.png_copy"
// copy missing file
if err := fs.Copy(dst, src); err == nil {
if err := fsys.Copy(dst, src); err == nil {
t.Fatalf("Expected to fail copying %q to %q, got nil", dst, src)
}
// copy existing file
if err := fs.Copy(src, dst); err != nil {
if err := fsys.Copy(src, dst); err != nil {
t.Fatalf("Failed to copy %q to %q: %v", src, dst, err)
}
f, err := fs.GetFile(dst)
f, err := fsys.GetFile(dst)
//nolint
defer f.Close()
if err != nil {
t.Fatalf("Missing copied file %q: %v", dst, err)
@@ -502,11 +569,11 @@ func TestFileSystemList(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
scenarios := []struct {
prefix string
@@ -537,7 +604,7 @@ func TestFileSystemList(t *testing.T) {
}
for _, s := range scenarios {
objs, err := fs.List(s.prefix)
objs, err := fsys.List(s.prefix)
if err != nil {
t.Fatalf("[%s] %v", s.prefix, err)
}
@@ -563,17 +630,17 @@ func TestFileSystemServeSingleRange(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
res := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
req.Header.Add("Range", "bytes=0-20")
if err := fs.Serve(res, req, "image.png", "image.png"); err != nil {
if err := fsys.Serve(res, req, "image.png", "image.png"); err != nil {
t.Fatal(err)
}
@@ -597,17 +664,17 @@ func TestFileSystemServeMultiRange(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
res := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)
req.Header.Add("Range", "bytes=0-20, 25-30")
if err := fs.Serve(res, req, "image.png", "image.png"); err != nil {
if err := fsys.Serve(res, req, "image.png", "image.png"); err != nil {
t.Fatal(err)
}
@@ -626,11 +693,11 @@ func TestFileSystemCreateThumb(t *testing.T) {
dir := createTestDir(t)
defer os.RemoveAll(dir)
fs, err := filesystem.NewLocal(dir)
fsys, err := filesystem.NewLocal(dir)
if err != nil {
t.Fatal(err)
}
defer fs.Close()
defer fsys.Close()
scenarios := []struct {
file string
@@ -651,7 +718,7 @@ func TestFileSystemCreateThumb(t *testing.T) {
}
for i, scenario := range scenarios {
err := fs.CreateThumb(scenario.file, scenario.thumb, "100x100")
err := fsys.CreateThumb(scenario.file, scenario.thumb, "100x100")
hasErr := err != nil
if hasErr != scenario.expectError {
@@ -663,7 +730,7 @@ func TestFileSystemCreateThumb(t *testing.T) {
continue
}
if exists, _ := fs.Exists(scenario.thumb); !exists {
if exists, _ := fsys.Exists(scenario.thumb); !exists {
t.Errorf("(%d) Couldn't find %q thumb", i, scenario.thumb)
}
}
@@ -677,6 +744,10 @@ func createTestDir(t *testing.T) string {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "empty"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "test"), os.ModePerm); err != nil {
t.Fatal(err)
}
@@ -66,7 +66,7 @@
// (V1) *s3.PutObjectInput; (V2) *s3v2.PutObjectInput, when Options.Method == http.MethodPut, or
// (V1) *s3.DeleteObjectInput; (V2) [not supported] when Options.Method == http.MethodDelete
package filesystem
package s3lite
import (
"context"
@@ -82,7 +82,6 @@ import (
"strings"
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
awsv2cfg "github.com/aws/aws-sdk-go-v2/config"
s3managerv2 "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
s3v2 "github.com/aws/aws-sdk-go-v2/service/s3"
typesv2 "github.com/aws/aws-sdk-go-v2/service/s3/types"
@@ -244,116 +243,8 @@ func URLUnescape(s string) string {
// -------------------------------------------------------------------
// UseV2 returns true iff the URL parameters indicate that the provider
// should use the AWS SDK v2.
//
// "awssdk=v1" will force V1.
// "awssdk=v2" will force V2.
// No "awssdk" parameter (or any other value) will return the default, currently V1.
// Note that the default may change in the future.
func UseV2(q url.Values) bool {
if values, ok := q["awssdk"]; ok {
if values[0] == "v2" || values[0] == "V2" {
return true
}
}
return false
}
// NewDefaultV2Config returns a aws.Config for AWS SDK v2, using the default options.
func NewDefaultV2Config(ctx context.Context) (awsv2.Config, error) {
return awsv2cfg.LoadDefaultConfig(ctx)
}
// V2ConfigFromURLParams returns an aws.Config for AWS SDK v2 initialized based on the URL
// parameters in q. It is intended to be used by URLOpeners for AWS services if
// UseV2 returns true.
//
// https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#Config
//
// It returns an error if q contains any unknown query parameters; callers
// should remove any query parameters they know about from q before calling
// V2ConfigFromURLParams.
//
// The following query options are supported:
// - region: The AWS region for requests; sets WithRegion.
// - profile: The shared config profile to use; sets SharedConfigProfile.
// - endpoint: The AWS service endpoint to send HTTP request.
func V2ConfigFromURLParams(ctx context.Context, q url.Values) (awsv2.Config, error) {
var opts []func(*awsv2cfg.LoadOptions) error
for param, values := range q {
value := values[0]
switch param {
case "region":
opts = append(opts, awsv2cfg.WithRegion(value))
case "endpoint":
customResolver := awsv2.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (awsv2.Endpoint, error) {
return awsv2.Endpoint{
PartitionID: "aws",
URL: value,
SigningRegion: region,
}, nil
})
opts = append(opts, awsv2cfg.WithEndpointResolverWithOptions(customResolver))
case "profile":
opts = append(opts, awsv2cfg.WithSharedConfigProfile(value))
case "awssdk":
// ignore, should be handled before this
default:
return awsv2.Config{}, fmt.Errorf("unknown query parameter %q", param)
}
}
return awsv2cfg.LoadDefaultConfig(ctx, opts...)
}
// -------------------------------------------------------------------
const defaultPageSize = 1000
func init() {
blob.DefaultURLMux().RegisterBucket(Scheme, new(urlSessionOpener))
}
type urlSessionOpener struct{}
func (o *urlSessionOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
opener := &URLOpener{UseV2: true}
return opener.OpenBucketURL(ctx, u)
}
// Scheme is the URL scheme s3blob registers its URLOpener under on
// blob.DefaultMux.
const Scheme = "s3"
// URLOpener opens S3 URLs like "s3://mybucket".
//
// The URL host is used as the bucket name.
//
// Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2,
// or anything else to accept the default.
//
// For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters
// for overriding the aws.Session from the URL.
// For V2, see gocloud.dev/aws/V2ConfigFromURLParams.
type URLOpener struct {
// UseV2 indicates whether the AWS SDK V2 should be used.
UseV2 bool
// Options specifies the options to pass to OpenBucket.
Options Options
}
// OpenBucketURL opens a blob.Bucket based on u.
func (o *URLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
cfg, err := V2ConfigFromURLParams(ctx, u.Query())
if err != nil {
return nil, fmt.Errorf("open bucket %v: %v", u, err)
}
clientV2 := s3v2.NewFromConfig(cfg)
return OpenBucketV2(ctx, clientV2, u.Host, &o.Options)
}
// Options sets options for constructing a *blob.Bucket backed by fileblob.
type Options struct {
// UseLegacyList forces the use of ListObjects instead of ListObjectsV2.
@@ -676,64 +567,6 @@ func (b *bucket) listObjectsV2(ctx context.Context, in *s3v2.ListObjectsV2Input,
}, nil
}
// func (b *bucket) listObjects(ctx context.Context, in *s3.ListObjectsV2Input, opts *driver.ListOptions) (*s3.ListObjectsV2Output, error) {
// if !b.useLegacyList {
// if opts.BeforeList != nil {
// asFunc := func(i interface{}) bool {
// if p, ok := i.(**s3.ListObjectsV2Input); ok {
// *p = in
// return true
// }
// return false
// }
// if err := opts.BeforeList(asFunc); err != nil {
// return nil, err
// }
// }
// return b.client.ListObjectsV2WithContext(ctx, in)
// }
// // Use the legacy ListObjects request.
// legacyIn := &s3.ListObjectsInput{
// Bucket: in.Bucket,
// Delimiter: in.Delimiter,
// EncodingType: in.EncodingType,
// Marker: in.ContinuationToken,
// MaxKeys: in.MaxKeys,
// Prefix: in.Prefix,
// RequestPayer: in.RequestPayer,
// }
// if opts.BeforeList != nil {
// asFunc := func(i interface{}) bool {
// p, ok := i.(**s3.ListObjectsInput)
// if !ok {
// return false
// }
// *p = legacyIn
// return true
// }
// if err := opts.BeforeList(asFunc); err != nil {
// return nil, err
// }
// }
// legacyResp, err := b.client.ListObjectsWithContext(ctx, legacyIn)
// if err != nil {
// return nil, err
// }
// var nextContinuationToken *string
// if legacyResp.NextMarker != nil {
// nextContinuationToken = legacyResp.NextMarker
// } else if awsv2.ToBool(legacyResp.IsTruncated) {
// nextContinuationToken = awsv2.String(awsv2.ToString(legacyResp.Contents[len(legacyResp.Contents)-1].Key))
// }
// return &s3.ListObjectsV2Output{
// CommonPrefixes: legacyResp.CommonPrefixes,
// Contents: legacyResp.Contents,
// NextContinuationToken: nextContinuationToken,
// }, nil
// }
// As implements driver.As.
func (b *bucket) As(i interface{}) bool {
p, ok := i.(**s3v2.Client)
+45
View File
@@ -0,0 +1,45 @@
package hook
// Resolver defines a common interface for a Hook event (see [Event]).
type Resolver interface {
// Next triggers the next handler in the hook's chain (if any).
Next() error
// note: kept only for the generic interface; may get removed in the future
nextFunc() func() error
setNextFunc(f func() error)
}
var _ Resolver = (*Event)(nil)
// Event implements [Resolver] and it is intended to be used as a base
// Hook event that you can embed in your custom typed event structs.
//
// Example:
//
// type CustomEvent struct {
// hook.Event
//
// SomeField int
// }
type Event struct {
next func() error
}
// Next calls the next hook handler.
func (e *Event) Next() error {
if e.next != nil {
return e.next()
}
return nil
}
// nextFunc returns the function that Next calls.
func (e *Event) nextFunc() func() error {
return e.next
}
// setNextFunc sets the function that Next calls.
func (e *Event) setNextFunc(f func() error) {
e.next = f
}
+29
View File
@@ -0,0 +1,29 @@
package hook
import "testing"
func TestEventNext(t *testing.T) {
calls := 0
e := Event{}
if e.nextFunc() != nil {
t.Fatalf("Expected nextFunc to be nil")
}
e.setNextFunc(func() error {
calls++
return nil
})
if e.nextFunc() == nil {
t.Fatalf("Expected nextFunc to be non-nil")
}
e.Next()
e.Next()
if calls != 2 {
t.Fatalf("Expected %d calls, got %d", 2, calls)
}
}
+138 -85
View File
@@ -1,126 +1,179 @@
package hook
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/pocketbase/pocketbase/tools/security"
)
var StopPropagation = errors.New("Event hook propagation stopped")
// HandlerFunc defines a hook handler function.
type HandlerFunc[T Resolver] func(e T) error
// Handler defines a hook handler function.
type Handler[T any] func(e T) error
// Handler defines a single Hook handler.
// Multiple handlers can share the same id.
// If Id is not explicitly set it will be autogenerated by Hook.Add and Hook.AddHandler.
type Handler[T Resolver] struct {
// Func defines the handler function to execute.
//
// Note that users need to call e.Next() in order to proceed with
// the execution of the hook chain.
Func HandlerFunc[T]
// handlerPair defines a pair of string id and Handler.
type handlerPair[T any] struct {
id string
handler Handler[T]
// Id is the unique identifier of the handler.
//
// It could be used later to remove the handler from a hook via [Hook.Remove].
//
// If missing, an autogenerated value will be assigned when adding
// the handler to a hook.
Id string
// Priority allows changing the default exec priority of the handler within a hook.
//
// If 0, the handler will be executed in the same order it was registered.
Priority int
}
// Hook defines a concurrent safe structure for handling event hooks
// (aka. callbacks propagation).
type Hook[T any] struct {
mux sync.RWMutex
handlers []*handlerPair[T]
}
// PreAdd registers a new handler to the hook by prepending it to the existing queue.
// Hook defines a generic concurrent safe structure for managing event hooks.
//
// Returns an autogenerated hook id that could be used later to remove the hook with Hook.Remove(id).
func (h *Hook[T]) PreAdd(fn Handler[T]) string {
h.mux.Lock()
defer h.mux.Unlock()
id := generateHookId()
// minimize allocations by shifting the slice
h.handlers = append(h.handlers, nil)
copy(h.handlers[1:], h.handlers)
h.handlers[0] = &handlerPair[T]{id, fn}
return id
}
// Add registers a new handler to the hook by appending it to the existing queue.
// When using custom a event it must embed the base [hook.Event].
//
// Returns an autogenerated hook id that could be used later to remove the hook with Hook.Remove(id).
func (h *Hook[T]) Add(fn Handler[T]) string {
h.mux.Lock()
defer h.mux.Unlock()
id := generateHookId()
h.handlers = append(h.handlers, &handlerPair[T]{id, fn})
return id
// Example:
//
// type CustomEvent struct {
// hook.Event
// SomeField int
// }
//
// h := Hook[*CustomEvent]{}
//
// h.BindFunc(func(e *CustomEvent) error {
// println(e.SomeField)
//
// return e.Next()
// })
//
// h.Trigger(&CustomEvent{ SomeField: 123 })
type Hook[T Resolver] struct {
handlers []*Handler[T]
mu sync.RWMutex
}
// Remove removes a single hook handler by its id.
func (h *Hook[T]) Remove(id string) {
h.mux.Lock()
defer h.mux.Unlock()
// Bind registers the provided handler to the current hooks queue.
//
// If handler.Id is empty it is updated with autogenerated value.
//
// If a handler from the current hook list has Id matching handler.Id
// then the old handler is replaced with the new one.
func (h *Hook[T]) Bind(handler *Handler[T]) string {
h.mu.Lock()
defer h.mu.Unlock()
var exists bool
if handler.Id == "" {
handler.Id = generateHookId()
// ensure that it doesn't exist
DUPLICATE_CHECK:
for _, existing := range h.handlers {
if existing.Id == handler.Id {
handler.Id = generateHookId()
goto DUPLICATE_CHECK
}
}
} else {
// replace existing
for i, existing := range h.handlers {
if existing.Id == handler.Id {
h.handlers[i] = handler
exists = true
break
}
}
}
// append new
if !exists {
h.handlers = append(h.handlers, handler)
}
// sort handlers by Priority, preserving the original order of equal items
sort.SliceStable(h.handlers, func(i, j int) bool {
return h.handlers[i].Priority < h.handlers[j].Priority
})
return handler.Id
}
// BindFunc is similar to Bind but registers a new handler from just the provided function.
//
// The registered handler is added with a default 0 priority and the id will be autogenerated.
//
// If you want to register a handler with custom priority or id use the [Hook.Bind] method.
func (h *Hook[T]) BindFunc(fn HandlerFunc[T]) string {
return h.Bind(&Handler[T]{Func: fn})
}
// Unbind removes a single hook handler by its id.
func (h *Hook[T]) Unbind(id string) {
h.mu.Lock()
defer h.mu.Unlock()
for i := len(h.handlers) - 1; i >= 0; i-- {
if h.handlers[i].id == id {
if h.handlers[i].Id == id {
h.handlers = append(h.handlers[:i], h.handlers[i+1:]...)
return
break // for now stop on the first occurrence since we don't allow handlers with duplicated ids
}
}
}
// RemoveAll removes all registered handlers.
func (h *Hook[T]) RemoveAll() {
h.mux.Lock()
defer h.mux.Unlock()
// UnbindAll removes all registered handlers.
func (h *Hook[T]) UnbindAll() {
h.mu.Lock()
defer h.mu.Unlock()
h.handlers = nil
}
// Length returns to total number of registered hook handlers.
func (h *Hook[T]) Length() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.handlers)
}
// Trigger executes all registered hook handlers one by one
// with the specified `data` as an argument.
// with the specified event as an argument.
//
// Optionally, this method allows also to register additional one off
// handlers that will be temporary appended to the handlers queue.
//
// The execution stops when:
// - hook.StopPropagation is returned in one of the handlers
// - any non-nil error is returned in one of the handlers
func (h *Hook[T]) Trigger(data T, oneOffHandlers ...Handler[T]) error {
h.mux.RLock()
// NB! Each hook handler must call event.Next() in order the hook chain to proceed.
func (h *Hook[T]) Trigger(event T, oneOffHandlers ...HandlerFunc[T]) error {
h.mu.RLock()
handlers := make([]HandlerFunc[T], 0, len(h.handlers)+len(oneOffHandlers))
for _, handler := range h.handlers {
handlers = append(handlers, handler.Func)
}
handlers = append(handlers, oneOffHandlers...)
h.mu.RUnlock()
handlers := make([]*handlerPair[T], 0, len(h.handlers)+len(oneOffHandlers))
handlers = append(handlers, h.handlers...)
event.setNextFunc(nil) // reset in case the event is being reused
// append the one off handlers
for i, oneOff := range oneOffHandlers {
handlers = append(handlers, &handlerPair[T]{
id: fmt.Sprintf("@%d", i),
handler: oneOff,
for i := len(handlers) - 1; i >= 0; i-- {
i := i
old := event.nextFunc()
event.setNextFunc(func() error {
event.setNextFunc(old)
return handlers[i](event)
})
}
// unlock is not deferred to avoid deadlocks in case Trigger
// is called recursively by the handlers
h.mux.RUnlock()
for _, item := range handlers {
err := item.handler(data)
if err == nil {
continue
}
if errors.Is(err, StopPropagation) {
return nil
}
return err
}
return nil
return event.Next()
}
func generateHookId() string {
return security.PseudorandomString(8)
return security.PseudorandomString(20)
}
+106 -124
View File
@@ -5,175 +5,157 @@ import (
"testing"
)
func TestHookAddAndPreAdd(t *testing.T) {
h := Hook[int]{}
func TestHookAddHandlerAndAdd(t *testing.T) {
calls := ""
if total := len(h.handlers); total != 0 {
t.Fatalf("Expected no handlers, found %d", total)
h := Hook[*Event]{}
h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
h3Id := h.BindFunc(func(e *Event) error { calls += "3"; return e.Next() })
h.Bind(&Handler[*Event]{
Id: h3Id, // should replace 3
Func: func(e *Event) error { calls += "3'"; return e.Next() },
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "4"; return e.Next() },
Priority: -2,
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "5"; return e.Next() },
Priority: -1,
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "6"; return e.Next() },
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "7"; e.Next(); return errors.New("test") }, // error shouldn't stop the chain
})
h.Trigger(
&Event{},
func(e *Event) error { calls += "8"; return e.Next() },
func(e *Event) error { calls += "9"; return nil }, // skip next
func(e *Event) error { calls += "10"; return e.Next() },
)
if total := len(h.handlers); total != 7 {
t.Fatalf("Expected %d handlers, found %d", 7, total)
}
triggerSequence := ""
expectedCalls := "45123'6789"
f1 := func(data int) error { triggerSequence += "f1"; return nil }
f2 := func(data int) error { triggerSequence += "f2"; return nil }
f3 := func(data int) error { triggerSequence += "f3"; return nil }
f4 := func(data int) error { triggerSequence += "f4"; return nil }
h.Add(f1)
h.Add(f2)
h.PreAdd(f3)
h.PreAdd(f4)
h.Trigger(1)
if total := len(h.handlers); total != 4 {
t.Fatalf("Expected %d handlers, found %d", 4, total)
}
expectedTriggerSequence := "f4f3f1f2"
if triggerSequence != expectedTriggerSequence {
t.Fatalf("Expected trigger sequence %s, got %s", expectedTriggerSequence, triggerSequence)
if calls != expectedCalls {
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
}
}
func TestHookRemove(t *testing.T) {
h := Hook[int]{}
func TestHookLength(t *testing.T) {
h := Hook[*Event]{}
h1Called := false
h2Called := false
if l := h.Length(); l != 0 {
t.Fatalf("Expected 0 hook handlers, got %d", l)
}
id1 := h.Add(func(data int) error { h1Called = true; return nil })
h.Add(func(data int) error { h2Called = true; return nil })
h.BindFunc(func(e *Event) error { return e.Next() })
h.BindFunc(func(e *Event) error { return e.Next() })
h.Remove("missing") // should do nothing and not panic
if l := h.Length(); l != 2 {
t.Fatalf("Expected 2 hook handlers, got %d", l)
}
}
func TestHookUnbind(t *testing.T) {
h := Hook[*Event]{}
calls := ""
id1 := h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "3"; return e.Next() },
})
h.Unbind("missing") // should do nothing and not panic
if total := len(h.handlers); total != 3 {
t.Fatalf("Expected %d handlers, got %d", 3, total)
}
h.Unbind(id1)
if total := len(h.handlers); total != 2 {
t.Fatalf("Expected %d handlers, got %d", 2, total)
}
h.Remove(id1)
if total := len(h.handlers); total != 1 {
t.Fatalf("Expected %d handlers, got %d", 1, total)
}
if err := h.Trigger(1); err != nil {
err := h.Trigger(&Event{}, func(e *Event) error { calls += "4"; return e.Next() })
if err != nil {
t.Fatal(err)
}
if h1Called {
t.Fatalf("Expected hook 1 to be removed and not called")
}
expectedCalls := "234"
if !h2Called {
t.Fatalf("Expected hook 2 to be called")
if calls != expectedCalls {
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
}
}
func TestHookRemoveAll(t *testing.T) {
h := Hook[int]{}
func TestHookUnbindAll(t *testing.T) {
h := Hook[*Event]{}
h.RemoveAll() // should do nothing and not panic
h.UnbindAll() // should do nothing and not panic
h.Add(func(data int) error { return nil })
h.Add(func(data int) error { return nil })
h.BindFunc(func(e *Event) error { return nil })
h.BindFunc(func(e *Event) error { return nil })
if total := len(h.handlers); total != 2 {
t.Fatalf("Expected 2 handlers before RemoveAll, found %d", total)
t.Fatalf("Expected %d handlers before UnbindAll, found %d", 2, total)
}
h.RemoveAll()
h.UnbindAll()
if total := len(h.handlers); total != 0 {
t.Fatalf("Expected no handlers after RemoveAll, found %d", total)
t.Fatalf("Expected no handlers after UnbindAll, found %d", total)
}
}
func TestHookTrigger(t *testing.T) {
err1 := errors.New("demo")
err2 := errors.New("demo")
func TestHookTriggerErrorPropagation(t *testing.T) {
err := errors.New("test")
scenarios := []struct {
handlers []Handler[int]
name string
handlers []HandlerFunc[*Event]
expectedError error
}{
{
[]Handler[int]{
func(data int) error { return nil },
func(data int) error { return nil },
"without error",
[]HandlerFunc[*Event]{
func(e *Event) error { return e.Next() },
func(e *Event) error { return e.Next() },
},
nil,
},
{
[]Handler[int]{
func(data int) error { return nil },
func(data int) error { return err1 },
func(data int) error { return err2 },
"with error",
[]HandlerFunc[*Event]{
func(e *Event) error { return e.Next() },
func(e *Event) error { e.Next(); return err },
func(e *Event) error { return e.Next() },
},
err1,
err,
},
}
for i, scenario := range scenarios {
h := Hook[int]{}
for _, handler := range scenario.handlers {
h.Add(handler)
}
result := h.Trigger(1)
if result != scenario.expectedError {
t.Fatalf("(%d) Expected %v, got %v", i, scenario.expectedError, result)
}
}
}
func TestHookTriggerStopPropagation(t *testing.T) {
called1 := false
f1 := func(data int) error { called1 = true; return nil }
called2 := false
f2 := func(data int) error { called2 = true; return nil }
called3 := false
f3 := func(data int) error { called3 = true; return nil }
called4 := false
f4 := func(data int) error { called4 = true; return StopPropagation }
called5 := false
f5 := func(data int) error { called5 = true; return nil }
called6 := false
f6 := func(data int) error { called6 = true; return nil }
h := Hook[int]{}
h.Add(f1)
h.Add(f2)
result := h.Trigger(123, f3, f4, f5, f6)
if result != nil {
t.Fatalf("Expected nil after StopPropagation, got %v", result)
}
// ensure that the trigger handler were not persisted
if total := len(h.handlers); total != 2 {
t.Fatalf("Expected 2 handlers, found %d", total)
}
scenarios := []struct {
called bool
expected bool
}{
{called1, true},
{called2, true},
{called3, true},
{called4, true}, // StopPropagation
{called5, false},
{called6, false},
}
for i, scenario := range scenarios {
if scenario.called != scenario.expected {
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, scenario.called)
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
h := Hook[*Event]{}
for _, handler := range s.handlers {
h.BindFunc(handler)
}
result := h.Trigger(&Event{})
if result != s.expectedError {
t.Fatalf("Expected %v, got %v", s.expectedError, result)
}
})
}
}
+23 -13
View File
@@ -7,6 +7,8 @@ import (
// Tagger defines an interface for event data structs that support tags/groups/categories/etc.
// Usually used together with TaggedHook.
type Tagger interface {
Resolver
Tags() []string
}
@@ -33,12 +35,14 @@ type TaggedHook[T Tagger] struct {
// CanTriggerOn checks if the current TaggedHook can be triggered with
// the provided event data tags.
func (h *TaggedHook[T]) CanTriggerOn(tags []string) bool {
//
// It returns always true if the hook doens't have any tags.
func (h *TaggedHook[T]) CanTriggerOn(tagsToCheck []string) bool {
if len(h.tags) == 0 {
return true // match all
}
for _, t := range tags {
for _, t := range tagsToCheck {
if list.ExistInSlice(t, h.tags) {
return true
}
@@ -47,28 +51,34 @@ func (h *TaggedHook[T]) CanTriggerOn(tags []string) bool {
return false
}
// PreAdd registers a new handler to the hook by prepending it to the existing queue.
// Bind registers the provided handler to the current hooks queue.
//
// The fn handler will be called only if the event data tags satisfy h.CanTriggerOn.
func (h *TaggedHook[T]) PreAdd(fn Handler[T]) string {
return h.mainHook.PreAdd(func(e T) error {
// It is similar to [Hook.Bind] with the difference that the handler
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
func (h *TaggedHook[T]) Bind(handler *Handler[T]) string {
fn := handler.Func
handler.Func = func(e T) error {
if h.CanTriggerOn(e.Tags()) {
return fn(e)
}
return nil
})
return e.Next()
}
return h.mainHook.Bind(handler)
}
// Add registers a new handler to the hook by appending it to the existing queue.
// BindFunc registers a new handler with the specified function.
//
// The fn handler will be called only if the event data tags satisfy h.CanTriggerOn.
func (h *TaggedHook[T]) Add(fn Handler[T]) string {
return h.mainHook.Add(func(e T) error {
// It is similar to [Hook.Bind] with the difference that the handler
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
func (h *TaggedHook[T]) BindFunc(fn HandlerFunc[T]) string {
return h.mainHook.BindFunc(func(e T) error {
if h.CanTriggerOn(e.Tags()) {
return fn(e)
}
return nil
return e.Next()
})
}
+43 -28
View File
@@ -1,69 +1,84 @@
package hook
import "testing"
import (
"strings"
"testing"
)
type mockTagsData struct {
type mockTagsEvent struct {
Event
tags []string
}
func (m mockTagsData) Tags() []string {
func (m mockTagsEvent) Tags() []string {
return m.tags
}
func TestTaggedHook(t *testing.T) {
triggerSequence := ""
calls := ""
base := &Hook[mockTagsData]{}
base.Add(func(data mockTagsData) error { triggerSequence += "f0"; return nil })
base := &Hook[*mockTagsEvent]{}
base.BindFunc(func(e *mockTagsEvent) error { calls += "f0"; return e.Next() })
hA := NewTaggedHook(base)
hA.Add(func(data mockTagsData) error { triggerSequence += "a1"; return nil })
hA.PreAdd(func(data mockTagsData) error { triggerSequence += "a2"; return nil })
hA.BindFunc(func(e *mockTagsEvent) error { calls += "a1"; return e.Next() })
hA.Bind(&Handler[*mockTagsEvent]{
Func: func(e *mockTagsEvent) error { calls += "a2"; return e.Next() },
Priority: -1,
})
hB := NewTaggedHook(base, "b1", "b2")
hB.Add(func(data mockTagsData) error { triggerSequence += "b1"; return nil })
hB.PreAdd(func(data mockTagsData) error { triggerSequence += "b2"; return nil })
hB.BindFunc(func(e *mockTagsEvent) error { calls += "b1"; return e.Next() })
hB.Bind(&Handler[*mockTagsEvent]{
Func: func(e *mockTagsEvent) error { calls += "b2"; return e.Next() },
Priority: -2,
})
hC := NewTaggedHook(base, "c1", "c2")
hC.Add(func(data mockTagsData) error { triggerSequence += "c1"; return nil })
hC.PreAdd(func(data mockTagsData) error { triggerSequence += "c2"; return nil })
hC.BindFunc(func(e *mockTagsEvent) error { calls += "c1"; return e.Next() })
hC.Bind(&Handler[*mockTagsEvent]{
Func: func(e *mockTagsEvent) error { calls += "c2"; return e.Next() },
Priority: -3,
})
scenarios := []struct {
data mockTagsData
expectedSequence string
event *mockTagsEvent
expectedCalls string
}{
{
mockTagsData{},
&mockTagsEvent{},
"a2f0a1",
},
{
mockTagsData{[]string{"missing"}},
&mockTagsEvent{tags: []string{"missing"}},
"a2f0a1",
},
{
mockTagsData{[]string{"b2"}},
&mockTagsEvent{tags: []string{"b2"}},
"b2a2f0a1b1",
},
{
mockTagsData{[]string{"c1"}},
&mockTagsEvent{tags: []string{"c1"}},
"c2a2f0a1c1",
},
{
mockTagsData{[]string{"b1", "c2"}},
&mockTagsEvent{tags: []string{"b1", "c2"}},
"c2b2a2f0a1b1c1",
},
}
for i, s := range scenarios {
triggerSequence = "" // reset
for _, s := range scenarios {
t.Run(strings.Join(s.event.tags, "_"), func(t *testing.T) {
calls = "" // reset
err := hA.Trigger(s.data)
if err != nil {
t.Fatalf("[%d] Unexpected trigger error: %v", i, err)
}
err := base.Trigger(s.event)
if err != nil {
t.Fatalf("Unexpected trigger error: %v", err)
}
if triggerSequence != s.expectedSequence {
t.Fatalf("[%d] Expected trigger sequence %s, got %s", i, s.expectedSequence, triggerSequence)
}
if calls != s.expectedCalls {
t.Fatalf("Expected calls sequence %q, got %q", s.expectedCalls, calls)
}
})
}
}
+39 -28
View File
@@ -1,6 +1,7 @@
package inflector_test
import (
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/inflector"
@@ -18,10 +19,13 @@ func TestUcFirst(t *testing.T) {
{"test test2", "Test test2"},
}
for i, scenario := range scenarios {
if result := inflector.UcFirst(scenario.val); result != scenario.expected {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.UcFirst(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
@@ -42,10 +46,13 @@ func TestColumnify(t *testing.T) {
{"test1--test2", "test1--test2"},
}
for i, scenario := range scenarios {
if result := inflector.Columnify(scenario.val); result != scenario.expected {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Columnify(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
@@ -67,10 +74,13 @@ func TestSentenize(t *testing.T) {
{"hello world?", "Hello world?"},
}
for i, scenario := range scenarios {
if result := inflector.Sentenize(scenario.val); result != scenario.expected {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Sentenize(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
@@ -89,21 +99,19 @@ func TestSanitize(t *testing.T) {
{"abcABC", `[A-Z`, "", true}, // invalid pattern
}
for i, scenario := range scenarios {
result, err := inflector.Sanitize(scenario.val, scenario.pattern)
hasErr := err != nil
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result, err := inflector.Sanitize(s.val, s.pattern)
hasErr := err != nil
if scenario.expectErr != hasErr {
if scenario.expectErr {
t.Errorf("(%d) Expected error, got nil", i)
} else {
t.Errorf("(%d) Didn't expect error, got", err)
if s.expectErr != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectErr, hasErr, err)
}
}
if result != scenario.expected {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
}
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
@@ -126,9 +134,12 @@ func TestSnakecase(t *testing.T) {
{"testABR", "test_abr"},
}
for i, scenario := range scenarios {
if result := inflector.Snakecase(scenario.val); result != scenario.expected {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Snakecase(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
+26 -3
View File
@@ -39,7 +39,7 @@ func ExistInSlice[T comparable](item T, list []T) bool {
// ExistInSliceWithRegex checks whether a string exists in a slice
// either by direct match, or by a regular expression (eg. `^\w+$`).
//
// _Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!_
// Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!
func ExistInSliceWithRegex(str string, list []string) bool {
for _, field := range list {
isRegex := strings.HasPrefix(field, "^") && strings.HasSuffix(field, "$")
@@ -64,7 +64,7 @@ func ExistInSliceWithRegex(str string, list []string) bool {
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
//
// @todo consider replacing with TTL or LRU type cache
cachedPatterns.SetIfLessThanLimit(field, pattern, 5000)
cachedPatterns.SetIfLessThanLimit(field, pattern, 500)
}
if pattern != nil && pattern.MatchString(str) {
@@ -129,7 +129,7 @@ func ToUniqueStringSlice(value any) (result []string) {
// just add the string as single array element
result = append(result, val)
}
case json.Marshaler: // eg. JsonArray
case json.Marshaler: // eg. JSONArray
raw, _ := val.MarshalJSON()
_ = json.Unmarshal(raw, &result)
default:
@@ -138,3 +138,26 @@ func ToUniqueStringSlice(value any) (result []string) {
return NonzeroUniques(result)
}
// ToChunks splits list into chunks.
//
// Zero or negative chunkSize argument is normalized to 1.
//
// See https://go.dev/wiki/SliceTricks#batching-with-minimal-allocation.
func ToChunks[T any](list []T, chunkSize int) [][]T {
if chunkSize <= 0 {
chunkSize = 1
}
chunks := make([][]T, 0, (len(list)+chunkSize-1)/chunkSize)
if len(list) == 0 {
return chunks
}
for chunkSize < len(list) {
list, chunks = list[chunkSize:], append(chunks, list[0:chunkSize:chunkSize])
}
return append(chunks, list)
}
+111 -74
View File
@@ -2,6 +2,7 @@ package list_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/list"
@@ -42,18 +43,20 @@ func TestSubtractSliceString(t *testing.T) {
}
for i, s := range scenarios {
result := list.SubtractSlice(s.base, s.subtract)
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
result := list.SubtractSlice(s.base, s.subtract)
raw, err := json.Marshal(result)
if err != nil {
t.Fatalf("(%d) Failed to serialize: %v", i, err)
}
raw, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to serialize: %v", err)
}
strResult := string(raw)
strResult := string(raw)
if strResult != s.expected {
t.Fatalf("(%d) Expected %v, got %v", i, s.expected, strResult)
}
if strResult != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, strResult)
}
})
}
}
@@ -91,18 +94,20 @@ func TestSubtractSliceInt(t *testing.T) {
}
for i, s := range scenarios {
result := list.SubtractSlice(s.base, s.subtract)
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
result := list.SubtractSlice(s.base, s.subtract)
raw, err := json.Marshal(result)
if err != nil {
t.Fatalf("(%d) Failed to serialize: %v", i, err)
}
raw, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to serialize: %v", err)
}
strResult := string(raw)
strResult := string(raw)
if strResult != s.expected {
t.Fatalf("(%d) Expected %v, got %v", i, s.expected, strResult)
}
if strResult != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, strResult)
}
})
}
}
@@ -120,15 +125,13 @@ func TestExistInSliceString(t *testing.T) {
{"test", []string{"1", "2", "test"}, true},
}
for i, scenario := range scenarios {
result := list.ExistInSlice(scenario.item, scenario.list)
if result != scenario.expected {
if scenario.expected {
t.Errorf("(%d) Expected to exist in the list", i)
} else {
t.Errorf("(%d) Expected NOT to exist in the list", i)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
result := list.ExistInSlice(s.item, s.list)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
}
})
}
}
@@ -146,15 +149,13 @@ func TestExistInSliceInt(t *testing.T) {
{-1, []int{0, -1, -2, -3, -4}, true},
}
for i, scenario := range scenarios {
result := list.ExistInSlice(scenario.item, scenario.list)
if result != scenario.expected {
if scenario.expected {
t.Errorf("(%d) Expected to exist in the list", i)
} else {
t.Errorf("(%d) Expected NOT to exist in the list", i)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%d", i, s.item), func(t *testing.T) {
result := list.ExistInSlice(s.item, s.list)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
}
})
}
}
@@ -177,15 +178,13 @@ func TestExistInSliceWithRegex(t *testing.T) {
{"!?@test", []string{`^\W+$`, "test"}, false},
}
for i, scenario := range scenarios {
result := list.ExistInSliceWithRegex(scenario.item, scenario.list)
if result != scenario.expected {
if scenario.expected {
t.Errorf("(%d) Expected the string to exist in the list", i)
} else {
t.Errorf("(%d) Expected the string NOT to exist in the list", i)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
result := list.ExistInSliceWithRegex(s.item, s.list)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
}
})
}
}
@@ -196,21 +195,23 @@ func TestToInterfaceSlice(t *testing.T) {
{[]string{}},
{[]string{""}},
{[]string{"1", "test"}},
{[]string{"test1", "test2", "test3"}},
{[]string{"test1", "test1", "test2", "test3"}},
}
for i, scenario := range scenarios {
result := list.ToInterfaceSlice(scenario.items)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
result := list.ToInterfaceSlice(s.items)
if len(result) != len(scenario.items) {
t.Errorf("(%d) Result list length doesn't match with the original list", i)
}
for j, v := range result {
if v != scenario.items[j] {
t.Errorf("(%d:%d) Result list item should match with the original list item", i, j)
if len(result) != len(s.items) {
t.Fatalf("Expected length %d, got %d", len(s.items), len(result))
}
}
for j, v := range result {
if v != s.items[j] {
t.Fatalf("Result list item doesn't match with the original list item, got %v VS %v", v, s.items[j])
}
}
})
}
}
@@ -225,18 +226,20 @@ func TestNonzeroUniquesString(t *testing.T) {
{[]string{"test1", "", "test2", "Test2", "test1", "test3"}, []string{"test1", "test2", "Test2", "test3"}},
}
for i, scenario := range scenarios {
result := list.NonzeroUniques(scenario.items)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
result := list.NonzeroUniques(s.items)
if len(result) != len(scenario.expected) {
t.Errorf("(%d) Result list length doesn't match with the expected list", i)
}
for j, v := range result {
if v != scenario.expected[j] {
t.Errorf("(%d:%d) Result list item should match with the expected list item", i, j)
if len(result) != len(s.expected) {
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
}
}
for j, v := range result {
if v != s.expected[j] {
t.Fatalf("Result list item doesn't match with the expected list item, got %v VS %v", v, s.expected[j])
}
}
})
}
}
@@ -254,20 +257,54 @@ func TestToUniqueStringSlice(t *testing.T) {
{[]any{0, 1, "test", ""}, []string{"0", "1", "test"}},
{[]string{"test1", "test2", "test1"}, []string{"test1", "test2"}},
{`["test1", "test2", "test2"]`, []string{"test1", "test2"}},
{types.JsonArray[string]{"test1", "test2", "test1"}, []string{"test1", "test2"}},
{types.JSONArray[string]{"test1", "test2", "test1"}, []string{"test1", "test2"}},
}
for i, scenario := range scenarios {
result := list.ToUniqueStringSlice(scenario.value)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.value), func(t *testing.T) {
result := list.ToUniqueStringSlice(s.value)
if len(result) != len(scenario.expected) {
t.Errorf("(%d) Result list length doesn't match with the expected list", i)
}
for j, v := range result {
if v != scenario.expected[j] {
t.Errorf("(%d:%d) Result list item should match with the expected list item", i, j)
if len(result) != len(s.expected) {
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
}
}
for j, v := range result {
if v != s.expected[j] {
t.Fatalf("Result list item doesn't match with the expected list item, got %v vs %v", v, s.expected[j])
}
}
})
}
}
func TestToChunks(t *testing.T) {
scenarios := []struct {
items []any
chunkSize int
expected string
}{
{nil, 2, "[]"},
{[]any{}, 2, "[]"},
{[]any{1, 2, 3, 4}, -1, "[[1],[2],[3],[4]]"},
{[]any{1, 2, 3, 4}, 0, "[[1],[2],[3],[4]]"},
{[]any{1, 2, 3, 4}, 2, "[[1,2],[3,4]]"},
{[]any{1, 2, 3, 4, 5}, 2, "[[1,2],[3,4],[5]]"},
{[]any{1, 2, 3, 4, 5}, 10, "[[1,2,3,4,5]]"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
result := list.ToChunks(s.items, s.chunkSize)
raw, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, rawStr)
}
})
}
}
+20 -7
View File
@@ -2,9 +2,11 @@ package logger
import (
"context"
"errors"
"log/slog"
"sync"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/types"
)
@@ -160,7 +162,6 @@ func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error {
if err := h.resolveAttr(data, a); err != nil {
return false
}
return true
})
@@ -168,7 +169,7 @@ func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error {
Time: r.Time,
Level: r.Level,
Message: r.Message,
Data: types.JsonMap(data),
Data: types.JSONMap[any](data),
}
if h.options.BeforeAddFunc != nil && !h.options.BeforeAddFunc(ctx, log) {
@@ -251,11 +252,23 @@ func (h *BatchHandler) resolveAttr(data map[string]any, attr slog.Attr) error {
data[attr.Key] = groupData
}
default:
v := attr.Value.Any()
if err, ok := v.(error); ok {
data[attr.Key] = err.Error()
} else {
switch v := attr.Value.Any().(type) {
case validation.Errors:
data[attr.Key] = map[string]any{
"data": v,
"raw": v.Error(),
}
case error:
var ve validation.Errors
if errors.As(v, &ve) {
data[attr.Key] = map[string]any{
"data": ve,
"raw": v.Error(),
}
} else {
data[attr.Key] = v.Error()
}
default:
data[attr.Key] = v
}
}
+2 -5
View File
@@ -181,11 +181,8 @@ func TestBatchHandlerHandle(t *testing.T) {
BeforeAddFunc: func(_ context.Context, log *Log) bool {
beforeLogs = append(beforeLogs, log)
if log.Message == "test2" {
return false // skip test2 log
}
return true
// skip test2 log
return log.Message != "test2"
},
WriteFunc: func(_ context.Context, logs []*Log) error {
writeLogs = logs
+1 -1
View File
@@ -11,7 +11,7 @@ import (
// preformatted JSON map.
type Log struct {
Time time.Time
Data types.JSONMap[any]
Message string
Level slog.Level
Data types.JsonMap
}
+1 -1
View File
@@ -4,7 +4,7 @@ import (
"testing"
)
func TestHtml2Text(t *testing.T) {
func TestHTML2Text(t *testing.T) {
scenarios := []struct {
html string
expected string
+12
View File
@@ -3,6 +3,8 @@ package mailer
import (
"io"
"net/mail"
"github.com/pocketbase/pocketbase/tools/hook"
)
// Message defines a generic email message struct.
@@ -24,6 +26,16 @@ type Mailer interface {
Send(message *Message) error
}
// SendInterceptor is optional interface for registering mail send hooks.
type SendInterceptor interface {
OnSend() *hook.Hook[*SendEvent]
}
type SendEvent struct {
hook.Event
Message *Message
}
// addressesToStrings converts the provided address to a list of serialized RFC 5322 strings.
//
// To export only the email part of mail.Address, you can set withName to false.
+23 -2
View File
@@ -7,6 +7,8 @@ import (
"net/http"
"os/exec"
"strings"
"github.com/pocketbase/pocketbase/tools/hook"
)
var _ Mailer = (*Sendmail)(nil)
@@ -16,10 +18,29 @@ var _ Mailer = (*Sendmail)(nil)
//
// This client is usually recommended only for development and testing.
type Sendmail struct {
onSend *hook.Hook[*SendEvent]
}
// Send implements `mailer.Mailer` interface.
// OnSend implements [mailer.SendInterceptor] interface.
func (c *Sendmail) OnSend() *hook.Hook[*SendEvent] {
if c.onSend == nil {
c.onSend = &hook.Hook[*SendEvent]{}
}
return c.onSend
}
// Send implements [mailer.Mailer] interface.
func (c *Sendmail) Send(m *Message) error {
if c.onSend != nil {
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
return c.send(e.Message)
})
}
return c.send(m)
}
func (c *Sendmail) send(m *Message) error {
toAddresses := addressesToStrings(m.To, false)
headers := make(http.Header)
@@ -74,5 +95,5 @@ func findSendmailPath() (string, error) {
}
}
return "", errors.New("failed to locate a sendmail executable path")
return "", errors.New("Failed to locate a sendmail executable path.")
}
+32 -30
View File
@@ -7,43 +7,27 @@ import (
"strings"
"github.com/domodwyer/mailyak/v3"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/security"
)
var _ Mailer = (*SmtpClient)(nil)
var _ Mailer = (*SMTPClient)(nil)
const (
SmtpAuthPlain = "PLAIN"
SmtpAuthLogin = "LOGIN"
SMTPAuthPlain = "PLAIN"
SMTPAuthLogin = "LOGIN"
)
// Deprecated: Use directly the SmtpClient struct literal.
//
// NewSmtpClient creates new SmtpClient with the provided configuration.
func NewSmtpClient(
host string,
port int,
username string,
password string,
tls bool,
) *SmtpClient {
return &SmtpClient{
Host: host,
Port: port,
Username: username,
Password: password,
Tls: tls,
}
}
// SmtpClient defines a SMTP mail client structure that implements
// SMTPClient defines a SMTP mail client structure that implements
// `mailer.Mailer` interface.
type SmtpClient struct {
Host string
type SMTPClient struct {
onSend *hook.Hook[*SendEvent]
TLS bool
Port int
Host string
Username string
Password string
Tls bool
// SMTP auth method to use
// (if not explicitly set, defaults to "PLAIN")
@@ -56,12 +40,30 @@ type SmtpClient struct {
LocalName string
}
// Send implements `mailer.Mailer` interface.
func (c *SmtpClient) Send(m *Message) error {
// OnSend implements [mailer.SendInterceptor] interface.
func (c *SMTPClient) OnSend() *hook.Hook[*SendEvent] {
if c.onSend == nil {
c.onSend = &hook.Hook[*SendEvent]{}
}
return c.onSend
}
// Send implements [mailer.Mailer] interface.
func (c *SMTPClient) Send(m *Message) error {
if c.onSend != nil {
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
return c.send(e.Message)
})
}
return c.send(m)
}
func (c *SMTPClient) send(m *Message) error {
var smtpAuth smtp.Auth
if c.Username != "" || c.Password != "" {
switch c.AuthMethod {
case SmtpAuthLogin:
case SMTPAuthLogin:
smtpAuth = &smtpLoginAuth{c.Username, c.Password}
default:
smtpAuth = smtp.PlainAuth("", c.Username, c.Password, c.Host)
@@ -70,7 +72,7 @@ func (c *SmtpClient) Send(m *Message) error {
// create mail instance
var yak *mailyak.MailYak
if c.Tls {
if c.TLS {
var tlsErr error
yak, tlsErr = mailyak.NewWithTLS(fmt.Sprintf("%s:%d", c.Host, c.Port), smtpAuth, nil)
if tlsErr != nil {
+16 -14
View File
@@ -56,24 +56,26 @@ func TestLoginAuthStart(t *testing.T) {
}
for _, s := range scenarios {
method, resp, err := auth.Start(s.serverInfo)
t.Run(s.name, func(t *testing.T) {
method, resp, err := auth.Start(s.serverInfo)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("[%s] Expected hasErr %v, got %v", s.name, s.expectError, hasErr)
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
if hasErr {
continue
}
if hasErr {
return
}
if len(resp) != 0 {
t.Fatalf("[%s] Expected empty data response, got %v", s.name, resp)
}
if len(resp) != 0 {
t.Fatalf("Expected empty data response, got %v", resp)
}
if method != "LOGIN" {
t.Fatalf("[%s] Expected LOGIN, got %v", s.name, method)
}
if method != "LOGIN" {
t.Fatalf("Expected LOGIN, got %v", method)
}
})
}
}
-59
View File
@@ -1,59 +0,0 @@
package migrate
import (
"path/filepath"
"runtime"
"sort"
"github.com/pocketbase/dbx"
)
type Migration struct {
File string
Up func(db dbx.Builder) error
Down func(db dbx.Builder) error
}
// MigrationsList defines a list with migration definitions
type MigrationsList struct {
list []*Migration
}
// Item returns a single migration from the list by its index.
func (l *MigrationsList) Item(index int) *Migration {
return l.list[index]
}
// Items returns the internal migrations list slice.
func (l *MigrationsList) Items() []*Migration {
return l.list
}
// Register adds new migration definition to the list.
//
// If `optFilename` is not provided, it will try to get the name from its .go file.
//
// The list will be sorted automatically based on the migrations file name.
func (l *MigrationsList) Register(
up func(db dbx.Builder) error,
down func(db dbx.Builder) error,
optFilename ...string,
) {
var file string
if len(optFilename) > 0 {
file = optFilename[0]
} else {
_, path, _, _ := runtime.Caller(1)
file = filepath.Base(path)
}
l.list = append(l.list, &Migration{
File: file,
Up: up,
Down: down,
})
sort.Slice(l.list, func(i int, j int) bool {
return l.list[i].File < l.list[j].File
})
}
-33
View File
@@ -1,33 +0,0 @@
package migrate
import (
"testing"
)
func TestMigrationsList(t *testing.T) {
l := MigrationsList{}
l.Register(nil, nil, "3_test.go")
l.Register(nil, nil, "1_test.go")
l.Register(nil, nil, "2_test.go")
l.Register(nil, nil /* auto detect file name */)
expected := []string{
"1_test.go",
"2_test.go",
"3_test.go",
"list_test.go",
}
items := l.Items()
if len(items) != len(expected) {
t.Fatalf("Expected %d items, got %d: \n%#v", len(expected), len(items), items)
}
for i, name := range expected {
item := l.Item(i)
if item.File != name {
t.Fatalf("Expected name %s for index %d, got %s", name, i, item.File)
}
}
}
-275
View File
@@ -1,275 +0,0 @@
package migrate
import (
"fmt"
"strings"
"time"
"github.com/AlecAivazis/survey/v2"
"github.com/fatih/color"
"github.com/pocketbase/dbx"
"github.com/spf13/cast"
)
const DefaultMigrationsTable = "_migrations"
// Runner defines a simple struct for managing the execution of db migrations.
type Runner struct {
db *dbx.DB
migrationsList MigrationsList
tableName string
}
// NewRunner creates and initializes a new db migrations Runner instance.
func NewRunner(db *dbx.DB, migrationsList MigrationsList) (*Runner, error) {
runner := &Runner{
db: db,
migrationsList: migrationsList,
tableName: DefaultMigrationsTable,
}
if err := runner.createMigrationsTable(); err != nil {
return nil, err
}
return runner, nil
}
// Run interactively executes the current runner with the provided args.
//
// The following commands are supported:
// - up - applies all migrations
// - down [n] - reverts the last n applied migrations
func (r *Runner) Run(args ...string) error {
cmd := "up"
if len(args) > 0 {
cmd = args[0]
}
switch cmd {
case "up":
applied, err := r.Up()
if err != nil {
return err
}
if len(applied) == 0 {
color.Green("No new migrations to apply.")
} else {
for _, file := range applied {
color.Green("Applied %s", file)
}
}
return nil
case "down":
toRevertCount := 1
if len(args) > 1 {
toRevertCount = cast.ToInt(args[1])
if toRevertCount < 0 {
// revert all applied migrations
toRevertCount = len(r.migrationsList.Items())
}
}
names, err := r.lastAppliedMigrations(toRevertCount)
if err != nil {
return err
}
confirm := false
prompt := &survey.Confirm{
Message: fmt.Sprintf(
"\n%v\nDo you really want to revert the last %d applied migration(s)?",
strings.Join(names, "\n"),
toRevertCount,
),
}
survey.AskOne(prompt, &confirm)
if !confirm {
fmt.Println("The command has been cancelled")
return nil
}
reverted, err := r.Down(toRevertCount)
if err != nil {
return err
}
if len(reverted) == 0 {
color.Green("No migrations to revert.")
} else {
for _, file := range reverted {
color.Green("Reverted %s", file)
}
}
return nil
case "history-sync":
if err := r.removeMissingAppliedMigrations(); err != nil {
return err
}
color.Green("The %s table was synced with the available migrations.", r.tableName)
return nil
default:
return fmt.Errorf("Unsupported command: %q\n", cmd)
}
}
// Up executes all unapplied migrations for the provided runner.
//
// On success returns list with the applied migrations file names.
func (r *Runner) Up() ([]string, error) {
applied := []string{}
err := r.db.Transactional(func(tx *dbx.Tx) error {
for _, m := range r.migrationsList.Items() {
// skip applied
if r.isMigrationApplied(tx, m.File) {
continue
}
// ignore empty Up action
if m.Up != nil {
if err := m.Up(tx); err != nil {
return fmt.Errorf("Failed to apply migration %s: %w", m.File, err)
}
}
if err := r.saveAppliedMigration(tx, m.File); err != nil {
return fmt.Errorf("Failed to save applied migration info for %s: %w", m.File, err)
}
applied = append(applied, m.File)
}
return nil
})
if err != nil {
return nil, err
}
return applied, nil
}
// Down reverts the last `toRevertCount` applied migrations
// (in the order they were applied).
//
// On success returns list with the reverted migrations file names.
func (r *Runner) Down(toRevertCount int) ([]string, error) {
reverted := make([]string, 0, toRevertCount)
names, appliedErr := r.lastAppliedMigrations(toRevertCount)
if appliedErr != nil {
return nil, appliedErr
}
err := r.db.Transactional(func(tx *dbx.Tx) error {
for _, name := range names {
for _, m := range r.migrationsList.Items() {
if m.File != name {
continue
}
// revert limit reached
if toRevertCount-len(reverted) <= 0 {
return nil
}
// ignore empty Down action
if m.Down != nil {
if err := m.Down(tx); err != nil {
return fmt.Errorf("Failed to revert migration %s: %w", m.File, err)
}
}
if err := r.saveRevertedMigration(tx, m.File); err != nil {
return fmt.Errorf("Failed to save reverted migration info for %s: %w", m.File, err)
}
reverted = append(reverted, m.File)
}
}
return nil
})
if err != nil {
return nil, err
}
return reverted, nil
}
func (r *Runner) createMigrationsTable() error {
rawQuery := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %v (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
r.db.QuoteTableName(r.tableName),
)
_, err := r.db.NewQuery(rawQuery).Execute()
return err
}
func (r *Runner) isMigrationApplied(tx dbx.Builder, file string) bool {
var exists bool
err := tx.Select("count(*)").
From(r.tableName).
Where(dbx.HashExp{"file": file}).
Limit(1).
Row(&exists)
return err == nil && exists
}
func (r *Runner) saveAppliedMigration(tx dbx.Builder, file string) error {
_, err := tx.Insert(r.tableName, dbx.Params{
"file": file,
"applied": time.Now().UnixMicro(),
}).Execute()
return err
}
func (r *Runner) saveRevertedMigration(tx dbx.Builder, file string) error {
_, err := tx.Delete(r.tableName, dbx.HashExp{"file": file}).Execute()
return err
}
func (r *Runner) lastAppliedMigrations(limit int) ([]string, error) {
var files = make([]string, 0, limit)
err := r.db.Select("file").
From(r.tableName).
Where(dbx.Not(dbx.HashExp{"applied": nil})).
// unify microseconds and seconds applied time for backward compatibility
OrderBy("substr(applied||'0000000000000000', 0, 17) DESC").
AndOrderBy("file DESC").
Limit(int64(limit)).
Column(&files)
if err != nil {
return nil, err
}
return files, nil
}
func (r *Runner) removeMissingAppliedMigrations() error {
loadedMigrations := r.migrationsList.Items()
names := make([]any, len(loadedMigrations))
for i, migration := range loadedMigrations {
names[i] = migration.File
}
_, err := r.db.Delete(r.tableName, dbx.Not(dbx.HashExp{
"file": names,
})).Execute()
return err
}
-216
View File
@@ -1,216 +0,0 @@
package migrate
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/list"
_ "modernc.org/sqlite"
)
func TestNewRunner(t *testing.T) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
l := MigrationsList{}
l.Register(nil, nil, "1_test.go")
l.Register(nil, nil, "2_test.go")
l.Register(nil, nil, "3_test.go")
r, err := NewRunner(testDB.DB, l)
if err != nil {
t.Fatal(err)
}
if len(r.migrationsList.Items()) != len(l.Items()) {
t.Fatalf("Expected the same migrations list to be assigned, got \n%#v", r.migrationsList)
}
expectedQueries := []string{
"CREATE TABLE IF NOT EXISTS `_migrations` (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
}
if len(expectedQueries) != len(testDB.CalledQueries) {
t.Fatalf("Expected %d queries, got %d: \n%v", len(expectedQueries), len(testDB.CalledQueries), testDB.CalledQueries)
}
for _, q := range expectedQueries {
if !list.ExistInSlice(q, testDB.CalledQueries) {
t.Fatalf("Query %s was not found in \n%v", q, testDB.CalledQueries)
}
}
}
func TestRunnerUpAndDown(t *testing.T) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
callsOrder := []string{}
l := MigrationsList{}
l.Register(func(db dbx.Builder) error {
callsOrder = append(callsOrder, "up2")
return nil
}, func(db dbx.Builder) error {
callsOrder = append(callsOrder, "down2")
return nil
}, "2_test")
l.Register(func(db dbx.Builder) error {
callsOrder = append(callsOrder, "up3")
return nil
}, func(db dbx.Builder) error {
callsOrder = append(callsOrder, "down3")
return nil
}, "3_test")
l.Register(func(db dbx.Builder) error {
callsOrder = append(callsOrder, "up1")
return nil
}, func(db dbx.Builder) error {
callsOrder = append(callsOrder, "down1")
return nil
}, "1_test")
r, err := NewRunner(testDB.DB, l)
if err != nil {
t.Fatal(err)
}
// simulate partially out-of-order run migration
r.saveAppliedMigration(testDB, "2_test")
// ---------------------------------------------------------------
// Up()
// ---------------------------------------------------------------
if _, err := r.Up(); err != nil {
t.Fatal(err)
}
expectedUpCallsOrder := `["up1","up3"]` // skip up2 since it was applied previously
upCallsOrder, err := json.Marshal(callsOrder)
if err != nil {
t.Fatal(err)
}
if v := string(upCallsOrder); v != expectedUpCallsOrder {
t.Fatalf("Expected Up() calls order %s, got %s", expectedUpCallsOrder, upCallsOrder)
}
// ---------------------------------------------------------------
// reset callsOrder
callsOrder = []string{}
// simulate unrun migration
r.migrationsList.Register(nil, func(db dbx.Builder) error {
callsOrder = append(callsOrder, "down4")
return nil
}, "4_test")
// ---------------------------------------------------------------
// ---------------------------------------------------------------
// Down()
// ---------------------------------------------------------------
if _, err := r.Down(2); err != nil {
t.Fatal(err)
}
expectedDownCallsOrder := `["down3","down1"]` // revert in the applied order
downCallsOrder, err := json.Marshal(callsOrder)
if err != nil {
t.Fatal(err)
}
if v := string(downCallsOrder); v != expectedDownCallsOrder {
t.Fatalf("Expected Down() calls order %s, got %s", expectedDownCallsOrder, downCallsOrder)
}
}
func TestHistorySync(t *testing.T) {
testDB, err := createTestDB()
if err != nil {
t.Fatal(err)
}
defer testDB.Close()
// mock migrations history
l := MigrationsList{}
l.Register(func(db dbx.Builder) error {
return nil
}, func(db dbx.Builder) error {
return nil
}, "1_test")
l.Register(func(db dbx.Builder) error {
return nil
}, func(db dbx.Builder) error {
return nil
}, "2_test")
l.Register(func(db dbx.Builder) error {
return nil
}, func(db dbx.Builder) error {
return nil
}, "3_test")
r, err := NewRunner(testDB.DB, l)
if err != nil {
t.Fatalf("Failed to initialize the runner: %v", err)
}
if _, err := r.Up(); err != nil {
t.Fatalf("Failed to apply the mock migrations: %v", err)
}
if !r.isMigrationApplied(testDB.DB, "2_test") {
t.Fatalf("Expected 2_test migration to be applied")
}
// mock deleted migrations
r.migrationsList.list = []*Migration{r.migrationsList.list[0], r.migrationsList.list[2]}
if err := r.removeMissingAppliedMigrations(); err != nil {
t.Fatalf("Failed to remove missing applied migrations: %v", err)
}
if r.isMigrationApplied(testDB.DB, "2_test") {
t.Fatalf("Expected 2_test migration to NOT be applied")
}
}
// -------------------------------------------------------------------
// Helpers
// -------------------------------------------------------------------
type testDB struct {
*dbx.DB
CalledQueries []string
}
// NB! Don't forget to call `db.Close()` at the end of the test.
func createTestDB() (*testDB, error) {
sqlDB, err := sql.Open("sqlite", ":memory:")
if err != nil {
return nil, err
}
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
db.CalledQueries = append(db.CalledQueries, sql)
}
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
db.CalledQueries = append(db.CalledQueries, sql)
}
return &db, nil
}
@@ -1,4 +1,4 @@
package rest
package picker
import (
"errors"
@@ -10,6 +10,12 @@ import (
"golang.org/x/net/html"
)
func init() {
Modifiers["excerpt"] = func(args ...string) (Modifier, error) {
return newExcerptModifier(args...)
}
}
var whitespaceRegex = regexp.MustCompile(`\s+`)
var excludeTags = []string{
@@ -24,7 +30,7 @@ var inlineTags = []string{
"strong", "strike", "sub", "sup", "time",
}
var _ FieldModifier = (*excerptModifier)(nil)
var _ Modifier = (*excerptModifier)(nil)
type excerptModifier struct {
max int // approximate max excerpt length
@@ -59,7 +65,7 @@ func newExcerptModifier(args ...string) (*excerptModifier, error) {
return &excerptModifier{max, withEllipsis}, nil
}
// Modify implements the [FieldModifier.Modify] interface method.
// Modify implements the [Modifier.Modify] interface method.
//
// It returns a plain text excerpt/short-description from a formatted
// html string (non-string values are kept untouched).
@@ -1,4 +1,4 @@
package rest
package picker
import (
"fmt"
+41
View File
@@ -0,0 +1,41 @@
package picker
import (
"fmt"
"github.com/pocketbase/pocketbase/tools/tokenizer"
)
var Modifiers = map[string]ModifierFactoryFunc{}
type ModifierFactoryFunc func(args ...string) (Modifier, error)
type Modifier interface {
// Modify executes the modifier and returns a new modified value.
Modify(value any) (any, error)
}
func initModifer(rawModifier string) (Modifier, error) {
t := tokenizer.NewFromString(rawModifier)
t.Separators('(', ')', ',', ' ')
t.IgnoreParenthesis(true)
parts, err := t.ScanAll()
if err != nil {
return nil, err
}
if len(parts) == 0 {
return nil, fmt.Errorf("invalid or empty modifier expression %q", rawModifier)
}
name := parts[0]
args := parts[1:]
factory, ok := Modifiers[name]
if !ok {
return nil, fmt.Errorf("missing or invalid modifier %q", name)
}
return factory(args...)
}
@@ -1,76 +1,25 @@
package rest
package picker
import (
"encoding/json"
"fmt"
"strings"
// Experimental!
//
// Need more tests before replacing encoding/json entirely.
// Test also encoding/json/v2 once released (see https://github.com/golang/go/discussions/63397)
goccy "github.com/goccy/go-json"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/tokenizer"
)
type FieldModifier interface {
// Modify executes the modifier and returns a new modified value.
Modify(value any) (any, error)
}
// Serializer represents custom REST JSON serializer based on echo.DefaultJSONSerializer,
// with support for additional generic response data transformation (eg. fields picker).
type Serializer struct {
echo.DefaultJSONSerializer
FieldsParam string
}
// Serialize converts an interface into a json and writes it to the response.
// Pick converts data into a []any, map[string]any, etc. (using json marshal->unmarshal)
// containing only the fields from the parsed rawFields expression.
//
// It also provides a generic response data fields picker via the FieldsParam query parameter (default to "fields").
//
// Note: for the places where it is safe, the std encoding/json is replaced
// with goccy due to its slightly better Unmarshal/Marshal performance.
func (s *Serializer) Serialize(c echo.Context, i any, indent string) error {
fieldsParam := s.FieldsParam
if fieldsParam == "" {
fieldsParam = "fields"
}
statusCode := c.Response().Status
rawFields := c.QueryParam(fieldsParam)
if rawFields == "" || statusCode < 200 || statusCode > 299 {
return s.DefaultJSONSerializer.Serialize(c, i, indent)
}
decoded, err := PickFields(i, rawFields)
if err != nil {
return err
}
enc := goccy.NewEncoder(c.Response())
if indent != "" {
enc.SetIndent("", indent)
}
return enc.Encode(decoded)
}
// PickFields parses the provided fields string expression and
// returns a new subset of data with only the requested fields.
//
// Fields transformations with modifiers are also supported (see initModifer()).
// rawFields is a comma separated string of the fields to include.
// Nested fields should be listed with dot-notation.
// Fields value modifiers are also supported using the `:modifier(args)` format (see Modifiers).
//
// Example:
//
// data := map[string]any{"a": 1, "b": 2, "c": map[string]any{"c1": 11, "c2": 22}}
// PickFields(data, "a,c.c1") // map[string]any{"a": 1, "c": map[string]any{"c1": 11}}
func PickFields(data any, rawFields string) (any, error) {
// Pick(data, "a,c.c1") // map[string]any{"a": 1, "c": map[string]any{"c1": 11}}
func Pick(data any, rawFields string) (any, error) {
parsedFields, err := parseFields(rawFields)
if err != nil {
return nil, err
@@ -82,18 +31,18 @@ func PickFields(data any, rawFields string) (any, error) {
//
// @todo research other approaches to avoid the double serialization
// ---
encoded, err := json.Marshal(data) // use the std json since goccy has several bugs reported with struct marshaling and it is not safe
encoded, err := json.Marshal(data)
if err != nil {
return nil, err
}
var decoded any
if err := goccy.Unmarshal(encoded, &decoded); err != nil {
if err := json.Unmarshal(encoded, &decoded); err != nil {
return nil, err
}
// ---
// special cases to preserve the same fields format when used with single item or array data.
// special cases to preserve the same fields format when used with single item or search results data.
var isSearchResult bool
switch data.(type) {
case search.Result, *search.Result:
@@ -111,7 +60,7 @@ func PickFields(data any, rawFields string) (any, error) {
return decoded, nil
}
func parseFields(rawFields string) (map[string]FieldModifier, error) {
func parseFields(rawFields string) (map[string]Modifier, error) {
t := tokenizer.NewFromString(rawFields)
fields, err := t.ScanAll()
@@ -119,7 +68,7 @@ func parseFields(rawFields string) (map[string]FieldModifier, error) {
return nil, err
}
result := make(map[string]FieldModifier, len(fields))
result := make(map[string]Modifier, len(fields))
for _, f := range fields {
parts := strings.SplitN(strings.TrimSpace(f), ":", 2)
@@ -138,36 +87,7 @@ func parseFields(rawFields string) (map[string]FieldModifier, error) {
return result, nil
}
func initModifer(rawModifier string) (FieldModifier, error) {
t := tokenizer.NewFromString(rawModifier)
t.Separators('(', ')', ',', ' ')
t.IgnoreParenthesis(true)
parts, err := t.ScanAll()
if err != nil {
return nil, err
}
if len(parts) == 0 {
return nil, fmt.Errorf("invalid or empty modifier expression %q", rawModifier)
}
name := parts[0]
args := parts[1:]
switch name {
case "excerpt":
m, err := newExcerptModifier(args...)
if err != nil {
return nil, fmt.Errorf("invalid excerpt modifier: %w", err)
}
return m, nil
}
return nil, fmt.Errorf("missing or invalid modifier %q", name)
}
func pickParsedFields(data any, fields map[string]FieldModifier) error {
func pickParsedFields(data any, fields map[string]Modifier) error {
switch v := data.(type) {
case map[string]any:
pickMapFields(v, fields)
@@ -196,7 +116,7 @@ func pickParsedFields(data any, fields map[string]FieldModifier) error {
return nil
}
func pickMapFields(data map[string]any, fields map[string]FieldModifier) error {
func pickMapFields(data map[string]any, fields map[string]Modifier) error {
if len(fields) == 0 {
return nil // nothing to pick
}
@@ -221,7 +141,7 @@ func pickMapFields(data map[string]any, fields map[string]FieldModifier) error {
DataLoop:
for k := range data {
matchingFields := make(map[string]FieldModifier, len(fields))
matchingFields := make(map[string]Modifier, len(fields))
for f, m := range fields {
if strings.HasPrefix(f+".", k+".") {
matchingFields[f] = m
@@ -1,135 +1,13 @@
package rest_test
package picker_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/picker"
"github.com/pocketbase/pocketbase/tools/search"
)
func TestSerialize(t *testing.T) {
scenarios := []struct {
name string
serializer rest.Serializer
statusCode int
data any
query string
expected string
}{
{
"empty query",
rest.Serializer{},
200,
map[string]any{"a": 1, "b": 2, "c": "test"},
"",
`{"a":1,"b":2,"c":"test"}`,
},
{
"empty fields",
rest.Serializer{},
200,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=",
`{"a":1,"b":2,"c":"test"}`,
},
{
"missing fields",
rest.Serializer{},
200,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{}`,
},
{
">299 response",
rest.Serializer{},
300,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{"a":1,"b":2,"c":"test"}`,
},
{
"<200 response",
rest.Serializer{},
199,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{"a":1,"b":2,"c":"test"}`,
},
{
"non map response",
rest.Serializer{},
200,
"test",
"fields=a,b,test",
`"test"`,
},
{
"non slice of map response",
rest.Serializer{},
200,
[]any{"a", "b", "test"},
"fields=a,test",
`["a","b","test"]`,
},
{
"map with no matching field",
rest.Serializer{},
200,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing", // test individual fields trim
`{}`,
},
{
"map with existing and missing fields",
rest.Serializer{},
200,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=a, c ,missing", // test individual fields trim
`{"a":1,"c":"test"}`,
},
{
"custom fields param",
rest.Serializer{FieldsParam: "custom"},
200,
map[string]any{"a": 1, "b": 2, "c": "test"},
"custom=a, c ,missing", // test individual fields trim
`{"a":1,"c":"test"}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.URL.RawQuery = s.query
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
c.Response().Status = s.statusCode
if err := s.serializer.Serialize(c, s.data, ""); err != nil {
t.Fatalf("Serialize failure: %v", err)
}
rawBody, err := io.ReadAll(rec.Result().Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}
if v := strings.TrimSpace(string(rawBody)); v != s.expected {
t.Fatalf("Expected body\n%v \ngot \n%v", s.expected, v)
}
})
}
}
func TestPickFields(t *testing.T) {
scenarios := []struct {
name string
@@ -374,7 +252,7 @@ func TestPickFields(t *testing.T) {
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result, err := rest.PickFields(s.data, s.fields)
result, err := picker.Pick(s.data, s.fields)
hasErr := err != nil
if hasErr != s.expectError {
-170
View File
@@ -1,170 +0,0 @@
package rest
import (
"bytes"
"encoding/json"
"io"
"net/http"
"reflect"
"strings"
"github.com/labstack/echo/v5"
"github.com/spf13/cast"
)
// MultipartJsonKey is the key for the special multipart/form-data
// handling allowing reading serialized json payload without normalization.
const MultipartJsonKey string = "@jsonPayload"
// MultiBinder is similar to [echo.DefaultBinder] but uses slightly different
// application/json and multipart/form-data bind methods to accommodate better
// the PocketBase router needs.
type MultiBinder struct{}
// Bind implements the [Binder.Bind] method.
//
// Bind is almost identical to [echo.DefaultBinder.Bind] but uses the
// [rest.BindBody] function for binding the request body.
func (b *MultiBinder) Bind(c echo.Context, i interface{}) (err error) {
if err := echo.BindPathParams(c, i); err != nil {
return err
}
// Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
// For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues.
method := c.Request().Method
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
if err = echo.BindQueryParams(c, i); err != nil {
return err
}
}
return BindBody(c, i)
}
// BindBody binds request body content to i.
//
// This is similar to `echo.BindBody()`, but for JSON requests uses
// custom json reader that **copies** the request body, allowing multiple reads.
func BindBody(c echo.Context, i any) error {
req := c.Request()
if req.ContentLength == 0 {
return nil
}
ctype := req.Header.Get(echo.HeaderContentType)
switch {
case strings.HasPrefix(ctype, echo.MIMEApplicationJSON):
err := CopyJsonBody(c.Request(), i)
if err != nil {
return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
case strings.HasPrefix(ctype, echo.MIMEApplicationForm), strings.HasPrefix(ctype, echo.MIMEMultipartForm):
return bindFormData(c, i)
}
// fallback to the default binder
return echo.BindBody(c, i)
}
// CopyJsonBody reads the request body into i by
// creating a copy of `r.Body` to allow multiple reads.
func CopyJsonBody(r *http.Request, i any) error {
body := r.Body
// this usually shouldn't be needed because the Server calls close
// for us but we are changing the request body with a new reader
defer body.Close()
limitReader := io.LimitReader(body, DefaultMaxMemory)
bodyBytes, readErr := io.ReadAll(limitReader)
if readErr != nil {
return readErr
}
err := json.NewDecoder(bytes.NewReader(bodyBytes)).Decode(i)
// set new body reader
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return err
}
// Custom multipart/form-data binder that implements an additional handling like
// loading a serialized json payload or properly scan array values when a map destination is used.
func bindFormData(c echo.Context, i any) error {
if i == nil {
return nil
}
values, err := c.FormValues()
if err != nil {
return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
if len(values) == 0 {
return nil
}
// special case to allow submitting json without normalization
// alongside the other multipart/form-data values
jsonPayloadValues := values[MultipartJsonKey]
for _, payload := range jsonPayloadValues {
json.Unmarshal([]byte(payload), i)
}
rt := reflect.TypeOf(i).Elem()
// map
if rt.Kind() == reflect.Map {
rv := reflect.ValueOf(i).Elem()
for k, v := range values {
if k == MultipartJsonKey {
continue
}
if total := len(v); total == 1 {
rv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalizeMultipartValue(v[0])))
} else {
normalized := make([]any, total)
for i, vItem := range v {
normalized[i] = normalizeMultipartValue(vItem)
}
rv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized))
}
}
return nil
}
// anything else
return echo.BindBody(c, i)
}
// In order to support more seamlessly both json and multipart/form-data requests,
// the following normalization rules are applied for plain multipart string values:
// - "true" is converted to the json `true`
// - "false" is converted to the json `false`
// - numeric (non-scientific) strings are converted to json number
// - any other string (empty string too) is left as it is
func normalizeMultipartValue(raw string) any {
switch raw {
case "":
return raw
case "true":
return true
case "false":
return false
default:
if raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9') {
if v, err := cast.ToFloat64E(raw); err == nil {
return v
}
}
return raw
}
}
-149
View File
@@ -1,149 +0,0 @@
package rest_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/tools/rest"
)
func TestMultiBinderBind(t *testing.T) {
binder := rest.MultiBinder{}
req := httptest.NewRequest(http.MethodGet, "/test?query=123", strings.NewReader(`{"body":"456"}`))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
e := echo.New()
e.Any("/:name", func(c echo.Context) error {
// bind twice to ensure that the json body reader copy is invoked
for i := 0; i < 2; i++ {
data := struct {
Name string `param:"name"`
Query string `query:"query"`
Body string `form:"body"`
}{}
if err := binder.Bind(c, &data); err != nil {
t.Fatal(err)
}
if data.Name != "test" {
t.Fatalf("Expected Name %q, got %q", "test", data.Name)
}
if data.Query != "123" {
t.Fatalf("Expected Query %q, got %q", "123", data.Query)
}
if data.Body != "456" {
t.Fatalf("Expected Body %q, got %q", "456", data.Body)
}
}
return nil
})
e.ServeHTTP(rec, req)
}
func TestBindBody(t *testing.T) {
scenarios := []struct {
body io.Reader
contentType string
expectBody string
expectError bool
}{
{
strings.NewReader(""),
echo.MIMEApplicationJSON,
`{}`,
false,
},
{
strings.NewReader(`{"test":"invalid`),
echo.MIMEApplicationJSON,
`{}`,
true,
},
{
strings.NewReader(`{"test":123}`),
echo.MIMEApplicationJSON,
`{"test":123}`,
false,
},
{
strings.NewReader(
url.Values{
"string": []string{"str"},
"stings": []string{"str1", "str2", ""},
"number": []string{"-123"},
"numbers": []string{"123", "456.789"},
"bool": []string{"true"},
"bools": []string{"true", "false"},
rest.MultipartJsonKey: []string{`invalid`, `{"a":123}`, `{"b":456}`},
}.Encode(),
),
echo.MIMEApplicationForm,
`{"a":123,"b":456,"bool":true,"bools":[true,false],"number":-123,"numbers":[123,456.789],"stings":["str1","str2",""],"string":"str"}`,
false,
},
}
for i, scenario := range scenarios {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", scenario.body)
req.Header.Set(echo.HeaderContentType, scenario.contentType)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
data := map[string]any{}
err := rest.BindBody(c, &data)
hasErr := err != nil
if hasErr != scenario.expectError {
t.Errorf("[%d] Expected hasErr %v, got %v", i, scenario.expectError, hasErr)
}
rawBody, err := json.Marshal(data)
if err != nil {
t.Errorf("[%d] Failed to marshal binded body: %v", i, err)
}
if scenario.expectBody != string(rawBody) {
t.Errorf("[%d] Expected body \n%s, \ngot \n%s", i, scenario.expectBody, rawBody)
}
}
}
func TestCopyJsonBody(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"test":"test123"}`))
// simulate multiple reads from the same request
result1 := map[string]string{}
rest.CopyJsonBody(req, &result1)
result2 := map[string]string{}
rest.CopyJsonBody(req, &result2)
if len(result1) == 0 {
t.Error("Expected result1 to be filled")
}
if len(result2) == 0 {
t.Error("Expected result2 to be filled")
}
if v, ok := result1["test"]; !ok || v != "test123" {
t.Errorf("Expected result1.test to be %q, got %q", "test123", v)
}
if v, ok := result2["test"]; !ok || v != "test123" {
t.Errorf("Expected result2.test to be %q, got %q", "test123", v)
}
}
-39
View File
@@ -1,39 +0,0 @@
package rest
import (
"net/http"
"github.com/pocketbase/pocketbase/tools/filesystem"
)
// DefaultMaxMemory defines the default max memory bytes that
// will be used when parsing a form request body.
const DefaultMaxMemory = 32 << 20 // 32mb
// FindUploadedFiles extracts all form files of "key" from a http request
// and returns a slice with filesystem.File instances (if any).
func FindUploadedFiles(r *http.Request, key string) ([]*filesystem.File, error) {
if r.MultipartForm == nil {
err := r.ParseMultipartForm(DefaultMaxMemory)
if err != nil {
return nil, err
}
}
if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 {
return nil, http.ErrMissingFile
}
result := make([]*filesystem.File, 0, len(r.MultipartForm.File[key]))
for _, fh := range r.MultipartForm.File[key] {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result = append(result, file)
}
return result, nil
}
-80
View File
@@ -1,80 +0,0 @@
package rest_test
import (
"bytes"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/rest"
)
func TestFindUploadedFiles(t *testing.T) {
scenarios := []struct {
filename string
expectedPattern string
}{
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
{"test", `^test_\w{10}\.txt$`},
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
}
for i, s := range scenarios {
// create multipart form file body
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
w, err := mp.CreateFormFile("test", s.filename)
if err != nil {
t.Fatal(err)
}
w.Write([]byte("test"))
mp.Close()
// ---
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
result, err := rest.FindUploadedFiles(req, "test")
if err != nil {
t.Fatal(err)
}
if len(result) != 1 {
t.Errorf("[%d] Expected 1 file, got %d", i, len(result))
}
if result[0].Size != 4 {
t.Errorf("[%d] Expected the file size to be 4 bytes, got %d", i, result[0].Size)
}
pattern, err := regexp.Compile(s.expectedPattern)
if err != nil {
t.Errorf("[%d] Invalid filename pattern %q: %v", i, s.expectedPattern, err)
}
if !pattern.MatchString(result[0].Name) {
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
}
}
}
func TestFindUploadedFilesMissing(t *testing.T) {
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
mp.Close()
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
result, err := rest.FindUploadedFiles(req, "test")
if err == nil {
t.Error("Expected error, got nil")
}
if result != nil {
t.Errorf("Expected result to be nil, got %v", result)
}
}
-29
View File
@@ -1,29 +0,0 @@
package rest
import (
"net/url"
"path"
"strings"
)
// NormalizeUrl removes duplicated slashes from a url path.
func NormalizeUrl(originalUrl string) (string, error) {
u, err := url.Parse(originalUrl)
if err != nil {
return "", err
}
hasSlash := strings.HasSuffix(u.Path, "/")
// clean up path by removing duplicated /
u.Path = path.Clean(u.Path)
u.RawPath = path.Clean(u.RawPath)
// restore original trailing slash
if hasSlash && !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
u.RawPath += "/"
}
return u.String(), nil
}
-40
View File
@@ -1,40 +0,0 @@
package rest_test
import (
"testing"
"github.com/pocketbase/pocketbase/tools/rest"
)
func TestNormalizeUrl(t *testing.T) {
scenarios := []struct {
url string
expectError bool
expectUrl string
}{
{":/", true, ""},
{"./", false, "./"},
{"../../test////", false, "../../test/"},
{"/a/b/c", false, "/a/b/c"},
{"a/////b//c/", false, "a/b/c/"},
{"/a/////b//c", false, "/a/b/c"},
{"///a/b/c", false, "/a/b/c"},
{"//a/b/c", false, "//a/b/c"}, // preserve "auto-schema"
{"http://a/b/c", false, "http://a/b/c"},
{"a//bc?test=1//dd", false, "a/bc?test=1//dd"}, // only the path is normalized
{"a//bc?test=1#12///3", false, "a/bc?test=1#12///3"}, // only the path is normalized
}
for i, s := range scenarios {
result, err := rest.NormalizeUrl(s.url)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("(%d) Expected hasErr %v, got %v", i, s.expectError, hasErr)
}
if result != s.expectUrl {
t.Errorf("(%d) Expected url %q, got %q", i, s.expectUrl, result)
}
}
}
+231
View File
@@ -0,0 +1,231 @@
package router
import (
"database/sql"
"errors"
"io/fs"
"net/http"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/inflector"
)
// SafeErrorItem defines a common error interface for a printable public safe error.
type SafeErrorItem interface {
// Code represents a fixed unique identifier of the error (usually used as translation key).
Code() string
// Error is the default English human readable error message that will be returned.
Error() string
}
// SafeErrorParamsResolver defines an optional interface for specifying dynamic error parameters.
type SafeErrorParamsResolver interface {
// Params defines a map with dynamic parameters to return as part of the public safe error view.
Params() map[string]any
}
// SafeErrorResolver defines an error interface for resolving the public safe error fields.
type SafeErrorResolver interface {
// Resolve allows modifying and returning a new public safe error data map.
Resolve(errData map[string]any) any
}
// ApiError defines the struct for a basic api error response.
type ApiError struct {
rawData any
Data map[string]any `json:"data"`
Message string `json:"message"`
Status int `json:"status"`
}
// Error makes it compatible with the `error` interface.
func (e *ApiError) Error() string {
return e.Message
}
// RawData returns the unformatted error data (could be an internal error, text, etc.)
func (e *ApiError) RawData() any {
return e.rawData
}
// Is reports whether the current ApiError wraps the target.
func (e *ApiError) Is(target error) bool {
err, ok := e.rawData.(error)
if ok {
return errors.Is(err, target)
}
apiErr, ok := target.(*ApiError)
return ok && e == apiErr
}
// NewNotFoundError creates and returns 404 ApiError.
func NewNotFoundError(message string, rawErrData any) *ApiError {
if message == "" {
message = "The requested resource wasn't found."
}
return NewApiError(http.StatusNotFound, message, rawErrData)
}
// NewBadRequestError creates and returns 400 ApiError.
func NewBadRequestError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Something went wrong while processing your request."
}
return NewApiError(http.StatusBadRequest, message, rawErrData)
}
// NewForbiddenError creates and returns 403 ApiError.
func NewForbiddenError(message string, rawErrData any) *ApiError {
if message == "" {
message = "You are not allowed to perform this request."
}
return NewApiError(http.StatusForbidden, message, rawErrData)
}
// NewUnauthorizedError creates and returns 401 ApiError.
func NewUnauthorizedError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Missing or invalid authentication."
}
return NewApiError(http.StatusUnauthorized, message, rawErrData)
}
// NewInternalServerError creates and returns 500 ApiError.
func NewInternalServerError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Something went wrong while processing your request."
}
return NewApiError(http.StatusInternalServerError, message, rawErrData)
}
func NewTooManyRequestsError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Too Many Requests."
}
return NewApiError(http.StatusTooManyRequests, message, rawErrData)
}
// NewApiError creates and returns new normalized ApiError instance.
func NewApiError(status int, message string, rawErrData any) *ApiError {
if message == "" {
message = http.StatusText(status)
}
return &ApiError{
rawData: rawErrData,
Data: safeErrorsData(rawErrData),
Status: status,
Message: strings.TrimSpace(inflector.Sentenize(message)),
}
}
// ToApiError wraps err into ApiError instance (if not already).
func ToApiError(err error) *ApiError {
var apiErr *ApiError
if !errors.As(err, &apiErr) {
// no ApiError found -> assign a generic one
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, fs.ErrNotExist) {
apiErr = NewNotFoundError("", err)
} else {
apiErr = NewBadRequestError("", err)
}
}
return apiErr
}
// -------------------------------------------------------------------
func safeErrorsData(data any) map[string]any {
switch v := data.(type) {
case validation.Errors:
return resolveSafeErrorsData(v)
case error:
validationErrors := validation.Errors{}
if errors.As(v, &validationErrors) {
return resolveSafeErrorsData(validationErrors)
}
return map[string]any{} // not nil to ensure that is json serialized as object
case map[string]validation.Error:
return resolveSafeErrorsData(v)
case map[string]SafeErrorItem:
return resolveSafeErrorsData(v)
case map[string]error:
return resolveSafeErrorsData(v)
case map[string]string:
return resolveSafeErrorsData(v)
case map[string]any:
return resolveSafeErrorsData(v)
default:
return map[string]any{} // not nil to ensure that is json serialized as object
}
}
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
result := map[string]any{}
for name, err := range data {
if isNestedError(err) {
result[name] = safeErrorsData(err)
} else {
result[name] = resolveSafeErrorItem(err)
}
}
return result
}
func isNestedError(err any) bool {
switch err.(type) {
case validation.Errors,
map[string]validation.Error,
map[string]SafeErrorItem,
map[string]error,
map[string]string,
map[string]any:
return true
}
return false
}
// resolveSafeErrorItem extracts from each validation error its
// public safe error code and message.
func resolveSafeErrorItem(err any) any {
data := map[string]any{}
if obj, ok := err.(SafeErrorItem); ok {
// extract the specific error code and message
data["code"] = obj.Code()
data["message"] = inflector.Sentenize(obj.Error())
} else {
// fallback to the default public safe values
data["code"] = "validation_invalid_value"
data["message"] = "Invalid value."
}
if s, ok := err.(SafeErrorParamsResolver); ok {
params := s.Params()
if len(params) > 0 {
data["params"] = params
}
}
if s, ok := err.(SafeErrorResolver); ok {
return s.Resolve(data)
}
return data
}
+358
View File
@@ -0,0 +1,358 @@
package router_test
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/fs"
"strconv"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestNewApiErrorWithRawData(t *testing.T) {
t.Parallel()
e := router.NewApiError(
300,
"message_test",
"rawData_test",
)
result, _ := json.Marshal(e)
expected := `{"data":{},"message":"Message_test.","status":300}`
if string(result) != expected {
t.Errorf("Expected\n%v\ngot\n%v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() != "rawData_test" {
t.Errorf("Expected rawData\n%v\ngot\n%v", "rawData_test", e.RawData())
}
}
func TestNewApiErrorWithValidationData(t *testing.T) {
t.Parallel()
e := router.NewApiError(
300,
"message_test",
map[string]any{
"err1": errors.New("test error"), // should be normalized
"err2": validation.ErrRequired,
"err3": validation.Errors{
"err3.1": errors.New("test error"), // should be normalized
"err3.2": validation.ErrRequired,
"err3.3": validation.Errors{
"err3.3.1": validation.ErrRequired,
},
},
"err4": &mockSafeErrorItem{},
"err5": map[string]error{
"err5.1": validation.ErrRequired,
},
},
)
result, _ := json.Marshal(e)
expected := `{"data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"err3.1":{"code":"validation_invalid_value","message":"Invalid value."},"err3.2":{"code":"validation_required","message":"Cannot be blank."},"err3.3":{"err3.3.1":{"code":"validation_required","message":"Cannot be blank."}}},"err4":{"code":"mock_code","message":"Mock_error.","mock_resolve":123},"err5":{"err5.1":{"code":"validation_required","message":"Cannot be blank."}}},"message":"Message_test.","status":300}`
if string(result) != expected {
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() == nil {
t.Error("Expected non-nil rawData")
}
}
func TestNewNotFoundError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"The requested resource wasn't found.","status":404}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":404}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":404}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewNotFoundError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewBadRequestError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":400}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":400}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":400}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewBadRequestError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewForbiddenError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"You are not allowed to perform this request.","status":403}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":403}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":403}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewForbiddenError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewUnauthorizedError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Missing or invalid authentication.","status":401}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":401}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":401}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewUnauthorizedError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewInternalServerError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":500}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":500}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":500}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewInternalServerError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewTooManyRequestsError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Too Many Requests.","status":429}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":429}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message").SetParams(map[string]any{"test": 123})}, `{"data":{"err1":{"code":"test_code","message":"Test_message.","params":{"test":123}}},"message":"Demo.","status":429}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewTooManyRequestsError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestApiErrorIs(t *testing.T) {
t.Parallel()
err0 := router.NewInternalServerError("", nil)
err1 := router.NewInternalServerError("", nil)
err2 := errors.New("test")
err3 := fmt.Errorf("wrapped: %w", err0)
scenarios := []struct {
name string
err error
target error
expected bool
}{
{
"nil error",
err0,
nil,
false,
},
{
"non ApiError",
err0,
err1,
false,
},
{
"different ApiError",
err0,
err2,
false,
},
{
"same ApiError",
err0,
err0,
true,
},
{
"wrapped ApiError",
err3,
err0,
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
is := errors.Is(s.err, s.target)
if is != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, is)
}
})
}
}
func TestToApiError(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
err error
expected string
}{
{
"regular error",
errors.New("test"),
`{"data":{},"message":"Something went wrong while processing your request.","status":400}`,
},
{
"fs.ErrNotExist",
fs.ErrNotExist,
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
},
{
"sql.ErrNoRows",
sql.ErrNoRows,
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
},
{
"ApiError",
router.NewForbiddenError("test", nil),
`{"data":{},"message":"Test.","status":403}`,
},
{
"wrapped ApiError",
fmt.Errorf("wrapped: %w", router.NewForbiddenError("test", nil)),
`{"data":{},"message":"Test.","status":403}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
raw, err := json.Marshal(router.ToApiError(s.err))
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected error\n%vgot\n%v", s.expected, rawStr)
}
})
}
}
// -------------------------------------------------------------------
var (
_ router.SafeErrorItem = (*mockSafeErrorItem)(nil)
_ router.SafeErrorResolver = (*mockSafeErrorItem)(nil)
)
type mockSafeErrorItem struct {
}
func (m *mockSafeErrorItem) Code() string {
return "mock_code"
}
func (m *mockSafeErrorItem) Error() string {
return "mock_error"
}
func (m *mockSafeErrorItem) Resolve(errData map[string]any) any {
errData["mock_resolve"] = 123
return errData
}
+369
View File
@@ -0,0 +1,369 @@
package router
import (
"encoding/json"
"encoding/xml"
"errors"
"io"
"io/fs"
"net"
"net/http"
"net/netip"
"path/filepath"
"strings"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/picker"
"github.com/pocketbase/pocketbase/tools/store"
)
var ErrUnsupportedContentType = NewBadRequestError("Unsupported Content-Type", nil)
var ErrInvalidRedirectStatusCode = NewInternalServerError("Invalid redirect status code", nil)
var ErrFileNotFound = NewNotFoundError("File not found", nil)
const IndexPage = "index.html"
// Event specifies based Route handler event that is usually intended
// to be embedded as part of a custom event struct.
//
// NB! It is expected that the Response and Request fields are always set.
type Event struct {
Response http.ResponseWriter
Request *http.Request
hook.Event
data store.Store[any]
}
// RWUnwrapper specifies that an http.ResponseWriter could be "unwrapped"
// (usually used with [http.ResponseController]).
type RWUnwrapper interface {
Unwrap() http.ResponseWriter
}
// Written reports whether the current response has already been written.
//
// This method always returns false if e.ResponseWritter doesn't implement the WriteTracker interface
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
func (e *Event) Written() bool {
written, _ := getWritten(e.Response)
return written
}
// Status reports the status code of the current response.
//
// This method always returns 0 if e.Response doesn't implement the StatusTracker interface
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
func (e *Event) Status() int {
status, _ := getStatus(e.Response)
return status
}
// Flush flushes buffered data to the current response.
//
// Returns [http.ErrNotSupported] if e.Response doesn't implement the [http.Flusher] interface
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
func (e *Event) Flush() error {
return http.NewResponseController(e.Response).Flush()
}
// IsTLS reports whether the connection on which the request was received is TLS.
func (e *Event) IsTLS() bool {
return e.Request.TLS != nil
}
// SetCookie is an alias for [http.SetCookie].
//
// SetCookie adds a Set-Cookie header to the current response's headers.
// The provided cookie must have a valid Name.
// Invalid cookies may be silently dropped.
func (e *Event) SetCookie(cookie *http.Cookie) {
http.SetCookie(e.Response, cookie)
}
// RemoteIP returns the IP address of the client that sent the request.
//
// IPv6 addresses are returned expanded.
// For example, "2001:db8::1" becomes "2001:0db8:0000:0000:0000:0000:0000:0001".
//
// Note that if you are behind reverse proxy(ies), this method returns
// the IP of the last connecting proxy.
func (e *Event) RemoteIP() string {
ip, _, _ := net.SplitHostPort(e.Request.RemoteAddr)
parsed, _ := netip.ParseAddr(ip)
return parsed.StringExpanded()
}
// UnsafeRealIP returns the "real" client IP from common proxy headers
// OR fallbacks to the RemoteIP if none is found.
//
// NB! The returned IP value could be anything and it shouldn't be trusted if not behind a trusted reverse proxy!
func (e *Event) UnsafeRealIP() string {
if ip := e.Request.Header.Get("CF-Connecting-IP"); ip != "" {
return ip
}
if ip := e.Request.Header.Get("Fly-Client-IP"); ip != "" {
return ip
}
if ip := e.Request.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ipsList := e.Request.Header.Get("X-Forwarded-For"); ipsList != "" {
// extract the first non-empty leftmost-ish ip
ips := strings.Split(ipsList, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip != "" {
return ip
}
}
}
return e.RemoteIP()
}
// Store
// -------------------------------------------------------------------
// Get retrieves single value from the current event data store.
func (e *Event) Get(key string) any {
return e.data.Get(key)
}
// GetAll returns a copy of the current event data store.
func (e *Event) GetAll() map[string]any {
return e.data.GetAll()
}
// Set saves single value into the current event data store.
func (e *Event) Set(key string, value any) {
e.data.Set(key, value)
}
// SetAll saves all items from m into the current event data store.
func (e *Event) SetAll(m map[string]any) {
for k, v := range m {
e.Set(k, v)
}
}
// Response writers
// -------------------------------------------------------------------
const headerContentType = "Content-Type"
func (e *Event) setResponseHeaderIfEmpty(key, value string) {
header := e.Response.Header()
if header.Get(key) == "" {
header.Set(key, value)
}
}
// String writes a plain string response.
func (e *Event) String(status int, data string) error {
e.setResponseHeaderIfEmpty(headerContentType, "text/plain; charset=utf-8")
e.Response.WriteHeader(status)
_, err := e.Response.Write([]byte(data))
return err
}
// HTML writes an HTML response.
func (e *Event) HTML(status int, data string) error {
e.setResponseHeaderIfEmpty(headerContentType, "text/html; charset=utf-8")
e.Response.WriteHeader(status)
_, err := e.Response.Write([]byte(data))
return err
}
const jsonFieldsParam = "fields"
// JSON writes a JSON response.
//
// It also provides a generic response data fields picker if the "fields" query parameter is set.
func (e *Event) JSON(status int, data any) error {
e.setResponseHeaderIfEmpty(headerContentType, "application/json")
e.Response.WriteHeader(status)
rawFields := e.Request.URL.Query().Get(jsonFieldsParam)
// error response or no fields to pick
if rawFields == "" || status < 200 || status > 299 {
return json.NewEncoder(e.Response).Encode(data)
}
// pick only the requested fields
modified, err := picker.Pick(data, rawFields)
if err != nil {
return err
}
return json.NewEncoder(e.Response).Encode(modified)
}
// XML writes an XML response.
// It automatically prepends the generic [xml.Header] string to the response.
func (e *Event) XML(status int, data any) error {
e.setResponseHeaderIfEmpty(headerContentType, "application/xml; charset=utf-8")
e.Response.WriteHeader(status)
if _, err := e.Response.Write([]byte(xml.Header)); err != nil {
return err
}
return xml.NewEncoder(e.Response).Encode(data)
}
// Stream streams the specified reader into the response.
func (e *Event) Stream(status int, contentType string, reader io.Reader) error {
e.Response.Header().Set(headerContentType, contentType)
e.Response.WriteHeader(status)
_, err := io.Copy(e.Response, reader)
return err
}
// FileFS serves the specified filename from fsys.
//
// It is similar to [echo.FileFS] for consistency with earlier versions.
func (e *Event) FileFS(fsys fs.FS, filename string) error {
f, err := fsys.Open(filename)
if err != nil {
return ErrFileNotFound
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return err
}
// if it is a directory try to open its index.html file
if fi.IsDir() {
filename = filepath.ToSlash(filepath.Join(filename, IndexPage))
f, err = fsys.Open(filename)
if err != nil {
return ErrFileNotFound
}
defer f.Close()
fi, err = f.Stat()
if err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("[FileFS] file does not implement io.ReadSeeker")
}
http.ServeContent(e.Response, e.Request, fi.Name(), fi.ModTime(), ff)
return nil
}
// NoContent writes a response with no body (ex. 204).
func (e *Event) NoContent(status int) error {
e.Response.WriteHeader(status)
return nil
}
// Redirect writes a redirect response to the specified url.
// The status code must be in between 300 399 range.
func (e *Event) Redirect(status int, url string) error {
if status < 300 || status > 399 {
return ErrInvalidRedirectStatusCode
}
e.Response.Header().Set("Location", url)
e.Response.WriteHeader(status)
return nil
}
// ApiError helpers
// -------------------------------------------------------------------
func (e *Event) Error(status int, message string, errData any) *ApiError {
return NewApiError(status, message, errData)
}
func (e *Event) BadRequestError(message string, errData any) *ApiError {
return NewBadRequestError(message, errData)
}
func (e *Event) NotFoundError(message string, errData any) *ApiError {
return NewNotFoundError(message, errData)
}
func (e *Event) ForbiddenError(message string, errData any) *ApiError {
return NewForbiddenError(message, errData)
}
func (e *Event) UnauthorizedError(message string, errData any) *ApiError {
return NewUnauthorizedError(message, errData)
}
func (e *Event) TooManyRequestsError(message string, errData any) *ApiError {
return NewTooManyRequestsError(message, errData)
}
func (e *Event) InternalServerError(message string, errData any) *ApiError {
return NewInternalServerError(message, errData)
}
// Binders
// -------------------------------------------------------------------
const DefaultMaxMemory = 32 << 20 // 32mb
// Supports the following content-types:
//
// - application/json
// - multipart/form-data
// - application/x-www-form-urlencoded
// - text/xml, application/xml
func (e *Event) BindBody(dst any) error {
if e.Request.ContentLength == 0 {
return nil
}
contentType := e.Request.Header.Get(headerContentType)
if strings.HasPrefix(contentType, "application/json") {
dec := json.NewDecoder(e.Request.Body)
err := dec.Decode(dst)
if err == nil {
// manually call Reread because single call of json.Decoder.Decode()
// doesn't ensure that the entire body is a valid json string
// and it is not guaranteed that it will reach EOF to trigger the reread reset
// (ex. in case of trailing spaces or invalid trailing parts like: `{"test":1},something`)
if body, ok := e.Request.Body.(Rereader); ok {
body.Reread()
}
}
return err
}
if strings.HasPrefix(contentType, "multipart/form-data") {
if err := e.Request.ParseMultipartForm(DefaultMaxMemory); err != nil {
return err
}
return UnmarshalRequestData(e.Request.Form, dst, "", "")
}
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
if err := e.Request.ParseForm(); err != nil {
return err
}
return UnmarshalRequestData(e.Request.Form, dst, "", "")
}
if strings.HasPrefix(contentType, "text/xml") ||
strings.HasPrefix(contentType, "application/xml") {
return xml.NewDecoder(e.Request.Body).Decode(dst)
}
return ErrUnsupportedContentType
}
+924
View File
@@ -0,0 +1,924 @@
package router_test
import (
"bytes"
"crypto/tls"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/router"
)
type unwrapTester struct {
http.ResponseWriter
}
func (ut unwrapTester) Unwrap() http.ResponseWriter {
return ut.ResponseWriter
}
func TestEventWritten(t *testing.T) {
t.Parallel()
res1 := httptest.NewRecorder()
res2 := httptest.NewRecorder()
res2.Write([]byte("test"))
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4.Write([]byte("test"))
scenarios := []struct {
name string
response http.ResponseWriter
expected bool
}{
{
name: "non-written non-WriteTracker",
response: res1,
expected: false,
},
{
name: "written non-WriteTracker",
response: res2,
expected: false,
},
{
name: "non-written WriteTracker",
response: res3,
expected: false,
},
{
name: "written WriteTracker",
response: res4,
expected: true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
event := router.Event{
Response: s.response,
}
result := event.Written()
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestEventStatus(t *testing.T) {
t.Parallel()
res1 := httptest.NewRecorder()
res2 := httptest.NewRecorder()
res2.WriteHeader(123)
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4.WriteHeader(123)
scenarios := []struct {
name string
response http.ResponseWriter
expected int
}{
{
name: "non-written non-StatusTracker",
response: res1,
expected: 0,
},
{
name: "written non-StatusTracker",
response: res2,
expected: 0,
},
{
name: "non-written StatusTracker",
response: res3,
expected: 0,
},
{
name: "written StatusTracker",
response: res4,
expected: 123,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
event := router.Event{
Response: s.response,
}
result := event.Status()
if result != s.expected {
t.Fatalf("Expected %d, got %d", s.expected, result)
}
})
}
}
func TestEventIsTLS(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
event := router.Event{Request: req}
// without TLS
if event.IsTLS() {
t.Fatalf("Expected IsTLS false")
}
// dummy TLS state
req.TLS = new(tls.ConnectionState)
// with TLS
if !event.IsTLS() {
t.Fatalf("Expected IsTLS true")
}
}
func TestEventSetCookie(t *testing.T) {
t.Parallel()
event := router.Event{
Response: httptest.NewRecorder(),
}
cookie := event.Response.Header().Get("set-cookie")
if cookie != "" {
t.Fatalf("Expected empty cookie string, got %q", cookie)
}
event.SetCookie(&http.Cookie{Name: "test", Value: "a"})
expected := "test=a"
cookie = event.Response.Header().Get("set-cookie")
if cookie != expected {
t.Fatalf("Expected cookie %q, got %q", expected, cookie)
}
}
func TestEventRemoteIP(t *testing.T) {
t.Parallel()
scenarios := []struct {
remoteAddr string
expected string
}{
{"", "invalid IP"},
{"1.2.3.4", "invalid IP"},
{"1.2.3.4:8090", "1.2.3.4"},
{"[0000:0000:0000:0000:0000:0000:0000:0002]:80", "0000:0000:0000:0000:0000:0000:0000:0002"},
{"[::2]:80", "0000:0000:0000:0000:0000:0000:0000:0002"}, // should always return the expanded version
}
for _, s := range scenarios {
t.Run(s.remoteAddr, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
req.RemoteAddr = s.remoteAddr
event := router.Event{Request: req}
ip := event.RemoteIP()
if ip != s.expected {
t.Fatalf("Expected IP %q, got %q", s.expected, ip)
}
})
}
}
func TestEventUnsafeRealIP(t *testing.T) {
t.Parallel()
scenarios := []struct {
headers map[string]string
expected string
}{
{nil, "1.2.3.4"},
{
map[string]string{"CF-Connecting-IP": "test"},
"test",
},
{
map[string]string{"Fly-Client-IP": "test"},
"test",
},
{
map[string]string{"X-Real-IP": "test"},
"test",
},
{
map[string]string{"X-Forwarded-For": "test1,test2,test3"},
"test1",
},
}
for i, s := range scenarios {
keys := make([]string, 0, len(s.headers))
for h := range s.headers {
keys = append(keys, h)
}
testName := strings.Join(keys, "_")
if testName == "" {
testName = "no_headers" + strconv.Itoa(i)
}
t.Run(testName, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
req.RemoteAddr = "1.2.3.4:80" // fallback
for k, v := range s.headers {
req.Header.Set(k, v)
}
event := router.Event{Request: req}
ip := event.UnsafeRealIP()
if ip != s.expected {
t.Fatalf("Expected IP %q, got %q", s.expected, ip)
}
})
}
}
func TestEventSetGet(t *testing.T) {
event := router.Event{}
// get before any set (ensures that doesn't panic)
if v := event.Get("test"); v != nil {
t.Fatalf("Expected nil value, got %v", v)
}
event.Set("a", 123)
event.Set("b", 456)
scenarios := []struct {
key string
expected any
}{
{"", nil},
{"missing", nil},
{"a", 123},
{"b", 456},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
result := event.Get(s.key)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestEventSetAllGetAll(t *testing.T) {
data := map[string]any{
"a": 123,
"b": 456,
}
rawData, err := json.Marshal(data)
if err != nil {
t.Fatal(err)
}
event := router.Event{}
event.SetAll(data)
// modify the data to ensure that the map was shallow coppied
data["c"] = 789
result := event.GetAll()
rawResult, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}
if len(rawResult) == 0 || !bytes.Equal(rawData, rawResult) {
t.Fatalf("Expected\n%v\ngot\n%v", rawData, rawResult)
}
}
func TestEventString(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "no explicit content-type",
status: 123,
headers: nil,
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/plain; charset=utf-8"},
expectedBody: "test",
},
{
name: "with explicit content-type",
status: 123,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.String(s.status, s.body)
})
}
}
func TestEventHTML(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "no explicit content-type",
status: 123,
headers: nil,
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/html; charset=utf-8"},
expectedBody: "test",
},
{
name: "with explicit content-type",
status: 123,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.HTML(s.status, s.body)
})
}
}
func TestEventJSON(t *testing.T) {
body := map[string]any{"a": 123, "b": 456, "c": "test"}
expectedPickedBody := `{"a":123,"c":"test"}` + "\n"
expectedFullBody := `{"a":123,"b":456,"c":"test"}` + "\n"
scenarios := []testResponseWriteScenario[any]{
{
name: "no explicit content-type",
status: 200,
headers: nil,
body: body,
expectedStatus: 200,
expectedHeaders: map[string]string{"content-type": "application/json"},
expectedBody: expectedPickedBody,
},
{
name: "with explicit content-type (200)",
status: 200,
headers: map[string]string{"content-type": "application/test"},
body: body,
expectedStatus: 200,
expectedHeaders: map[string]string{"content-type": "application/test"},
expectedBody: expectedPickedBody,
},
{
name: "with explicit content-type (400)", // no fields picker
status: 400,
headers: map[string]string{"content-type": "application/test"},
body: body,
expectedStatus: 400,
expectedHeaders: map[string]string{"content-type": "application/test"},
expectedBody: expectedFullBody,
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
e.Request.URL.RawQuery = "fields=a,c" // ensures that the picker is invoked
return e.JSON(s.status, s.body)
})
}
}
func TestEventXML(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "no explicit content-type",
status: 234,
headers: nil,
body: "test",
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "application/xml; charset=utf-8"},
expectedBody: xml.Header + "<string>test</string>",
},
{
name: "with explicit content-type",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: xml.Header + "<string>test</string>",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.XML(s.status, s.body)
})
}
}
func TestEventStream(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "stream",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.Stream(s.status, s.headers["content-type"], strings.NewReader(s.body))
})
}
}
func TestEventNoContent(t *testing.T) {
s := testResponseWriteScenario[any]{
name: "no content",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: nil,
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "",
}
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.NoContent(s.status)
})
}
func TestEventFlush(t *testing.T) {
rec := httptest.NewRecorder()
event := &router.Event{
Response: unwrapTester{&router.ResponseWriter{ResponseWriter: rec}},
}
event.Response.Write([]byte("test"))
event.Flush()
if !rec.Flushed {
t.Fatal("Expected response to be flushed")
}
}
func TestEventRedirect(t *testing.T) {
scenarios := []testResponseWriteScenario[any]{
{
name: "non-30x status",
status: 200,
expectedStatus: 200,
expectedError: router.ErrInvalidRedirectStatusCode,
},
{
name: "30x status",
status: 302,
headers: map[string]string{"location": "test"}, // should be overwritten with the argument
expectedStatus: 302,
expectedHeaders: map[string]string{"location": "example"},
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.Redirect(s.status, "example")
})
}
}
func TestEventFileFS(t *testing.T) {
// stub test files
// ---
dir, err := os.MkdirTemp("", "EventFileFS")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
err = os.WriteFile(filepath.Join(dir, "index.html"), []byte("index"), 0644)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(filepath.Join(dir, "test.txt"), []byte("test"), 0644)
if err != nil {
t.Fatal(err)
}
// create sub directory with an index.html file inside it
err = os.MkdirAll(filepath.Join(dir, "sub1"), os.ModePerm)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(filepath.Join(dir, "sub1", "index.html"), []byte("sub1 index"), 0644)
if err != nil {
t.Fatal(err)
}
err = os.MkdirAll(filepath.Join(dir, "sub2"), os.ModePerm)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(filepath.Join(dir, "sub2", "test.txt"), []byte("sub2 test"), 0644)
if err != nil {
t.Fatal(err)
}
// ---
scenarios := []struct {
name string
path string
expected string
}{
{"missing file", "", ""},
{"root with no explicit file", "", ""},
{"root with explicit file", "test.txt", "test"},
{"sub dir with no explicit file", "sub1", "sub1 index"},
{"sub dir with no explicit file (no index.html)", "sub2", ""},
{"sub dir explicit file", "sub2/test.txt", "sub2 test"},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := &router.Event{
Request: req,
Response: rec,
}
err = event.FileFS(os.DirFS(dir), s.path)
hasErr := err != nil
expectErr := s.expected == ""
if hasErr != expectErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", expectErr, hasErr, err)
}
result := rec.Result()
raw, err := io.ReadAll(result.Body)
result.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(raw) != s.expected {
t.Fatalf("Expected body\n%s\ngot\n%s", s.expected, raw)
}
// ensure that the proper file headers are added
// (aka. http.ServeContent is invoked)
length, _ := strconv.Atoi(result.Header.Get("content-length"))
if length != len(s.expected) {
t.Fatalf("Expected Content-Length %d, got %d", len(s.expected), length)
}
})
}
}
func TestEventError(t *testing.T) {
err := new(router.Event).Error(123, "message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":123}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventBadRequestError(t *testing.T) {
err := new(router.Event).BadRequestError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":400}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventNotFoundError(t *testing.T) {
err := new(router.Event).NotFoundError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":404}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventForbiddenError(t *testing.T) {
err := new(router.Event).ForbiddenError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":403}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventUnauthorizedError(t *testing.T) {
err := new(router.Event).UnauthorizedError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":401}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventTooManyRequestsError(t *testing.T) {
err := new(router.Event).TooManyRequestsError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":429}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventInternalServerError(t *testing.T) {
err := new(router.Event).InternalServerError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":500}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventBindBody(t *testing.T) {
type testDstStruct struct {
A int `json:"a" xml:"a" form:"a"`
B int `json:"b" xml:"b" form:"b"`
C string `json:"c" xml:"c" form:"c"`
}
emptyDst := `{"a":0,"b":0,"c":""}`
queryDst := `a=123&b=-456&c=test`
xmlDst := `
<?xml version="1.0" encoding="UTF-8" ?>
<root>
<a>123</a>
<b>-456</b>
<c>test</c>
</root>
`
jsonDst := `{"a":123,"b":-456,"c":"test"}`
// multipart
mpBody := &bytes.Buffer{}
mpWriter := multipart.NewWriter(mpBody)
mpWriter.WriteField("@jsonPayload", `{"a":123}`)
mpWriter.WriteField("b", "-456")
mpWriter.WriteField("c", "test")
if err := mpWriter.Close(); err != nil {
t.Fatal(err)
}
scenarios := []struct {
contentType string
body io.Reader
expectDst string
expectError bool
}{
{
contentType: "",
body: strings.NewReader(jsonDst),
expectDst: emptyDst,
expectError: true,
},
{
contentType: "application/rtf", // unsupported
body: strings.NewReader(jsonDst),
expectDst: emptyDst,
expectError: true,
},
// empty body
{
contentType: "application/json;charset=emptybody",
body: strings.NewReader(""),
expectDst: emptyDst,
},
// json
{
contentType: "application/json",
body: strings.NewReader(jsonDst),
expectDst: jsonDst,
},
{
contentType: "application/json;charset=abc",
body: strings.NewReader(jsonDst),
expectDst: jsonDst,
},
// xml
{
contentType: "text/xml",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
{
contentType: "text/xml;charset=abc",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
{
contentType: "application/xml",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
{
contentType: "application/xml;charset=abc",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
// x-www-form-urlencoded
{
contentType: "application/x-www-form-urlencoded",
body: strings.NewReader(queryDst),
expectDst: jsonDst,
},
{
contentType: "application/x-www-form-urlencoded;charset=abc",
body: strings.NewReader(queryDst),
expectDst: jsonDst,
},
// multipart
{
contentType: mpWriter.FormDataContentType(),
body: mpBody,
expectDst: jsonDst,
},
}
for _, s := range scenarios {
t.Run(s.contentType, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", s.body)
if err != nil {
t.Fatal(err)
}
req.Header.Add("content-type", s.contentType)
event := &router.Event{Request: req}
dst := testDstStruct{}
err = event.BindBody(&dst)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
dstRaw, err := json.Marshal(dst)
if err != nil {
t.Fatal(err)
}
if string(dstRaw) != s.expectDst {
t.Fatalf("Expected dst\n%s\ngot\n%s", s.expectDst, dstRaw)
}
})
}
}
// -------------------------------------------------------------------
type testResponseWriteScenario[T any] struct {
name string
status int
headers map[string]string
body T
expectedStatus int
expectedHeaders map[string]string
expectedBody string
expectedError error
}
func testEventResponseWrite[T any](
t *testing.T,
scenario testResponseWriteScenario[T],
writeFunc func(e *router.Event) error,
) {
t.Run(scenario.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := &router.Event{
Request: req,
Response: &router.ResponseWriter{ResponseWriter: rec},
}
for k, v := range scenario.headers {
event.Response.Header().Add(k, v)
}
err = writeFunc(event)
if (scenario.expectedError != nil || err != nil) && !errors.Is(err, scenario.expectedError) {
t.Fatalf("Expected error %v, got %v", scenario.expectedError, err)
}
result := rec.Result()
if result.StatusCode != scenario.expectedStatus {
t.Fatalf("Expected status code %d, got %d", scenario.expectedStatus, result.StatusCode)
}
resultBody, err := io.ReadAll(result.Body)
result.Body.Close()
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
resultBody, err = json.Marshal(string(resultBody))
if err != nil {
t.Fatal(err)
}
expectedBody, err := json.Marshal(scenario.expectedBody)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(resultBody, expectedBody) {
t.Fatalf("Expected body\n%s\ngot\n%s", expectedBody, resultBody)
}
for k, ev := range scenario.expectedHeaders {
if v := result.Header.Get(k); v != ev {
t.Fatalf("Expected %q header to be %q, got %q", k, ev, v)
}
}
})
}
+226
View File
@@ -0,0 +1,226 @@
package router
import (
"net/http"
"regexp"
"strings"
"github.com/pocketbase/pocketbase/tools/hook"
)
// (note: the struct is named RouterGroup instead of Group so that it can
// be embedded in the Router without conflicting with the Group method)
// RouterGroup represents a collection of routes and other sub groups
// that share common pattern prefix and middlewares.
type RouterGroup[T hook.Resolver] struct {
excludedMiddlewares map[string]struct{}
children []any // Route or RouterGroup
Prefix string
Middlewares []*hook.Handler[T]
}
// Group creates and register a new child Group into the current one
// with the specified prefix.
//
// The prefix follows the standard Go net/http ServeMux pattern format ("[HOST]/[PATH]")
// and will be concatenated recursively into the final route path, meaning that
// only the root level group could have HOST as part of the prefix.
//
// Returns the newly created group to allow chaining and registering
// sub-routes and group specific middlewares.
func (group *RouterGroup[T]) Group(prefix string) *RouterGroup[T] {
newGroup := &RouterGroup[T]{}
newGroup.Prefix = prefix
group.children = append(group.children, newGroup)
return newGroup
}
// BindFunc registers one or multiple middleware functions to the current group.
//
// The registered middleware functions are "anonymous" and with default priority,
// aka. executes in the order they were registered.
//
// If you need to specify a named middleware (ex. so that it can be removed)
// or middleware with custom exec prirority, use [Group.Bind] method.
func (group *RouterGroup[T]) BindFunc(middlewareFuncs ...hook.HandlerFunc[T]) *RouterGroup[T] {
for _, m := range middlewareFuncs {
group.Middlewares = append(group.Middlewares, &hook.Handler[T]{Func: m})
}
return group
}
// Bind registers one or multiple middleware handlers to the current group.
func (group *RouterGroup[T]) Bind(middlewares ...*hook.Handler[T]) *RouterGroup[T] {
group.Middlewares = append(group.Middlewares, middlewares...)
// unmark the newly added middlewares in case they were previously "excluded"
if group.excludedMiddlewares != nil {
for _, m := range middlewares {
if m.Id != "" {
delete(group.excludedMiddlewares, m.Id)
}
}
}
return group
}
// Unbind removes one or more middlewares with the specified id(s)
// from the current group and its children (if any).
//
// Anonymous middlewares are not removable, aka. this method does nothing
// if the middleware id is an empty string.
func (group *RouterGroup[T]) Unbind(middlewareIds ...string) *RouterGroup[T] {
for _, middlewareId := range middlewareIds {
if middlewareId == "" {
continue
}
// remove from the group middlwares
for i := len(group.Middlewares) - 1; i >= 0; i-- {
if group.Middlewares[i].Id == middlewareId {
group.Middlewares = append(group.Middlewares[:i], group.Middlewares[i+1:]...)
}
}
// remove from the group children
for i := len(group.children) - 1; i >= 0; i-- {
switch v := group.children[i].(type) {
case *RouterGroup[T]:
v.Unbind(middlewareId)
case *Route[T]:
v.Unbind(middlewareId)
}
}
// add to the exclude list
if group.excludedMiddlewares == nil {
group.excludedMiddlewares = map[string]struct{}{}
}
group.excludedMiddlewares[middlewareId] = struct{}{}
}
return group
}
// Route registers a single route into the current group.
//
// Note that the final route path will be the concatenation of all parent groups prefixes + the route path.
// The path follows the standard Go net/http ServeMux format ("[HOST]/[PATH]"),
// meaning that only a top level group route could have HOST as part of the prefix.
//
// Returns the newly created route to allow attaching route-only middlewares.
func (group *RouterGroup[T]) Route(method string, path string, action hook.HandlerFunc[T]) *Route[T] {
route := &Route[T]{
Method: method,
Path: path,
Action: action,
}
group.children = append(group.children, route)
return route
}
// Any is a shorthand for [Group.AddRoute] with "" as route method (aka. matches any method).
func (group *RouterGroup[T]) Any(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route("", path, action)
}
// GET is a shorthand for [Group.AddRoute] with GET as route method.
func (group *RouterGroup[T]) GET(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodGet, path, action)
}
// POST is a shorthand for [Group.AddRoute] with POST as route method.
func (group *RouterGroup[T]) POST(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodPost, path, action)
}
// DELETE is a shorthand for [Group.AddRoute] with DELETE as route method.
func (group *RouterGroup[T]) DELETE(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodDelete, path, action)
}
// PATCH is a shorthand for [Group.AddRoute] with PATCH as route method.
func (group *RouterGroup[T]) PATCH(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodPatch, path, action)
}
// PUT is a shorthand for [Group.AddRoute] with PUT as route method.
func (group *RouterGroup[T]) PUT(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodPut, path, action)
}
// HEAD is a shorthand for [Group.AddRoute] with HEAD as route method.
func (group *RouterGroup[T]) HEAD(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodHead, path, action)
}
// OPTIONS is a shorthand for [Group.AddRoute] with OPTIONS as route method.
func (group *RouterGroup[T]) OPTIONS(path string, action hook.HandlerFunc[T]) *Route[T] {
return group.Route(http.MethodOptions, path, action)
}
// HasRoute checks whether the specified route pattern (method + path)
// is registered in the current group or its children.
//
// This could be useful to conditionally register and checks for routes
// in order prevent panic on duplicated routes.
//
// Note that routes with anonymous and named wildcard placeholder are treated as equal,
// aka. "GET /abc/" is considered the same as "GET /abc/{something...}".
func (group *RouterGroup[T]) HasRoute(method string, path string) bool {
pattern := path
if method != "" {
pattern = strings.ToUpper(method) + " " + pattern
}
return group.hasRoute(pattern, nil)
}
func (group *RouterGroup[T]) hasRoute(pattern string, parents []*RouterGroup[T]) bool {
for _, child := range group.children {
switch v := child.(type) {
case *RouterGroup[T]:
if v.hasRoute(pattern, append(parents, group)) {
return true
}
case *Route[T]:
var result string
if v.Method != "" {
result += v.Method + " "
}
// add parent groups prefixes
for _, p := range parents {
result += p.Prefix
}
// add current group prefix
result += group.Prefix
// add current route path
result += v.Path
if result == pattern || // direct match
// compares without the named wildcard, aka. /abc/{test...} is equal to /abc/
stripWildcard(result) == stripWildcard(pattern) {
return true
}
}
}
return false
}
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
func stripWildcard(pattern string) string {
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "/")
}
+425
View File
@@ -0,0 +1,425 @@
package router
import (
"errors"
"fmt"
"net/http"
"slices"
"testing"
"github.com/pocketbase/pocketbase/tools/hook"
)
func TestRouterGroupGroup(t *testing.T) {
t.Parallel()
g0 := RouterGroup[*Event]{}
g1 := g0.Group("test1")
g2 := g0.Group("test2")
if total := len(g0.children); total != 2 {
t.Fatalf("Expected %d child groups, got %d", 2, total)
}
if g1.Prefix != "test1" {
t.Fatalf("Expected g1 with prefix %q, got %q", "test1", g1.Prefix)
}
if g2.Prefix != "test2" {
t.Fatalf("Expected g2 with prefix %q, got %q", "test2", g2.Prefix)
}
}
func TestRouterGroupBindFunc(t *testing.T) {
t.Parallel()
g := RouterGroup[*Event]{}
calls := ""
// append one function
g.BindFunc(func(e *Event) error {
calls += "a"
return nil
})
// append multiple functions
g.BindFunc(
func(e *Event) error {
calls += "b"
return nil
},
func(e *Event) error {
calls += "c"
return nil
},
)
if total := len(g.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range g.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
}
}
func TestRouterGroupBind(t *testing.T) {
t.Parallel()
g := RouterGroup[*Event]{
// mock excluded middlewares to check whether the entry will be deleted
excludedMiddlewares: map[string]struct{}{"test2": {}},
}
calls := ""
// append one handler
g.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil
},
})
// append multiple handlers
g.Bind(
&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil
},
},
&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil
},
},
)
if total := len(g.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range g.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls %q, got %q", "abc", calls)
}
// ensures that the previously excluded middleware was removed
if len(g.excludedMiddlewares) != 0 {
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", g.excludedMiddlewares)
}
}
func TestRouterGroupUnbind(t *testing.T) {
t.Parallel()
g := RouterGroup[*Event]{}
calls := ""
// anonymous middlewares
g.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil // unused value
},
})
// middlewares with id
g.Bind(&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil // unused value
},
})
g.Bind(&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil // unused value
},
})
g.Bind(&hook.Handler[*Event]{
Id: "test3",
Func: func(e *Event) error {
calls += "d"
return nil // unused value
},
})
// remove
g.Unbind("") // should be no-op
g.Unbind("test1", "test3")
if total := len(g.Middlewares); total != 2 {
t.Fatalf("Expected %d middlewares, got %v", 2, total)
}
for _, h := range g.Middlewares {
if err := h.Func(nil); err != nil {
continue
}
}
if calls != "ac" {
t.Fatalf("Expected calls %q, got %q", "ac", calls)
}
// ensure that the ids were added in the exclude list
excluded := []string{"test1", "test3"}
if len(g.excludedMiddlewares) != len(excluded) {
t.Fatalf("Expected excludes %v, got %v", excluded, g.excludedMiddlewares)
}
for id := range g.excludedMiddlewares {
if !slices.Contains(excluded, id) {
t.Fatalf("Expected %q to be marked as excluded", id)
}
}
}
func TestRouterGroupRoute(t *testing.T) {
t.Parallel()
group := RouterGroup[*Event]{}
sub := group.Group("sub")
var called bool
route := group.Route(http.MethodPost, "/test", func(e *Event) error {
called = true
return nil
})
// ensure that the route was registered only to the main one
// ---
if len(sub.children) != 0 {
t.Fatalf("Expected no sub children, got %d", len(sub.children))
}
if len(group.children) != 2 {
t.Fatalf("Expected %d group children, got %d", 2, len(group.children))
}
// ---
// check the registered route
// ---
if route != group.children[1] {
t.Fatalf("Expected group children %v, got %v", route, group.children[1])
}
if route.Method != http.MethodPost {
t.Fatalf("Expected route method %q, got %q", http.MethodPost, route.Method)
}
if route.Path != "/test" {
t.Fatalf("Expected route path %q, got %q", "/test", route.Path)
}
route.Action(nil)
if !called {
t.Fatal("Expected route action to be called")
}
}
func TestRouterGroupRouteAliases(t *testing.T) {
t.Parallel()
group := RouterGroup[*Event]{}
testErr := errors.New("test")
testAction := func(e *Event) error {
return testErr
}
scenarios := []struct {
route *Route[*Event]
expectMethod string
expectPath string
}{
{
group.Any("/test", testAction),
"",
"/test",
},
{
group.GET("/test", testAction),
http.MethodGet,
"/test",
},
{
group.POST("/test", testAction),
http.MethodPost,
"/test",
},
{
group.DELETE("/test", testAction),
http.MethodDelete,
"/test",
},
{
group.PATCH("/test", testAction),
http.MethodPatch,
"/test",
},
{
group.PUT("/test", testAction),
http.MethodPut,
"/test",
},
{
group.HEAD("/test", testAction),
http.MethodHead,
"/test",
},
{
group.OPTIONS("/test", testAction),
http.MethodOptions,
"/test",
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.expectMethod, s.expectPath), func(t *testing.T) {
if s.route.Method != s.expectMethod {
t.Fatalf("Expected method %q, got %q", s.expectMethod, s.route.Method)
}
if s.route.Path != s.expectPath {
t.Fatalf("Expected path %q, got %q", s.expectPath, s.route.Path)
}
if err := s.route.Action(nil); !errors.Is(err, testErr) {
t.Fatal("Expected test action")
}
})
}
}
func TestRouterGroupHasRoute(t *testing.T) {
t.Parallel()
group := RouterGroup[*Event]{}
group.Any("/any", nil)
group.GET("/base", nil)
group.DELETE("/base", nil)
sub := group.Group("/sub1")
sub.GET("/a", nil)
sub.POST("/a", nil)
sub2 := sub.Group("/sub2")
sub2.GET("/b", nil)
sub2.GET("/b/{test}", nil)
// special cases to test the normalizations
group.GET("/c/", nil) // the same as /c/{test...}
group.GET("/d/{test...}", nil) // the same as /d/
scenarios := []struct {
method string
path string
expected bool
}{
{
http.MethodGet,
"",
false,
},
{
"",
"/any",
true,
},
{
http.MethodPost,
"/base",
false,
},
{
http.MethodGet,
"/base",
true,
},
{
http.MethodDelete,
"/base",
true,
},
{
http.MethodGet,
"/sub1",
false,
},
{
http.MethodGet,
"/sub1/a",
true,
},
{
http.MethodPost,
"/sub1/a",
true,
},
{
http.MethodDelete,
"/sub1/a",
false,
},
{
http.MethodGet,
"/sub2/b",
false,
},
{
http.MethodGet,
"/sub1/sub2/b",
true,
},
{
http.MethodGet,
"/sub1/sub2/b/{test}",
true,
},
{
http.MethodGet,
"/sub1/sub2/b/{test2}",
false,
},
{
http.MethodGet,
"/c/{test...}",
true,
},
{
http.MethodGet,
"/d/",
true,
},
}
for _, s := range scenarios {
t.Run(s.method+"_"+s.path, func(t *testing.T) {
has := group.HasRoute(s.method, s.path)
if has != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, has)
}
})
}
}
+60
View File
@@ -0,0 +1,60 @@
package router
import (
"bytes"
"io"
)
var (
_ io.ReadCloser = (*RereadableReadCloser)(nil)
_ Rereader = (*RereadableReadCloser)(nil)
)
// Rereader defines an interface for rewindable readers.
type Rereader interface {
Reread()
}
// RereadableReadCloser defines a wrapper around a io.ReadCloser reader
// allowing to read the original reader multiple times.
type RereadableReadCloser struct {
io.ReadCloser
copy *bytes.Buffer
active io.Reader
}
// Read implements the standard io.Reader interface.
//
// It reads up to len(b) bytes into b and at at the same time writes
// the read data into an internal bytes buffer.
//
// On EOF the r is "rewinded" to allow reading from r multiple times.
func (r *RereadableReadCloser) Read(b []byte) (int, error) {
if r.active == nil {
if r.copy == nil {
r.copy = &bytes.Buffer{}
}
r.active = io.TeeReader(r.ReadCloser, r.copy)
}
n, err := r.active.Read(b)
if err == io.EOF {
r.Reread()
}
return n, err
}
// Reread satisfies the [Rereader] interface and resets the r internal state to allow rereads.
//
// note: not named Reset to avoid conflicts with other reader interfaces.
func (r *RereadableReadCloser) Reread() {
if r.copy == nil || r.copy.Len() == 0 {
return // nothing to reset or it has been already reset
}
oldCopy := r.copy
r.copy = &bytes.Buffer{}
r.active = io.TeeReader(oldCopy, r.copy)
}
@@ -0,0 +1,28 @@
package router_test
import (
"io"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestRereadableReadCloser(t *testing.T) {
content := "test"
rereadable := &router.RereadableReadCloser{
ReadCloser: io.NopCloser(strings.NewReader(content)),
}
// read multiple times
for i := 0; i < 3; i++ {
result, err := io.ReadAll(rereadable)
if err != nil {
t.Fatalf("[read:%d] %v", i, err)
}
if str := string(result); str != content {
t.Fatalf("[read:%d] Expected %q, got %q", i, content, result)
}
}
}
+73
View File
@@ -0,0 +1,73 @@
package router
import "github.com/pocketbase/pocketbase/tools/hook"
type Route[T hook.Resolver] struct {
excludedMiddlewares map[string]struct{}
Action hook.HandlerFunc[T]
Method string
Path string
Middlewares []*hook.Handler[T]
}
// BindFunc registers one or multiple middleware functions to the current route.
//
// The registered middleware functions are "anonymous" and with default priority,
// aka. executes in the order they were registered.
//
// If you need to specify a named middleware (ex. so that it can be removed)
// or middleware with custom exec prirority, use the [Bind] method.
func (route *Route[T]) BindFunc(middlewareFuncs ...hook.HandlerFunc[T]) *Route[T] {
for _, m := range middlewareFuncs {
route.Middlewares = append(route.Middlewares, &hook.Handler[T]{Func: m})
}
return route
}
// Bind registers one or multiple middleware handlers to the current route.
func (route *Route[T]) Bind(middlewares ...*hook.Handler[T]) *Route[T] {
route.Middlewares = append(route.Middlewares, middlewares...)
// unmark the newly added middlewares in case they were previously "excluded"
if route.excludedMiddlewares != nil {
for _, m := range middlewares {
if m.Id != "" {
delete(route.excludedMiddlewares, m.Id)
}
}
}
return route
}
// Unbind removes one or more middlewares with the specified id(s) from the current route.
//
// It also adds the removed middleware ids to an exclude list so that they could be skipped from
// the execution chain in case the middleware is registered in a parent group.
//
// Anonymous middlewares are considered non-removable, aka. this method
// does nothing if the middleware id is an empty string.
func (route *Route[T]) Unbind(middlewareIds ...string) *Route[T] {
for _, middlewareId := range middlewareIds {
if middlewareId == "" {
continue
}
// remove from the route's middlewares
for i := len(route.Middlewares) - 1; i >= 0; i-- {
if route.Middlewares[i].Id == middlewareId {
route.Middlewares = append(route.Middlewares[:i], route.Middlewares[i+1:]...)
}
}
// add to the exclude list
if route.excludedMiddlewares == nil {
route.excludedMiddlewares = map[string]struct{}{}
}
route.excludedMiddlewares[middlewareId] = struct{}{}
}
return route
}
+168
View File
@@ -0,0 +1,168 @@
package router
import (
"slices"
"testing"
"github.com/pocketbase/pocketbase/tools/hook"
)
func TestRouteBindFunc(t *testing.T) {
t.Parallel()
r := Route[*Event]{}
calls := ""
// append one function
r.BindFunc(func(e *Event) error {
calls += "a"
return nil
})
// append multiple functions
r.BindFunc(
func(e *Event) error {
calls += "b"
return nil
},
func(e *Event) error {
calls += "c"
return nil
},
)
if total := len(r.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range r.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
}
}
func TestRouteBind(t *testing.T) {
t.Parallel()
r := Route[*Event]{
// mock excluded middlewares to check whether the entry will be deleted
excludedMiddlewares: map[string]struct{}{"test2": {}},
}
calls := ""
// append one handler
r.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil
},
})
// append multiple handlers
r.Bind(
&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil
},
},
&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil
},
},
)
if total := len(r.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range r.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls %q, got %q", "abc", calls)
}
// ensures that the previously excluded middleware was removed
if len(r.excludedMiddlewares) != 0 {
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", r.excludedMiddlewares)
}
}
func TestRouteUnbind(t *testing.T) {
t.Parallel()
r := Route[*Event]{}
calls := ""
// anonymous middlewares
r.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil // unused value
},
})
// middlewares with id
r.Bind(&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil // unused value
},
})
r.Bind(&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil // unused value
},
})
r.Bind(&hook.Handler[*Event]{
Id: "test3",
Func: func(e *Event) error {
calls += "d"
return nil // unused value
},
})
// remove
r.Unbind("") // should be no-op
r.Unbind("test1", "test3")
if total := len(r.Middlewares); total != 2 {
t.Fatalf("Expected %d middlewares, got %v", 2, total)
}
for _, h := range r.Middlewares {
if err := h.Func(nil); err != nil {
continue
}
}
if calls != "ac" {
t.Fatalf("Expected calls %q, got %q", "ac", calls)
}
// ensure that the id was added in the exclude list
excluded := []string{"test1", "test3"}
if len(r.excludedMiddlewares) != len(excluded) {
t.Fatalf("Expected excludes %v, got %v", excluded, r.excludedMiddlewares)
}
for id := range r.excludedMiddlewares {
if !slices.Contains(excluded, id) {
t.Fatalf("Expected %q to be marked as excluded", id)
}
}
}
+362
View File
@@ -0,0 +1,362 @@
package router
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"runtime"
"github.com/pocketbase/pocketbase/tools/hook"
)
type EventCleanupFunc func()
// EventFactoryFunc defines the function responsible for creating a Route specific event
// based on the provided request handler ServeHTTP data.
//
// Optionally return a clean up function that will be invoked right after the route execution.
type EventFactoryFunc[T hook.Resolver] func(w http.ResponseWriter, r *http.Request) (T, EventCleanupFunc)
// Router defines a thin wrapper around the standard Go [http.ServeMux] by
// adding support for routing sub-groups, middlewares and other common utils.
//
// Example:
//
// r := NewRouter[*MyEvent](eventFactory)
//
// // middlewares
// r.BindFunc(m1, m2)
//
// // routes
// r.GET("/test", handler1)
//
// // sub-routers/groups
// api := r.Group("/api")
// api.GET("/admins", handler2)
//
// // generate a http.ServeMux instance based on the router configurations
// mux, _ := r.BuildMux()
//
// http.ListenAndServe("localhost:8090", mux)
type Router[T hook.Resolver] struct {
*RouterGroup[T]
eventFactory EventFactoryFunc[T]
}
// NewRouter creates a new Router instance with the provided event factory function.
func NewRouter[T hook.Resolver](eventFactory EventFactoryFunc[T]) *Router[T] {
return &Router[T]{
RouterGroup: &RouterGroup[T]{},
eventFactory: eventFactory,
}
}
// BuildMux constructs a new mux [http.Handler] instance from the current router configurations.
func (r *Router[T]) BuildMux() (http.Handler, error) {
// Note that some of the default std Go handlers like the [http.NotFoundHandler]
// cannot be currently extended and requires defining a custom "catch-all" route
// so that the group middlewares could be executed.
//
// https://github.com/golang/go/issues/65648
if !r.HasRoute("", "/") {
r.Route("", "/", func(e T) error {
return NewNotFoundError("", nil)
})
}
mux := http.NewServeMux()
if err := r.loadMux(mux, r.RouterGroup, nil); err != nil {
return nil, err
}
return mux, nil
}
func (r *Router[T]) loadMux(mux *http.ServeMux, group *RouterGroup[T], parents []*RouterGroup[T]) error {
for _, child := range group.children {
switch v := child.(type) {
case *RouterGroup[T]:
if err := r.loadMux(mux, v, append(parents, group)); err != nil {
return err
}
case *Route[T]:
routeHook := &hook.Hook[T]{}
var pattern string
if v.Method != "" {
pattern = v.Method + " "
}
// add parent groups middlewares
for _, p := range parents {
pattern += p.Prefix
for _, h := range p.Middlewares {
if _, ok := p.excludedMiddlewares[h.Id]; !ok {
if _, ok = group.excludedMiddlewares[h.Id]; !ok {
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
routeHook.Bind(h)
}
}
}
}
}
// add current groups middlewares
pattern += group.Prefix
for _, h := range group.Middlewares {
if _, ok := group.excludedMiddlewares[h.Id]; !ok {
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
routeHook.Bind(h)
}
}
}
// add current route middlewares
pattern += v.Path
for _, h := range v.Middlewares {
if _, ok := v.excludedMiddlewares[h.Id]; !ok {
routeHook.Bind(h)
}
}
// add global panic-recover middleware
routeHook.Bind(&hook.Handler[T]{
Func: r.panicHandler,
Priority: -9999999, // before everything else
})
mux.HandleFunc(pattern, func(resp http.ResponseWriter, req *http.Request) {
// wrap the response to add write and status tracking
resp = &ResponseWriter{ResponseWriter: resp}
// wrap the request body to allow multiple reads
req.Body = &RereadableReadCloser{ReadCloser: req.Body}
event, cleanupFunc := r.eventFactory(resp, req)
// trigger the handler hook chain
err := routeHook.Trigger(event, v.Action)
if err != nil {
ErrorHandler(resp, req, err)
}
if cleanupFunc != nil {
cleanupFunc()
}
})
default:
return errors.New("invalid Group item type")
}
}
return nil
}
// panicHandler registers a default panic-recover handling.
func (r *Router[T]) panicHandler(event T) (err error) {
// panic-recover
defer func() {
recoverResult := recover()
if recoverResult == nil {
return
}
recoverErr, ok := recoverResult.(error)
if !ok {
recoverErr = fmt.Errorf("%v", recoverResult)
} else if errors.Is(recoverErr, http.ErrAbortHandler) {
// don't recover ErrAbortHandler so the response to the client can be aborted
panic(recoverResult)
}
stack := make([]byte, 2<<10) // 2 KB
length := runtime.Stack(stack, true)
err = NewInternalServerError("", fmt.Errorf("[PANIC RECOVER] %w %s", recoverErr, stack[:length]))
}()
err = event.Next()
return err
}
func ErrorHandler(resp http.ResponseWriter, req *http.Request, err error) {
if err == nil {
return
}
if ok, _ := getWritten(resp); ok {
return // a response was already written (aka. already handled)
}
header := resp.Header()
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/json")
}
apiErr := ToApiError(err)
resp.WriteHeader(apiErr.Status)
if req.Method != http.MethodHead {
if jsonErr := json.NewEncoder(resp).Encode(apiErr); jsonErr != nil {
log.Println(jsonErr) // truly rare case, log to stderr only for dev purposes
}
}
}
// -------------------------------------------------------------------
type WriteTracker interface {
// Written reports whether a write operation has occurred.
Written() bool
}
type StatusTracker interface {
// Status reports the written response status code.
Status() int
}
type flushErrorer interface {
FlushError() error
}
var (
_ WriteTracker = (*ResponseWriter)(nil)
_ StatusTracker = (*ResponseWriter)(nil)
_ http.Flusher = (*ResponseWriter)(nil)
_ http.Hijacker = (*ResponseWriter)(nil)
_ http.Pusher = (*ResponseWriter)(nil)
_ io.ReaderFrom = (*ResponseWriter)(nil)
_ flushErrorer = (*ResponseWriter)(nil)
)
// ResponseWriter wraps a http.ResponseWriter to track its write state.
type ResponseWriter struct {
http.ResponseWriter
written bool
status int
}
func (rw *ResponseWriter) WriteHeader(status int) {
if rw.written {
return
}
rw.written = true
rw.status = status
rw.ResponseWriter.WriteHeader(status)
}
func (rw *ResponseWriter) Write(b []byte) (int, error) {
if !rw.written {
rw.WriteHeader(http.StatusOK)
}
return rw.ResponseWriter.Write(b)
}
// Written implements [WriteTracker] and returns whether the current response body has been already written.
func (rw *ResponseWriter) Written() bool {
return rw.written
}
// Written implements [StatusTracker] and returns the written status code of the current response.
func (rw *ResponseWriter) Status() int {
return rw.status
}
// Flush implements [http.Flusher] and allows an HTTP handler to flush buffered data to the client.
// This method is no-op if the wrapped writer doesn't support it.
func (rw *ResponseWriter) Flush() {
_ = rw.FlushError()
}
// FlushError is similar to [Flush] but returns [http.ErrNotSupported]
// if the wrapped writer doesn't support it.
func (rw *ResponseWriter) FlushError() error {
err := http.NewResponseController(rw.ResponseWriter).Flush()
if err == nil || !errors.Is(err, http.ErrNotSupported) {
rw.written = true
}
return err
}
// Hijack implements [http.Hijacker] and allows an HTTP handler to take over the current connection.
func (rw *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw.ResponseWriter).Hijack()
}
// Pusher implements [http.Pusher] to indicate HTTP/2 server push support.
func (rw *ResponseWriter) Push(target string, opts *http.PushOptions) error {
w := rw.ResponseWriter
for {
switch p := w.(type) {
case http.Pusher:
return p.Push(target, opts)
case RWUnwrapper:
w = p.Unwrap()
default:
return http.ErrNotSupported
}
}
}
// ReaderFrom implements [io.ReaderFrom] by checking if the underlying writer supports it.
// Otherwise calls [io.Copy].
func (rw *ResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
if !rw.written {
rw.WriteHeader(http.StatusOK)
}
w := rw.ResponseWriter
for {
switch rf := w.(type) {
case io.ReaderFrom:
return rf.ReadFrom(r)
case RWUnwrapper:
w = rf.Unwrap()
default:
return io.Copy(rw.ResponseWriter, r)
}
}
}
// Unwrap returns the underlying ResponseWritter instance (usually used by [http.ResponseController]).
func (rw *ResponseWriter) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}
func getWritten(rw http.ResponseWriter) (bool, error) {
for {
switch w := rw.(type) {
case WriteTracker:
return w.Written(), nil
case RWUnwrapper:
rw = w.Unwrap()
default:
return false, http.ErrNotSupported
}
}
}
func getStatus(rw http.ResponseWriter) (int, error) {
for {
switch w := rw.(type) {
case StatusTracker:
return w.Status(), nil
case RWUnwrapper:
rw = w.Unwrap()
default:
return 0, http.ErrNotSupported
}
}
}
+258
View File
@@ -0,0 +1,258 @@
package router_test
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestRouter(t *testing.T) {
calls := ""
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
return &router.Event{
Response: w,
Request: r,
},
func() {
calls += ":cleanup"
}
})
r.BindFunc(func(e *router.Event) error {
calls += "root_m:"
err := e.Next()
if err != nil {
calls += "/error"
}
return err
})
r.Any("/any", func(e *router.Event) error {
calls += "/any"
return nil
})
r.GET("/a", func(e *router.Event) error {
calls += "/a"
return nil
})
g1 := r.Group("/a/b").BindFunc(func(e *router.Event) error {
calls += "a_b_group_m:"
return e.Next()
})
g1.GET("/1", func(e *router.Event) error {
calls += "/1_get"
return nil
}).BindFunc(func(e *router.Event) error {
calls += "1_get_m:"
return e.Next()
})
g1.POST("/1", func(e *router.Event) error {
calls += "/1_post"
return nil
})
g1.GET("/{param}", func(e *router.Event) error {
calls += "/" + e.Request.PathValue("param")
return errors.New("test") // should be normalized to an ApiError
})
g1.GET("/panic", func(e *router.Event) error {
calls += "/panic"
panic("test")
})
mux, err := r.BuildMux()
if err != nil {
t.Fatal(err)
}
ts := httptest.NewServer(mux)
defer ts.Close()
client := ts.Client()
scenarios := []struct {
method string
path string
calls string
}{
{http.MethodGet, "/any", "root_m:/any:cleanup"},
{http.MethodOptions, "/any", "root_m:/any:cleanup"},
{http.MethodPatch, "/any", "root_m:/any:cleanup"},
{http.MethodPut, "/any", "root_m:/any:cleanup"},
{http.MethodPost, "/any", "root_m:/any:cleanup"},
{http.MethodDelete, "/any", "root_m:/any:cleanup"},
// ---
{http.MethodPost, "/a", "root_m:/error:cleanup"}, // missing
{http.MethodGet, "/a", "root_m:/a:cleanup"},
{http.MethodHead, "/a", "root_m:/a:cleanup"}, // auto registered with the GET
{http.MethodGet, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
{http.MethodHead, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
{http.MethodPost, "/a/b/1", "root_m:a_b_group_m:/1_post:cleanup"},
{http.MethodGet, "/a/b/456", "root_m:a_b_group_m:/456/error:cleanup"},
{http.MethodGet, "/a/b/panic", "root_m:a_b_group_m:/panic:cleanup"},
}
for _, s := range scenarios {
t.Run(s.method+"_"+s.path, func(t *testing.T) {
calls = "" // reset
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
if err != nil {
t.Fatal(err)
}
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
if calls != s.calls {
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
}
})
}
}
func TestRouterUnbind(t *testing.T) {
calls := ""
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
return &router.Event{
Response: w,
Request: r,
},
func() {
calls += ":cleanup"
}
})
r.Bind(&hook.Handler[*router.Event]{
Id: "root_1",
Func: func(e *router.Event) error {
calls += "root_1:"
return e.Next()
},
})
r.Bind(&hook.Handler[*router.Event]{
Id: "root_2",
Func: func(e *router.Event) error {
calls += "root_2:"
return e.Next()
},
})
r.Bind(&hook.Handler[*router.Event]{
Id: "root_3",
Func: func(e *router.Event) error {
calls += "root_3:"
return e.Next()
},
})
r.GET("/action", func(e *router.Event) error {
calls += "root_action"
return nil
}).Unbind("root_1")
ga := r.Group("/group_a")
ga.Unbind("root_1")
ga.Bind(&hook.Handler[*router.Event]{
Id: "group_a_1",
Func: func(e *router.Event) error {
calls += "group_a_1:"
return e.Next()
},
})
ga.Bind(&hook.Handler[*router.Event]{
Id: "group_a_2",
Func: func(e *router.Event) error {
calls += "group_a_2:"
return e.Next()
},
})
ga.Bind(&hook.Handler[*router.Event]{
Id: "group_a_3",
Func: func(e *router.Event) error {
calls += "group_a_3:"
return e.Next()
},
})
ga.GET("/action", func(e *router.Event) error {
calls += "group_a_action"
return nil
}).Unbind("root_2", "group_b_1", "group_a_1")
gb := r.Group("/group_b")
gb.Unbind("root_2")
gb.Bind(&hook.Handler[*router.Event]{
Id: "group_b_1",
Func: func(e *router.Event) error {
calls += "group_b_1:"
return e.Next()
},
})
gb.Bind(&hook.Handler[*router.Event]{
Id: "group_b_2",
Func: func(e *router.Event) error {
calls += "group_b_2:"
return e.Next()
},
})
gb.Bind(&hook.Handler[*router.Event]{
Id: "group_b_3",
Func: func(e *router.Event) error {
calls += "group_b_3:"
return e.Next()
},
})
gb.GET("/action", func(e *router.Event) error {
calls += "group_b_action"
return nil
}).Unbind("group_b_3", "group_a_3", "root_3")
mux, err := r.BuildMux()
if err != nil {
t.Fatal(err)
}
ts := httptest.NewServer(mux)
defer ts.Close()
client := ts.Client()
scenarios := []struct {
method string
path string
calls string
}{
{http.MethodGet, "/action", "root_2:root_3:root_action:cleanup"},
{http.MethodGet, "/group_a/action", "root_3:group_a_2:group_a_3:group_a_action:cleanup"},
{http.MethodGet, "/group_b/action", "root_1:group_b_1:group_b_2:group_b_action:cleanup"},
}
for _, s := range scenarios {
t.Run(s.method+"_"+s.path, func(t *testing.T) {
calls = "" // reset
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
if err != nil {
t.Fatal(err)
}
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
if calls != s.calls {
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
}
})
}
}
+330
View File
@@ -0,0 +1,330 @@
package router
import (
"encoding"
"encoding/json"
"errors"
"reflect"
"strconv"
)
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
// JSONPayloadKey is the key for the special UnmarshalRequestData case
// used for reading serialized json payload without normalization.
const JSONPayloadKey string = "@jsonPayload"
// UnmarshalRequestData unmarshals url.Values type of data (query, multipart/form-data, etc.) into dst.
//
// dst must be a pointer to a map[string]any or struct.
//
// If dst is a map[string]any, each data value will be inferred and
// converted to its bool, numeric, or string equivalent value
// (refer to inferValue() for the exact rules).
//
// If dst is a struct, the following field types are supported:
// - bool
// - string
// - int, int8, int16, int32, int64
// - uint, uint8, uint16, uint32, uint64
// - float32, float64
// - serialized json string if submitted under the special "@jsonPayload" key
// - encoding.TextUnmarshaler
// - pointer and slice variations of the above primitives (ex. *string, []string, *[]string []*string, etc.)
// - named/anonymous struct fields
// Dot-notation is used to target nested fields, ex. "nestedStructField.title".
// - embedded struct fields
// The embedded struct fields are treated by default as if they were defined in their parent struct.
// If the embedded struct has a tag matching structTagKey then to set its fields the data keys must be prefixed with that tag
// similar to the regular nested struct fields.
//
// structTagKey and structPrefix are used only when dst is a struct.
//
// structTagKey represents the tag to use to match a data entry with a struct field (defaults to "form").
// If the struct field doesn't have the structTagKey tag, then the exported struct field name will be used as it is.
//
// structPrefix could be provided if all of the data keys are prefixed with a common string
// and you want the struct field to match only the value without the structPrefix
// (ex. for "user.name", "user.email" data keys and structPrefix "user", it will match "name" and "email" struct fields).
//
// Note that while the method was inspired by binders from echo, gorrila/schema, ozzo-routing
// and other similar common routing packages, it is not intended to be a drop-in replacement.
//
// @todo Consider adding support for dot-notation keys, in addition to the prefix, (ex. parent.child.title) to express nested object keys.
func UnmarshalRequestData(data map[string][]string, dst any, structTagKey string, structPrefix string) error {
if len(data) == 0 {
return nil // nothing to unmarshal
}
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Pointer {
return errors.New("dst must be a pointer")
}
dstValue = dereference(dstValue)
dstType := dstValue.Type()
switch dstType.Kind() {
case reflect.Map: // map[string]any
if dstType.Elem().Kind() != reflect.Interface {
return errors.New("dst map value type must be any/interface{}")
}
for k, v := range data {
if k == JSONPayloadKey {
continue // unmarshalled separately
}
total := len(v)
if total == 1 {
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(inferValue(v[0])))
} else {
normalized := make([]any, total)
for i, vItem := range v {
normalized[i] = inferValue(vItem)
}
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized))
}
}
case reflect.Struct:
// set a default tag key
if structTagKey == "" {
structTagKey = "form"
}
err := unmarshalInStructValue(data, dstValue, structTagKey, structPrefix)
if err != nil {
return err
}
default:
return errors.New("dst must be a map[string]any or struct")
}
// @jsonPayload
//
// Special case to scan serialized json string without
// normalization alongside the other data values
// ---------------------------------------------------------------
jsonPayloadValues := data[JSONPayloadKey]
for _, payload := range jsonPayloadValues {
if err := json.Unmarshal([]byte(payload), dst); err != nil {
return err
}
}
return nil
}
// unmarshalInStructValue unmarshals data into the provided struct reflect.Value fields.
func unmarshalInStructValue(
data map[string][]string,
dstStructValue reflect.Value,
structTagKey string,
structPrefix string,
) error {
dstStructType := dstStructValue.Type()
for i := 0; i < dstStructValue.NumField(); i++ {
fieldType := dstStructType.Field(i)
tag := fieldType.Tag.Get(structTagKey)
if tag == "-" || (!fieldType.Anonymous && !fieldType.IsExported()) {
continue // disabled or unexported non-anonymous struct field
}
fieldValue := dereference(dstStructValue.Field(i))
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isSlice := ft.Kind() == reflect.Slice
if isSlice {
ft = ft.Elem()
}
name := tag
if name == "" && !fieldType.Anonymous {
name = fieldType.Name
}
if name != "" && structPrefix != "" {
name = structPrefix + "." + name
}
// (*)encoding.TextUnmarshaler field
// ---
if ft.Implements(textUnmarshalerType) || reflect.PointerTo(ft).Implements(textUnmarshalerType) {
values, ok := data[name]
if !ok || len(values) == 0 || !fieldValue.CanSet() {
continue // no value to load or the field cannot be set
}
if isSlice {
n := len(values)
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
for i, v := range values {
unmarshaler, ok := dereference(slice.Index(i)).Addr().Interface().(encoding.TextUnmarshaler)
if ok {
if err := unmarshaler.UnmarshalText([]byte(v)); err != nil {
return err
}
}
}
fieldValue.Set(slice)
} else {
unmarshaler, ok := fieldValue.Addr().Interface().(encoding.TextUnmarshaler)
if ok {
if err := unmarshaler.UnmarshalText([]byte(values[0])); err != nil {
return err
}
}
}
continue
}
// "regular" field
// ---
if ft.Kind() != reflect.Struct {
values, ok := data[name]
if !ok || len(values) == 0 || !fieldValue.CanSet() {
continue // no value to load
}
if isSlice {
n := len(values)
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
for i, v := range values {
if err := setRegularReflectedValue(dereference(slice.Index(i)), v); err != nil {
return err
}
}
fieldValue.Set(slice)
} else {
if err := setRegularReflectedValue(fieldValue, values[0]); err != nil {
return err
}
}
continue
}
// structs (embedded or nested)
// ---
// slice of structs
if isSlice {
// populating slice of structs is not supported at the moment
// because the filling rules are ambiguous
continue
}
if tag != "" {
structPrefix = tag
} else {
structPrefix = name // name is empty for anonymous structs -> no prefix
}
if err := unmarshalInStructValue(data, fieldValue, structTagKey, structPrefix); err != nil {
return err
}
}
return nil
}
// dereference returns the underlying value v points to.
func dereference(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
// initialize with a new value and continue searching
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
// setRegularReflectedValue sets and casts value into rv.
func setRegularReflectedValue(rv reflect.Value, value string) error {
switch rv.Kind() {
case reflect.String:
rv.SetString(value)
case reflect.Bool:
if value == "" {
value = "f"
}
v, err := strconv.ParseBool(value)
if err != nil {
return err
}
rv.SetBool(v)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
if value == "" {
value = "0"
}
v, err := strconv.ParseInt(value, 0, 64)
if err != nil {
return err
}
rv.SetInt(v)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
if value == "" {
value = "0"
}
v, err := strconv.ParseUint(value, 0, 64)
if err != nil {
return err
}
rv.SetUint(v)
case reflect.Float32, reflect.Float64:
if value == "" {
value = "0"
}
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
rv.SetFloat(v)
default:
return errors.New("unknown value type " + rv.Kind().String())
}
return nil
}
// In order to support more seamlessly both json and multipart/form-data requests,
// the following normalization rules are applied for plain multipart string values:
// - "true" is converted to the json `true`
// - "false" is converted to the json `false`
// - numeric (non-scientific) strings are converted to json number
// - any other string (empty string too) is left as it is
func inferValue(raw string) any {
switch raw {
case "":
return raw
case "true":
return true
case "false":
return false
default:
// try to convert to number
if raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9') {
v, err := strconv.ParseFloat(raw, 64)
if err == nil {
return v
}
}
return raw
}
}
+450
View File
@@ -0,0 +1,450 @@
package router_test
import (
"bytes"
"encoding/json"
"testing"
"time"
"github.com/pocketbase/pocketbase/tools/router"
)
func pointer[T any](val T) *T {
return &val
}
func TestUnmarshalRequestData(t *testing.T) {
t.Parallel()
mapData := map[string][]string{
"number1": {"1"},
"number2": {"2", "3"},
"number3": {"2.1", "-3.4"},
"string0": {""},
"string1": {"a"},
"string2": {"b", "c"},
"bool1": {"true"},
"bool2": {"true", "false"},
"mixed": {"true", "123", "test"},
"@jsonPayload": {`{"json_a":null,"json_b":123}`, `{"json_c":[1,2,3]}`},
}
structData := map[string][]string{
"stringTag": {"a", "b"},
"StringPtr": {"b"},
"StringSlice": {"a", "b", "c", ""},
"stringSlicePtrTag": {"d", "e"},
"StringSliceOfPtr": {"f", "g"},
"boolTag": {"true"},
"BoolPtr": {"true"},
"BoolSlice": {"true", "false", ""},
"boolSlicePtrTag": {"false", "false", "true"},
"BoolSliceOfPtr": {"false", "true", "false"},
"int8Tag": {"-1", "2"},
"Int8Ptr": {"3"},
"Int8Slice": {"4", "5", ""},
"int8SlicePtrTag": {"5", "6"},
"Int8SliceOfPtr": {"7", "8"},
"int16Tag": {"-1", "2"},
"Int16Ptr": {"3"},
"Int16Slice": {"4", "5", ""},
"int16SlicePtrTag": {"5", "6"},
"Int16SliceOfPtr": {"7", "8"},
"int32Tag": {"-1", "2"},
"Int32Ptr": {"3"},
"Int32Slice": {"4", "5", ""},
"int32SlicePtrTag": {"5", "6"},
"Int32SliceOfPtr": {"7", "8"},
"int64Tag": {"-1", "2"},
"Int64Ptr": {"3"},
"Int64Slice": {"4", "5", ""},
"int64SlicePtrTag": {"5", "6"},
"Int64SliceOfPtr": {"7", "8"},
"intTag": {"-1", "2"},
"IntPtr": {"3"},
"IntSlice": {"4", "5", ""},
"intSlicePtrTag": {"5", "6"},
"IntSliceOfPtr": {"7", "8"},
"uint8Tag": {"1", "2"},
"Uint8Ptr": {"3"},
"Uint8Slice": {"4", "5", ""},
"uint8SlicePtrTag": {"5", "6"},
"Uint8SliceOfPtr": {"7", "8"},
"uint16Tag": {"1", "2"},
"Uint16Ptr": {"3"},
"Uint16Slice": {"4", "5", ""},
"uint16SlicePtrTag": {"5", "6"},
"Uint16SliceOfPtr": {"7", "8"},
"uint32Tag": {"1", "2"},
"Uint32Ptr": {"3"},
"Uint32Slice": {"4", "5", ""},
"uint32SlicePtrTag": {"5", "6"},
"Uint32SliceOfPtr": {"7", "8"},
"uint64Tag": {"1", "2"},
"Uint64Ptr": {"3"},
"Uint64Slice": {"4", "5", ""},
"uint64SlicePtrTag": {"5", "6"},
"Uint64SliceOfPtr": {"7", "8"},
"uintTag": {"1", "2"},
"UintPtr": {"3"},
"UintSlice": {"4", "5", ""},
"uintSlicePtrTag": {"5", "6"},
"UintSliceOfPtr": {"7", "8"},
"float32Tag": {"-1.2"},
"Float32Ptr": {"1.5", "2.0"},
"Float32Slice": {"1", "2.3", "-0.3", ""},
"float32SlicePtrTag": {"-1.3", "3"},
"Float32SliceOfPtr": {"0", "1.2"},
"float64Tag": {"-1.2"},
"Float64Ptr": {"1.5", "2.0"},
"Float64Slice": {"1", "2.3", "-0.3", ""},
"float64SlicePtrTag": {"-1.3", "3"},
"Float64SliceOfPtr": {"0", "1.2"},
"timeTag": {"2009-11-10T15:00:00Z"},
"TimePtr": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
"TimeSlice": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
"timeSlicePtrTag": {"2009-11-10T15:00:00Z", "2009-11-10T16:00:00Z"},
"TimeSliceOfPtr": {"2009-11-10T17:00:00Z", "2009-11-10T18:00:00Z"},
// @jsonPayload fields
"@jsonPayload": {
`{"payloadA":"test", "shouldBeIgnored": "abc"}`,
`{"payloadB":[1,2,3], "payloadC":true}`,
},
// unexported fields or `-` tags
"unexperted": {"test"},
"SkipExported": {"test"},
"unexportedStructFieldWithoutTag.Name": {"test"},
"unexportedStruct.Name": {"test"},
// structs
"StructWithoutTag.Name": {"test1"},
"exportedStruct.Name": {"test2"},
// embedded
"embed_name": {"test3"},
"embed2.embed_name2": {"test4"},
}
type embed1 struct {
Name string `form:"embed_name" json:"embed_name"`
}
type embed2 struct {
Name string `form:"embed_name2" json:"embed_name2"`
}
//nolint
type TestStruct struct {
String string `form:"stringTag" query:"stringTag2"`
StringPtr *string
StringSlice []string
StringSlicePtr *[]string `form:"stringSlicePtrTag"`
StringSliceOfPtr []*string
Bool bool `form:"boolTag" query:"boolTag2"`
BoolPtr *bool
BoolSlice []bool
BoolSlicePtr *[]bool `form:"boolSlicePtrTag"`
BoolSliceOfPtr []*bool
Int8 int8 `form:"int8Tag" query:"int8Tag2"`
Int8Ptr *int8
Int8Slice []int8
Int8SlicePtr *[]int8 `form:"int8SlicePtrTag"`
Int8SliceOfPtr []*int8
Int16 int16 `form:"int16Tag" query:"int16Tag2"`
Int16Ptr *int16
Int16Slice []int16
Int16SlicePtr *[]int16 `form:"int16SlicePtrTag"`
Int16SliceOfPtr []*int16
Int32 int32 `form:"int32Tag" query:"int32Tag2"`
Int32Ptr *int32
Int32Slice []int32
Int32SlicePtr *[]int32 `form:"int32SlicePtrTag"`
Int32SliceOfPtr []*int32
Int64 int64 `form:"int64Tag" query:"int64Tag2"`
Int64Ptr *int64
Int64Slice []int64
Int64SlicePtr *[]int64 `form:"int64SlicePtrTag"`
Int64SliceOfPtr []*int64
Int int `form:"intTag" query:"intTag2"`
IntPtr *int
IntSlice []int
IntSlicePtr *[]int `form:"intSlicePtrTag"`
IntSliceOfPtr []*int
Uint8 uint8 `form:"uint8Tag" query:"uint8Tag2"`
Uint8Ptr *uint8
Uint8Slice []uint8
Uint8SlicePtr *[]uint8 `form:"uint8SlicePtrTag"`
Uint8SliceOfPtr []*uint8
Uint16 uint16 `form:"uint16Tag" query:"uint16Tag2"`
Uint16Ptr *uint16
Uint16Slice []uint16
Uint16SlicePtr *[]uint16 `form:"uint16SlicePtrTag"`
Uint16SliceOfPtr []*uint16
Uint32 uint32 `form:"uint32Tag" query:"uint32Tag2"`
Uint32Ptr *uint32
Uint32Slice []uint32
Uint32SlicePtr *[]uint32 `form:"uint32SlicePtrTag"`
Uint32SliceOfPtr []*uint32
Uint64 uint64 `form:"uint64Tag" query:"uint64Tag2"`
Uint64Ptr *uint64
Uint64Slice []uint64
Uint64SlicePtr *[]uint64 `form:"uint64SlicePtrTag"`
Uint64SliceOfPtr []*uint64
Uint uint `form:"uintTag" query:"uintTag2"`
UintPtr *uint
UintSlice []uint
UintSlicePtr *[]uint `form:"uintSlicePtrTag"`
UintSliceOfPtr []*uint
Float32 float32 `form:"float32Tag" query:"float32Tag2"`
Float32Ptr *float32
Float32Slice []float32
Float32SlicePtr *[]float32 `form:"float32SlicePtrTag"`
Float32SliceOfPtr []*float32
Float64 float64 `form:"float64Tag" query:"float64Tag2"`
Float64Ptr *float64
Float64Slice []float64
Float64SlicePtr *[]float64 `form:"float64SlicePtrTag"`
Float64SliceOfPtr []*float64
// encoding.TextUnmarshaler
Time time.Time `form:"timeTag" query:"timeTag2"`
TimePtr *time.Time
TimeSlice []time.Time
TimeSlicePtr *[]time.Time `form:"timeSlicePtrTag"`
TimeSliceOfPtr []*time.Time
// @jsonPayload fields
JSONPayloadA string `form:"shouldBeIgnored" json:"payloadA"`
JSONPayloadB []int `json:"payloadB"`
JSONPayloadC bool `json:"-"`
// unexported fields or `-` tags
unexported string
SkipExported string `form:"-"`
unexportedStructFieldWithoutTag struct {
Name string `json:"unexportedStructFieldWithoutTag_name"`
}
unexportedStructFieldWithTag struct {
Name string `json:"unexportedStructFieldWithTag_name"`
} `form:"unexportedStruct"`
// structs
StructWithoutTag struct {
Name string `json:"StructWithoutTag_name"`
}
StructWithTag struct {
Name string `json:"StructWithTag_name"`
} `form:"exportedStruct"`
// embedded
embed1
embed2 `form:"embed2"`
}
scenarios := []struct {
name string
data map[string][]string
dst any
tag string
prefix string
error bool
result string
}{
{
name: "nil data",
data: nil,
dst: pointer(map[string]any{}),
error: false,
result: `{}`,
},
{
name: "non-pointer map[string]any",
data: mapData,
dst: map[string]any{},
error: true,
},
{
name: "unsupported *map[string]string",
data: mapData,
dst: pointer(map[string]string{}),
error: true,
},
{
name: "unsupported *map[string][]string",
data: mapData,
dst: pointer(map[string][]string{}),
error: true,
},
{
name: "*map[string]any",
data: mapData,
dst: pointer(map[string]any{}),
result: `{"bool1":true,"bool2":[true,false],"json_a":null,"json_b":123,"json_c":[1,2,3],"mixed":[true,123,"test"],"number1":1,"number2":[2,3],"number3":[2.1,-3.4],"string0":"","string1":"a","string2":["b","c"]}`,
},
{
name: "valid pointer struct (all fields)",
data: structData,
dst: &TestStruct{},
result: `{"String":"a","StringPtr":"b","StringSlice":["a","b","c",""],"StringSlicePtr":["d","e"],"StringSliceOfPtr":["f","g"],"Bool":true,"BoolPtr":true,"BoolSlice":[true,false,false],"BoolSlicePtr":[false,false,true],"BoolSliceOfPtr":[false,true,false],"Int8":-1,"Int8Ptr":3,"Int8Slice":[4,5,0],"Int8SlicePtr":[5,6],"Int8SliceOfPtr":[7,8],"Int16":-1,"Int16Ptr":3,"Int16Slice":[4,5,0],"Int16SlicePtr":[5,6],"Int16SliceOfPtr":[7,8],"Int32":-1,"Int32Ptr":3,"Int32Slice":[4,5,0],"Int32SlicePtr":[5,6],"Int32SliceOfPtr":[7,8],"Int64":-1,"Int64Ptr":3,"Int64Slice":[4,5,0],"Int64SlicePtr":[5,6],"Int64SliceOfPtr":[7,8],"Int":-1,"IntPtr":3,"IntSlice":[4,5,0],"IntSlicePtr":[5,6],"IntSliceOfPtr":[7,8],"Uint8":1,"Uint8Ptr":3,"Uint8Slice":"BAUA","Uint8SlicePtr":"BQY=","Uint8SliceOfPtr":[7,8],"Uint16":1,"Uint16Ptr":3,"Uint16Slice":[4,5,0],"Uint16SlicePtr":[5,6],"Uint16SliceOfPtr":[7,8],"Uint32":1,"Uint32Ptr":3,"Uint32Slice":[4,5,0],"Uint32SlicePtr":[5,6],"Uint32SliceOfPtr":[7,8],"Uint64":1,"Uint64Ptr":3,"Uint64Slice":[4,5,0],"Uint64SlicePtr":[5,6],"Uint64SliceOfPtr":[7,8],"Uint":1,"UintPtr":3,"UintSlice":[4,5,0],"UintSlicePtr":[5,6],"UintSliceOfPtr":[7,8],"Float32":-1.2,"Float32Ptr":1.5,"Float32Slice":[1,2.3,-0.3,0],"Float32SlicePtr":[-1.3,3],"Float32SliceOfPtr":[0,1.2],"Float64":-1.2,"Float64Ptr":1.5,"Float64Slice":[1,2.3,-0.3,0],"Float64SlicePtr":[-1.3,3],"Float64SliceOfPtr":[0,1.2],"Time":"2009-11-10T15:00:00Z","TimePtr":"2009-11-10T14:00:00Z","TimeSlice":["2009-11-10T14:00:00Z","2009-11-10T15:00:00Z"],"TimeSlicePtr":["2009-11-10T15:00:00Z","2009-11-10T16:00:00Z"],"TimeSliceOfPtr":["2009-11-10T17:00:00Z","2009-11-10T18:00:00Z"],"payloadA":"test","payloadB":[1,2,3],"SkipExported":"","StructWithoutTag":{"StructWithoutTag_name":"test1"},"StructWithTag":{"StructWithTag_name":"test2"},"embed_name":"test3","embed_name2":"test4"}`,
},
{
name: "non-pointer struct",
data: structData,
dst: TestStruct{},
error: true,
},
{
name: "invalid struct uint value",
data: map[string][]string{"uintTag": {"-1"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct int value",
data: map[string][]string{"intTag": {"abc"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct bool value",
data: map[string][]string{"boolTag": {"abc"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct float value",
data: map[string][]string{"float64Tag": {"abc"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct TextUnmarshaler value",
data: map[string][]string{"timeTag": {"123"}},
dst: &TestStruct{},
error: true,
},
{
name: "custom tagKey",
data: map[string][]string{
"tag1": {"a"},
"tag2": {"b"},
"tag3": {"c"},
"Item": {"d"},
},
dst: &struct {
Item string `form:"tag1" query:"tag2" json:"tag2"`
}{},
tag: "query",
result: `{"tag2":"b"}`,
},
{
name: "custom prefix",
data: map[string][]string{
"test.A": {"1"},
"A": {"2"},
"test.alias": {"3"},
},
dst: &struct {
A string
B string `form:"alias"`
}{},
prefix: "test",
result: `{"A":"1","B":"3"}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := router.UnmarshalRequestData(s.data, s.dst, s.tag, s.prefix)
hasErr := err != nil
if hasErr != s.error {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.error, hasErr, err)
}
if hasErr {
return
}
raw, err := json.Marshal(s.dst)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(raw, []byte(s.result)) {
t.Fatalf("Expected dst \n%s\ngot\n%s", s.result, raw)
}
})
}
}
// note: extra unexported checks in addition to the above test as there
// is no easy way to print nested structs with all their fields.
func TestUnmarshalRequestDataUnexportedFields(t *testing.T) {
t.Parallel()
//nolint:all
type TestStruct struct {
Exported string
unexported string
// to ensure that the reflection doesn't take tags with higher priority than the exported state
unexportedWithTag string `form:"unexportedWithTag" json:"unexportedWithTag"`
}
dst := &TestStruct{}
err := router.UnmarshalRequestData(map[string][]string{
"Exported": {"test"}, // just for reference
"Unexported": {"test"},
"unexported": {"test"},
"UnexportedWithTag": {"test"},
"unexportedWithTag": {"test"},
}, dst, "", "")
if err != nil {
t.Fatal(err)
}
if dst.Exported != "test" {
t.Fatalf("Expected the Exported field to be %q, got %q", "test", dst.Exported)
}
if dst.unexported != "" {
t.Fatalf("Expected the unexported field to remain empty, got %q", dst.unexported)
}
if dst.unexportedWithTag != "" {
t.Fatalf("Expected the unexportedWithTag field to remain empty, got %q", dst.unexportedWithTag)
}
}
+2 -2
View File
@@ -6,9 +6,9 @@ import (
"sync"
)
// FireAndForget executes `f()` in a new go routine and auto recovers if panic.
// FireAndForget executes f() in a new go routine and auto recovers if panic.
//
// **Note:** Use this only if you are not interested in the result of `f()`
// **Note:** Use this only if you are not interested in the result of f()
// and don't want to block the parent go routine.
func FireAndForget(f func(), wg ...*sync.WaitGroup) {
if len(wg) > 0 && wg[0] != nil {
+7 -2
View File
@@ -64,9 +64,10 @@ func (f FilterData) BuildExpr(
}
}
if parsedFilterData.Has(raw) {
return buildParsedFilterExpr(parsedFilterData.Get(raw), fieldResolver)
if data, ok := parsedFilterData.GetOk(raw); ok {
return buildParsedFilterExpr(data, fieldResolver)
}
data, err := fexpr.Parse(raw)
if err != nil {
// depending on the users demand we may allow empty expressions
@@ -78,9 +79,11 @@ func (f FilterData) BuildExpr(
return nil, err
}
// store in cache
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
parsedFilterData.SetIfLessThanLimit(raw, data, 500)
return buildParsedFilterExpr(data, fieldResolver)
}
@@ -431,6 +434,8 @@ func mergeParams(params ...dbx.Params) dbx.Params {
return result
}
// @todo consider adding support for custom single character wildcard
//
// wrapLikeParams wraps each provided param value string with `%`
// if the param doesn't contain an explicit wildcard (`%`) character already.
func wrapLikeParams(params dbx.Params) dbx.Params {
+17 -5
View File
@@ -5,16 +5,20 @@ import (
"math"
"net/url"
"strconv"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/inflector"
"golang.org/x/sync/errgroup"
)
// DefaultPerPage specifies the default returned search result items.
const DefaultPerPage int = 30
// @todo consider making it configurable
//
// MaxPerPage specifies the maximum allowed search result items returned in a single page.
const MaxPerPage int = 500
const MaxPerPage int = 1000
// url search query params
const (
@@ -27,23 +31,23 @@ const (
// Result defines the returned search result structure.
type Result struct {
Items any `json:"items"`
Page int `json:"page"`
PerPage int `json:"perPage"`
TotalItems int `json:"totalItems"`
TotalPages int `json:"totalPages"`
Items any `json:"items"`
}
// Provider represents a single configured search provider instance.
type Provider struct {
fieldResolver FieldResolver
query *dbx.SelectQuery
skipTotal bool
countCol string
page int
perPage int
sort []SortField
filter []FilterData
page int
perPage int
skipTotal bool
}
// NewProvider creates and returns a new search provider.
@@ -208,6 +212,14 @@ func (s *Provider) Exec(items any) (*Result, error) {
return nil, err
}
if expr != "" {
// ensure that _rowid_ expressions are always prefixed with the first FROM table
if sortField.Name == rowidSortKey && !strings.Contains(expr, ".") {
queryInfo := modelsQuery.Info()
if len(queryInfo.From) > 0 {
expr = "[[" + inflector.Columnify(queryInfo.From[0]) + "]]." + expr
}
}
modelsQuery.AndOrderBy(expr)
}
}
+71 -69
View File
@@ -180,38 +180,39 @@ func TestProviderParse(t *testing.T) {
}
for i, s := range scenarios {
r := &testFieldResolver{}
p := NewProvider(r).
Page(initialPage).
PerPage(initialPerPage).
Sort(initialSort).
Filter(initialFilter)
t.Run(fmt.Sprintf("%d_%s", i, s.query), func(t *testing.T) {
r := &testFieldResolver{}
p := NewProvider(r).
Page(initialPage).
PerPage(initialPerPage).
Sort(initialSort).
Filter(initialFilter)
err := p.Parse(s.query)
err := p.Parse(s.query)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
continue
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if p.page != s.expectPage {
t.Errorf("(%d) Expected page %v, got %v", i, s.expectPage, p.page)
}
if p.page != s.expectPage {
t.Fatalf("Expected page %v, got %v", s.expectPage, p.page)
}
if p.perPage != s.expectPerPage {
t.Errorf("(%d) Expected perPage %v, got %v", i, s.expectPerPage, p.perPage)
}
if p.perPage != s.expectPerPage {
t.Fatalf("Expected perPage %v, got %v", s.expectPerPage, p.perPage)
}
encodedSort, _ := json.Marshal(p.sort)
if string(encodedSort) != s.expectSort {
t.Errorf("(%d) Expected sort %v, got \n%v", i, s.expectSort, string(encodedSort))
}
encodedSort, _ := json.Marshal(p.sort)
if string(encodedSort) != s.expectSort {
t.Fatalf("Expected sort %v, got \n%v", s.expectSort, string(encodedSort))
}
encodedFilter, _ := json.Marshal(p.filter)
if string(encodedFilter) != s.expectFilter {
t.Errorf("(%d) Expected filter %v, got \n%v", i, s.expectFilter, string(encodedFilter))
}
encodedFilter, _ := json.Marshal(p.filter)
if string(encodedFilter) != s.expectFilter {
t.Fatalf("Expected filter %v, got \n%v", s.expectFilter, string(encodedFilter))
}
})
}
}
@@ -256,7 +257,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{},
false,
false,
`{"page":1,"perPage":10,"totalItems":2,"totalPages":1,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":10,"totalItems":2,"totalPages":1}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
@@ -270,7 +271,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{},
false,
false,
`{"page":10,"perPage":30,"totalItems":2,"totalPages":1,"items":[]}`,
`{"items":[],"page":10,"perPage":30,"totalItems":2,"totalPages":1}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30 OFFSET 270",
@@ -306,7 +307,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{"test2 != null", "test1 >= 2"},
false,
false,
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2)",
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
@@ -320,7 +321,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{"test2 != null", "test1 >= 2"},
true,
false,
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":-1,"totalPages":-1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":-1,"totalPages":-1}`,
[]string{
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (((test2 IS NOT '' AND test2 IS NOT NULL)))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT " + fmt.Sprint(MaxPerPage),
},
@@ -333,7 +334,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{"test3 != ''"},
false,
false,
`{"page":1,"perPage":10,"totalItems":0,"totalPages":0,"items":[]}`,
`{"items":[],"page":1,"perPage":10,"totalItems":0,"totalPages":0}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL)))",
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
@@ -347,7 +348,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{"test3 != ''"},
true,
false,
`{"page":1,"perPage":10,"totalItems":-1,"totalPages":-1,"items":[]}`,
`{"items":[],"page":1,"perPage":10,"totalItems":-1,"totalPages":-1}`,
[]string{
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (((test3 IS NOT '' AND test3 IS NOT NULL))) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
},
@@ -360,7 +361,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{},
false,
false,
`{"page":2,"perPage":1,"totalItems":2,"totalPages":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":2,"totalPages":2}`,
[]string{
"SELECT COUNT(DISTINCT [[test.id]]) FROM `test` WHERE NOT (`test1` IS NULL)",
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
@@ -374,7 +375,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
[]FilterData{},
true,
false,
`{"page":2,"perPage":1,"totalItems":-1,"totalPages":-1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":2,"perPage":1,"totalItems":-1,"totalPages":-1}`,
[]string{
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
},
@@ -449,7 +450,7 @@ func TestProviderParseAndExec(t *testing.T) {
"no extra query params (aka. use the provider presets)",
"",
false,
`{"page":2,"perPage":123,"totalItems":2,"totalPages":1,"items":[]}`,
`{"items":[],"page":2,"perPage":123,"totalItems":2,"totalPages":1}`,
},
{
"invalid query",
@@ -491,62 +492,63 @@ func TestProviderParseAndExec(t *testing.T) {
"page > existing",
"page=3&perPage=9999",
false,
`{"page":3,"perPage":500,"totalItems":2,"totalPages":1,"items":[]}`,
`{"items":[],"page":3,"perPage":1000,"totalItems":2,"totalPages":1}`,
},
{
"valid query params",
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3",
false,
`{"page":1,"perPage":500,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":1,"totalPages":1}`,
},
{
"valid query params with skipTotal=1",
"page=1&perPage=9999&filter=test1>1&sort=-test2,test3&skipTotal=1",
false,
`{"page":1,"perPage":500,"totalItems":-1,"totalPages":-1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
`{"items":[{"test1":2,"test2":"test2.2","test3":""}],"page":1,"perPage":1000,"totalItems":-1,"totalPages":-1}`,
},
}
for _, s := range scenarios {
testDB.CalledQueries = []string{} // reset
t.Run(s.name, func(t *testing.T) {
testDB.CalledQueries = []string{} // reset
testResolver := &testFieldResolver{}
provider := NewProvider(testResolver).
Query(query).
Page(2).
PerPage(123).
Sort([]SortField{{"test2", SortAsc}}).
Filter([]FilterData{"test1 > 0"})
testResolver := &testFieldResolver{}
provider := NewProvider(testResolver).
Query(query).
Page(2).
PerPage(123).
Sort([]SortField{{"test2", SortAsc}}).
Filter([]FilterData{"test1 > 0"})
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
result, err := provider.ParseAndExec(s.queryString, &[]testTableStruct{})
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
continue
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
continue
}
if hasErr {
return
}
if testResolver.UpdateQueryCalls != 1 {
t.Errorf("[%s] Expected resolver.Update to be called %d, got %d", s.name, 1, testResolver.UpdateQueryCalls)
}
if testResolver.UpdateQueryCalls != 1 {
t.Fatalf("Expected resolver.Update to be called %d, got %d", 1, testResolver.UpdateQueryCalls)
}
expectedQueries := 2
if provider.skipTotal {
expectedQueries = 1
}
expectedQueries := 2
if provider.skipTotal {
expectedQueries = 1
}
if len(testDB.CalledQueries) != expectedQueries {
t.Errorf("[%s] Expected %d db queries, got %d: \n%v", s.name, expectedQueries, len(testDB.CalledQueries), testDB.CalledQueries)
}
if len(testDB.CalledQueries) != expectedQueries {
t.Fatalf("Expected %d db queries, got %d: \n%v", expectedQueries, len(testDB.CalledQueries), testDB.CalledQueries)
}
encoded, _ := json.Marshal(result)
if string(encoded) != s.expectResult {
t.Errorf("[%s] Expected result %v, got \n%v", s.name, s.expectResult, string(encoded))
}
encoded, _ := json.Marshal(result)
if string(encoded) != s.expectResult {
t.Fatalf("Expected result \n%v\ngot\n%v", s.expectResult, string(encoded))
}
})
}
}
+1 -1
View File
@@ -76,7 +76,7 @@ func (r *SimpleFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
// Returns error if `field` is not in `r.allowedFields`.
func (r *SimpleFieldResolver) Resolve(field string) (*ResolverResult, error) {
if !list.ExistInSliceWithRegex(field, r.allowedFields) {
return nil, fmt.Errorf("failed to resolve field %q", field)
return nil, fmt.Errorf("Failed to resolve field %q.", field)
}
parts := strings.Split(field, ".")
+29 -28
View File
@@ -1,6 +1,7 @@
package search_test
import (
"fmt"
"testing"
"github.com/pocketbase/dbx"
@@ -23,22 +24,22 @@ func TestSimpleFieldResolverUpdateQuery(t *testing.T) {
}
for i, s := range scenarios {
db := dbx.NewFromDB(nil, "")
query := db.Select("id").From("test")
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
db := dbx.NewFromDB(nil, "")
query := db.Select("id").From("test")
r.Resolve(s.fieldName)
r.Resolve(s.fieldName)
if err := r.UpdateQuery(nil); err != nil {
t.Errorf("(%d) UpdateQuery failed with error %v", i, err)
continue
}
if err := r.UpdateQuery(nil); err != nil {
t.Fatalf("UpdateQuery failed with error %v", err)
}
rawQuery := query.Build().SQL()
// rawQuery := s.expectQuery
rawQuery := query.Build().SQL()
if rawQuery != s.expectQuery {
t.Errorf("(%d) Expected query %v, got \n%v", i, s.expectQuery, rawQuery)
}
if rawQuery != s.expectQuery {
t.Fatalf("Expected query %v, got \n%v", s.expectQuery, rawQuery)
}
})
}
}
@@ -62,25 +63,25 @@ func TestSimpleFieldResolverResolve(t *testing.T) {
}
for i, s := range scenarios {
r, err := r.Resolve(s.fieldName)
t.Run(fmt.Sprintf("%d_%s", i, s.fieldName), func(t *testing.T) {
r, err := r.Resolve(s.fieldName)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
continue
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
continue
}
if hasErr {
return
}
if r.Identifier != s.expectName {
t.Errorf("(%d) Expected r.Identifier %q, got %q", i, s.expectName, r.Identifier)
}
if r.Identifier != s.expectName {
t.Fatalf("Expected r.Identifier %q, got %q", s.expectName, r.Identifier)
}
// params should be empty
if len(r.Params) != 0 {
t.Errorf("(%d) Expected 0 r.Params, got %v", i, r.Params)
}
if len(r.Params) != 0 {
t.Fatalf("r.Params should be empty, got %v", r.Params)
}
})
}
}
+9 -1
View File
@@ -5,7 +5,10 @@ import (
"strings"
)
const randomSortKey string = "@random"
const (
randomSortKey string = "@random"
rowidSortKey string = "@rowid"
)
// sort field directions
const (
@@ -26,6 +29,11 @@ func (s *SortField) BuildExpr(fieldResolver FieldResolver) (string, error) {
return "RANDOM()", nil
}
// special case for the builtin SQLite rowid column
if s.Name == rowidSortKey {
return fmt.Sprintf("[[_rowid_]] %s", s.Direction), nil
}
result, err := fieldResolver.Resolve(s.Name)
// invalidate empty fields and non-column identifiers
+26 -18
View File
@@ -2,6 +2,7 @@ package search_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/search"
@@ -29,27 +30,30 @@ func TestSortFieldBuildExpr(t *testing.T) {
{search.SortField{"test1", search.SortDesc}, false, "[[test1]] DESC"},
// special @random field (ignore direction)
{search.SortField{"@random", search.SortDesc}, false, "RANDOM()"},
// special _rowid_ field
{search.SortField{"@rowid", search.SortDesc}, false, "[[_rowid_]] DESC"},
}
for i, s := range scenarios {
result, err := s.sortField.BuildExpr(resolver)
for _, s := range scenarios {
t.Run(fmt.Sprintf("%s_%s", s.sortField.Name, s.sortField.Name), func(t *testing.T) {
result, err := s.sortField.BuildExpr(resolver)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("(%d) Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
continue
}
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if result != s.expectExpression {
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectExpression, result)
}
if result != s.expectExpression {
t.Fatalf("Expected expression %v, got %v", s.expectExpression, result)
}
})
}
}
func TestParseSortFromString(t *testing.T) {
scenarios := []struct {
value string
expectedJson string
value string
expected string
}{
{"", `[{"name":"","direction":"ASC"}]`},
{"test", `[{"name":"test","direction":"ASC"}]`},
@@ -57,14 +61,18 @@ func TestParseSortFromString(t *testing.T) {
{"-test", `[{"name":"test","direction":"DESC"}]`},
{"test1,-test2,+test3", `[{"name":"test1","direction":"ASC"},{"name":"test2","direction":"DESC"},{"name":"test3","direction":"ASC"}]`},
{"@random,-test", `[{"name":"@random","direction":"ASC"},{"name":"test","direction":"DESC"}]`},
{"-@rowid,-test", `[{"name":"@rowid","direction":"DESC"},{"name":"test","direction":"DESC"}]`},
}
for i, s := range scenarios {
result := search.ParseSortFromString(s.value)
encoded, _ := json.Marshal(result)
for _, s := range scenarios {
t.Run(s.value, func(t *testing.T) {
result := search.ParseSortFromString(s.value)
encoded, _ := json.Marshal(result)
encodedStr := string(encoded)
if string(encoded) != s.expectedJson {
t.Errorf("(%d) Expected expression %v, got %v", i, s.expectedJson, string(encoded))
}
if encodedStr != s.expected {
t.Fatalf("Expected expression %s, got %s", s.expected, encodedStr)
}
})
}
}
+38 -33
View File
@@ -1,6 +1,7 @@
package security_test
import (
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/security"
@@ -17,30 +18,30 @@ func TestEncrypt(t *testing.T) {
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
}
for i, scenario := range scenarios {
result, err := security.Encrypt([]byte(scenario.data), scenario.key)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.data), func(t *testing.T) {
result, err := security.Encrypt([]byte(s.data), s.key)
if scenario.expectError && err == nil {
t.Errorf("(%d) Expected error got nil", i)
}
if !scenario.expectError && err != nil {
t.Errorf("(%d) Expected nil got error %v", i, err)
}
hasErr := err != nil
if scenario.expectError && result != "" {
t.Errorf("(%d) Expected empty string, got %q", i, result)
}
if !scenario.expectError && result == "" {
t.Errorf("(%d) Expected non empty encrypted result string", i)
}
// try to decrypt
if result != "" {
decrypted, _ := security.Decrypt(result, scenario.key)
if string(decrypted) != scenario.data {
t.Errorf("(%d) Expected decrypted value to match with the data input, got %q", i, decrypted)
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
}
if hasErr {
if result != "" {
t.Fatalf("Expected empty Encrypt result on error, got %q", result)
}
return
}
// try to decrypt
decrypted, err := security.Decrypt(result, s.key)
if err != nil || string(decrypted) != s.data {
t.Fatalf("Expected decrypted value to match with the data input, got %q (%v)", decrypted, err)
}
})
}
}
@@ -57,19 +58,23 @@ func TestDecrypt(t *testing.T) {
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
}
for i, scenario := range scenarios {
result, err := security.Decrypt(scenario.cipher, scenario.key)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
result, err := security.Decrypt(s.cipher, s.key)
if scenario.expectError && err == nil {
t.Errorf("(%d) Expected error got nil", i)
}
if !scenario.expectError && err != nil {
t.Errorf("(%d) Expected nil got error %v", i, err)
}
hasErr := err != nil
resultStr := string(result)
if resultStr != scenario.expectedData {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expectedData, resultStr)
}
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if str := string(result); str != s.expectedData {
t.Fatalf("Expected %q, got %q", s.expectedData, str)
}
})
}
}
+3 -12
View File
@@ -4,6 +4,7 @@ import (
"errors"
"time"
// @todo update to v5
"github.com/golang-jwt/jwt/v4"
)
@@ -43,11 +44,9 @@ func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
}
// NewJWT generates and returns new HS256 signed JWT.
func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
seconds := time.Duration(secondsDuration) * time.Second
func NewJWT(payload jwt.MapClaims, signingKey string, duration time.Duration) (string, error) {
claims := jwt.MapClaims{
"exp": time.Now().Add(seconds).Unix(),
"exp": time.Now().Add(duration).Unix(),
}
for k, v := range payload {
@@ -56,11 +55,3 @@ func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (st
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
}
// Deprecated:
// Consider replacing with NewJWT().
//
// NewToken is a legacy alias for NewJWT that generates a HS256 signed JWT.
func NewToken(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
return NewJWT(payload, signingKey, secondsDuration)
}
+61 -54
View File
@@ -1,7 +1,10 @@
package security_test
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pocketbase/pocketbase/tools/security"
@@ -102,26 +105,30 @@ func TestParseJWT(t *testing.T) {
},
}
for i, scenario := range scenarios {
result, err := security.ParseJWT(scenario.token, scenario.secret)
if scenario.expectError && err == nil {
t.Errorf("(%d) Expected error got nil", i)
}
if !scenario.expectError && err != nil {
t.Errorf("(%d) Expected nil got error %v", i, err)
}
if len(result) != len(scenario.expectClaims) {
t.Errorf("(%d) Expected %v got %v", i, scenario.expectClaims, result)
}
for k, v := range scenario.expectClaims {
v2, ok := result[k]
if !ok {
t.Errorf("(%d) Missing expected claim %q", i, k)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.token), func(t *testing.T) {
result, err := security.ParseJWT(s.token, s.secret)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if v != v2 {
t.Errorf("(%d) Expected %v for %q claim, got %v", i, v, k, v2)
if len(result) != len(s.expectClaims) {
t.Fatalf("Expected %v claims got %v", s.expectClaims, result)
}
}
for k, v := range s.expectClaims {
v2, ok := result[k]
if !ok {
t.Fatalf("Missing expected claim %q", k)
}
if v != v2 {
t.Fatalf("Expected %v for %q claim, got %v", v, k, v2)
}
}
})
}
}
@@ -129,51 +136,51 @@ func TestNewJWT(t *testing.T) {
scenarios := []struct {
claims jwt.MapClaims
key string
duration int64
duration time.Duration
expectError bool
}{
// empty, zero duration
{jwt.MapClaims{}, "", 0, true},
// empty, 10 seconds duration
{jwt.MapClaims{}, "", 10, false},
{jwt.MapClaims{}, "", 10 * time.Second, false},
// non-empty, 10 seconds duration
{jwt.MapClaims{"name": "test"}, "test", 10, false},
{jwt.MapClaims{"name": "test"}, "test", 10 * time.Second, false},
}
for i, scenario := range scenarios {
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
if tokenErr != nil {
t.Errorf("(%d) Expected NewJWT to succeed, got error %v", i, tokenErr)
continue
}
claims, parseErr := security.ParseJWT(token, scenario.key)
hasParseErr := parseErr != nil
if hasParseErr != scenario.expectError {
t.Errorf("(%d) Expected hasParseErr to be %v, got %v (%v)", i, scenario.expectError, hasParseErr, parseErr)
continue
}
if scenario.expectError {
continue
}
if _, ok := claims["exp"]; !ok {
t.Errorf("(%d) Missing required claim exp, got %v", i, claims)
}
// clear exp claim to match with the scenario ones
delete(claims, "exp")
if len(claims) != len(scenario.claims) {
t.Errorf("(%d) Expected %v claims, got %v", i, scenario.claims, claims)
}
for j, k := range claims {
if claims[j] != scenario.claims[j] {
t.Errorf("(%d) Expected %v for %q claim, got %v", i, claims[j], k, scenario.claims[j])
t.Run(strconv.Itoa(i), func(t *testing.T) {
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
if tokenErr != nil {
t.Fatalf("Expected NewJWT to succeed, got error %v", tokenErr)
}
}
claims, parseErr := security.ParseJWT(token, scenario.key)
hasParseErr := parseErr != nil
if hasParseErr != scenario.expectError {
t.Fatalf("Expected hasParseErr to be %v, got %v (%v)", scenario.expectError, hasParseErr, parseErr)
}
if scenario.expectError {
return
}
if _, ok := claims["exp"]; !ok {
t.Fatalf("Missing required claim exp, got %v", claims)
}
// clear exp claim to match with the scenario ones
delete(claims, "exp")
if len(claims) != len(scenario.claims) {
t.Fatalf("Expected %v claims, got %v", scenario.claims, claims)
}
for j, k := range claims {
if claims[j] != scenario.claims[j] {
t.Fatalf("Expected %v for %q claim, got %v", claims[j], k, scenario.claims[j])
}
}
})
}
}
+1 -6
View File
@@ -3,16 +3,11 @@ package security
import (
cryptoRand "crypto/rand"
"math/big"
mathRand "math/rand"
"time"
mathRand "math/rand" // @todo replace with rand/v2?
)
const defaultRandomAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
func init() {
mathRand.Seed(time.Now().UnixNano())
}
// RandomString generates a cryptographically random string with the specified length.
//
// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding.
+152
View File
@@ -0,0 +1,152 @@
package security
import (
cryptoRand "crypto/rand"
"fmt"
"math/big"
"regexp/syntax"
"strings"
)
const defaultMaxRepeat = 6
var anyCharNotNLPairs = []rune{'A', 'Z', 'a', 'z', '0', '9'}
// RandomStringByRegex generates a random string matching the regex pattern.
// If optFlags is not set, fallbacks to [syntax.Perl].
//
// NB! While the source of the randomness comes from [crypto/rand] this method
// is not recommended to be used on its own in critical secure contexts because
// the generated length could vary too much on the used pattern and may not be
// as secure as simply calling [security.RandomString].
// If you still insist on using it for such purposes, consider at least
// a large enough minimum length for the generated string, e.g. `[a-z0-9]{30}`.
//
// This function is inspired by github.com/pipe01/revregexp, github.com/lucasjones/reggen and other similar packages.
func RandomStringByRegex(pattern string, optFlags ...syntax.Flags) (string, error) {
var flags syntax.Flags
if len(optFlags) == 0 {
flags = syntax.Perl
} else {
for _, f := range optFlags {
flags |= f
}
}
r, err := syntax.Parse(pattern, flags)
if err != nil {
return "", err
}
var sb = new(strings.Builder)
err = writeRandomStringByRegex(r, sb)
if err != nil {
return "", err
}
return sb.String(), nil
}
func writeRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder) error {
// https://pkg.go.dev/regexp/syntax#Op
switch r.Op {
case syntax.OpCharClass:
c, err := randomRuneFromPairs(r.Rune)
if err != nil {
return err
}
_, err = sb.WriteRune(c)
return err
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
c, err := randomRuneFromPairs(anyCharNotNLPairs)
if err != nil {
return err
}
_, err = sb.WriteRune(c)
return err
case syntax.OpAlternate:
idx, err := randomNumber(len(r.Sub))
if err != nil {
return err
}
return writeRandomStringByRegex(r.Sub[idx], sb)
case syntax.OpConcat:
var err error
for _, sub := range r.Sub {
err = writeRandomStringByRegex(sub, sb)
if err != nil {
break
}
}
return err
case syntax.OpRepeat:
return repeatRandomStringByRegex(r.Sub[0], sb, r.Min, r.Max)
case syntax.OpQuest:
return repeatRandomStringByRegex(r.Sub[0], sb, 0, 1)
case syntax.OpPlus:
return repeatRandomStringByRegex(r.Sub[0], sb, 1, -1)
case syntax.OpStar:
return repeatRandomStringByRegex(r.Sub[0], sb, 0, -1)
case syntax.OpCapture:
return writeRandomStringByRegex(r.Sub[0], sb)
case syntax.OpLiteral:
_, err := sb.WriteString(string(r.Rune))
return err
default:
return fmt.Errorf("unsupported pattern operator %d", r.Op)
}
}
func repeatRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder, min int, max int) error {
if max < 0 {
max = defaultMaxRepeat
}
if max < min {
max = min
}
n := min
if max != min {
randRange, err := randomNumber(max - min)
if err != nil {
return err
}
n += randRange
}
var err error
for i := 0; i < n; i++ {
err = writeRandomStringByRegex(r, sb)
if err != nil {
return err
}
}
return nil
}
func randomRuneFromPairs(pairs []rune) (rune, error) {
idx, err := randomNumber(len(pairs) / 2)
if err != nil {
return 0, err
}
return randomRuneFromRange(pairs[idx*2], pairs[idx*2+1])
}
func randomRuneFromRange(min rune, max rune) (rune, error) {
offset, err := randomNumber(int(max - min + 1))
if err != nil {
return min, err
}
return min + rune(offset), nil
}
func randomNumber(maxSoft int) (int, error) {
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(maxSoft)))
return int(randRange.Int64()), err
}
+66
View File
@@ -0,0 +1,66 @@
package security_test
import (
"fmt"
"regexp"
"regexp/syntax"
"slices"
"testing"
"github.com/pocketbase/pocketbase/tools/security"
)
func TestRandomStringByRegex(t *testing.T) {
generated := []string{}
scenarios := []struct {
pattern string
flags []syntax.Flags
expectError bool
}{
{``, nil, true},
{`test`, nil, false},
{`\d+`, []syntax.Flags{syntax.POSIX}, true},
{`\d+`, nil, false},
{`\d*`, nil, false},
{`\d{1,10}`, nil, false},
{`\d{3}`, nil, false},
{`\d{0,}-abc`, nil, false},
{`[a-zA-Z]*`, nil, false},
{`[^a-zA-Z]{5,30}`, nil, false},
{`\w+_abc`, nil, false},
{`[a-zA-Z_]*`, nil, false},
{`[2-9]{5}-\w+`, nil, false},
{`(a|b|c)`, nil, false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%q", i, s.pattern), func(t *testing.T) {
str, err := security.RandomStringByRegex(s.pattern, s.flags...)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
r, err := regexp.Compile(s.pattern)
if err != nil {
t.Fatal(err)
}
if !r.Match([]byte(str)) {
t.Fatalf("Expected %q to match pattern %v", str, s.pattern)
}
if slices.Contains(generated, str) {
t.Fatalf("The generated string %q already exists in\n%v", str, generated)
}
generated = append(generated, str)
})
}
}
+115 -24
View File
@@ -1,11 +1,18 @@
package store
import "sync"
import (
"encoding/json"
"sync"
)
// @todo remove after https://github.com/golang/go/issues/20135
const ShrinkThreshold = 200 // the number is arbitrary chosen
// Store defines a concurrent safe in memory key-value data store.
type Store[T any] struct {
data map[string]T
mux sync.RWMutex
data map[string]T
mu sync.RWMutex
deleted int64
}
// New creates a new Store[T] instance with a shallow copy of the provided data (if any).
@@ -20,8 +27,8 @@ func New[T any](data map[string]T) *Store[T] {
// Reset clears the store and replaces the store data with a
// shallow copy of the provided newData.
func (s *Store[T]) Reset(newData map[string]T) {
s.mux.Lock()
defer s.mux.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
if len(newData) > 0 {
s.data = make(map[string]T, len(newData))
@@ -31,38 +38,50 @@ func (s *Store[T]) Reset(newData map[string]T) {
} else {
s.data = make(map[string]T)
}
s.deleted = 0
}
// Length returns the current number of elements in the store.
func (s *Store[T]) Length() int {
s.mux.RLock()
defer s.mux.RUnlock()
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.data)
}
// RemoveAll removes all the existing store entries.
func (s *Store[T]) RemoveAll() {
s.mux.Lock()
defer s.mux.Unlock()
s.data = make(map[string]T)
s.Reset(nil)
}
// Remove removes a single entry from the store.
//
// Remove does nothing if key doesn't exist in the store.
func (s *Store[T]) Remove(key string) {
s.mux.Lock()
defer s.mux.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, key)
s.deleted++
// reassign to a new map so that the old one can be gc-ed because it doesn't shrink
//
// @todo remove after https://github.com/golang/go/issues/20135
if s.deleted >= ShrinkThreshold {
newData := make(map[string]T, len(s.data))
for k, v := range s.data {
newData[k] = v
}
s.data = newData
s.deleted = 0
}
}
// Has checks if element with the specified key exist or not.
func (s *Store[T]) Has(key string) bool {
s.mux.RLock()
defer s.mux.RUnlock()
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.data[key]
@@ -73,16 +92,26 @@ func (s *Store[T]) Has(key string) bool {
//
// If key is not set, the zero T value is returned.
func (s *Store[T]) Get(key string) T {
s.mux.RLock()
defer s.mux.RUnlock()
s.mu.RLock()
defer s.mu.RUnlock()
return s.data[key]
}
// GetOk is similar to Get but returns also a boolean indicating whether the key exists or not.
func (s *Store[T]) GetOk(key string) (T, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.data[key]
return v, ok
}
// GetAll returns a shallow copy of the current store data.
func (s *Store[T]) GetAll() map[string]T {
s.mux.RLock()
defer s.mux.RUnlock()
s.mu.RLock()
defer s.mu.RUnlock()
var clone = make(map[string]T, len(s.data))
@@ -93,10 +122,24 @@ func (s *Store[T]) GetAll() map[string]T {
return clone
}
// Values returns a slice with all of the current store values.
func (s *Store[T]) Values() []T {
s.mu.RLock()
defer s.mu.RUnlock()
var values = make([]T, 0, len(s.data))
for _, v := range s.data {
values = append(values, v)
}
return values
}
// Set sets (or overwrite if already exist) a new value for key.
func (s *Store[T]) Set(key string, value T) {
s.mux.Lock()
defer s.mux.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
if s.data == nil {
s.data = make(map[string]T)
@@ -105,16 +148,34 @@ func (s *Store[T]) Set(key string, value T) {
s.data[key] = value
}
// GetOrSet retrieves a single existing value for the provided key
// or stores a new one if it doesn't exist.
func (s *Store[T]) GetOrSet(key string, setFunc func() T) T {
s.mu.Lock()
defer s.mu.Unlock()
if s.data == nil {
s.data = make(map[string]T)
}
v, ok := s.data[key]
if !ok {
v = setFunc()
s.data[key] = v
}
return v
}
// SetIfLessThanLimit sets (or overwrite if already exist) a new value for key.
//
// This method is similar to Set() but **it will skip adding new elements**
// to the store if the store length has reached the specified limit.
// false is returned if maxAllowedElements limit is reached.
func (s *Store[T]) SetIfLessThanLimit(key string, value T, maxAllowedElements int) bool {
s.mux.Lock()
defer s.mux.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
// init map if not already
if s.data == nil {
s.data = make(map[string]T)
}
@@ -132,3 +193,33 @@ func (s *Store[T]) SetIfLessThanLimit(key string, value T, maxAllowedElements in
return true
}
// UnmarshalJSON implements [json.Unmarshaler] and imports the
// provided JSON data into the store.
//
// The store entries that match with the ones from the data will be overwritten with the new value.
func (s *Store[T]) UnmarshalJSON(data []byte) error {
raw := map[string]T{}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.data == nil {
s.data = make(map[string]T)
}
for k, v := range raw {
s.data[k] = v
}
return nil
}
// MarshalJSON implements [json.Marshaler] and export the current
// store data into valid JSON.
func (s *Store[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(s.GetAll())
}
+163 -5
View File
@@ -3,6 +3,8 @@ package store_test
import (
"bytes"
"encoding/json"
"slices"
"strconv"
"testing"
"github.com/pocketbase/pocketbase/tools/store"
@@ -137,11 +139,41 @@ func TestGet(t *testing.T) {
{"missing", 0}, // should auto fallback to the zero value
}
for i, scenario := range scenarios {
val := s.Get(scenario.key)
if val != scenario.expect {
t.Errorf("(%d) Expected %v, got %v", i, scenario.expect, val)
}
for _, scenario := range scenarios {
t.Run(scenario.key, func(t *testing.T) {
val := s.Get(scenario.key)
if val != scenario.expect {
t.Fatalf("Expected %v, got %v", scenario.expect, val)
}
})
}
}
func TestGetOk(t *testing.T) {
s := store.New(map[string]int{"test1": 0, "test2": 1})
scenarios := []struct {
key string
expectValue int
expectOk bool
}{
{"test1", 0, true},
{"test2", 1, true},
{"missing", 0, false}, // should auto fallback to the zero value
}
for _, scenario := range scenarios {
t.Run(scenario.key, func(t *testing.T) {
val, ok := s.GetOk(scenario.key)
if ok != scenario.expectOk {
t.Fatalf("Expected ok %v, got %v", scenario.expectOk, ok)
}
if val != scenario.expectValue {
t.Fatalf("Expected %v, got %v", scenario.expectValue, val)
}
})
}
}
@@ -173,6 +205,27 @@ func TestGetAll(t *testing.T) {
}
}
func TestValues(t *testing.T) {
data := map[string]int{
"a": 1,
"b": 2,
}
values := store.New(data).Values()
expected := []int{1, 2}
if len(values) != len(expected) {
t.Fatalf("Expected %d values, got %d", len(expected), len(values))
}
for _, v := range expected {
if !slices.Contains(values, v) {
t.Fatalf("Missing value %v in\n%v", v, values)
}
}
}
func TestSet(t *testing.T) {
s := store.Store[int]{}
@@ -196,6 +249,37 @@ func TestSet(t *testing.T) {
}
}
func TestGetOrSet(t *testing.T) {
s := store.New(map[string]int{
"test1": 0,
"test2": 1,
"test3": 3,
})
scenarios := []struct {
key string
value int
expected int
}{
{"test2", 20, 1},
{"test3", 2, 3},
{"test_new", 20, 20},
{"test_new", 50, 20}, // should return the previously inserted value
}
for _, scenario := range scenarios {
t.Run(scenario.key, func(t *testing.T) {
result := s.GetOrSet(scenario.key, func() int {
return scenario.value
})
if result != scenario.expected {
t.Fatalf("Expected %v, got %v", scenario.expected, result)
}
})
}
}
func TestSetIfLessThanLimit(t *testing.T) {
s := store.Store[int]{}
@@ -230,3 +314,77 @@ func TestSetIfLessThanLimit(t *testing.T) {
}
}
}
func TestUnmarshalJSON(t *testing.T) {
s := store.Store[string]{}
s.Set("b", "old") // should be overwritten
s.Set("c", "test3") // ensures that the old values are not removed
raw := []byte(`{"a":"test1", "b":"test2"}`)
if err := json.Unmarshal(raw, &s); err != nil {
t.Fatal(err)
}
if v := s.Get("a"); v != "test1" {
t.Fatalf("Expected store.a to be %q, got %q", "test1", v)
}
if v := s.Get("b"); v != "test2" {
t.Fatalf("Expected store.b to be %q, got %q", "test2", v)
}
if v := s.Get("c"); v != "test3" {
t.Fatalf("Expected store.c to be %q, got %q", "test3", v)
}
}
func TestMarshalJSON(t *testing.T) {
s := &store.Store[string]{}
s.Set("a", "test1")
s.Set("b", "test2")
expected := []byte(`{"a":"test1", "b":"test2"}`)
result, err := json.Marshal(s)
if err != nil {
t.Fatal(err)
}
if bytes.Equal(result, expected) {
t.Fatalf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestShrink(t *testing.T) {
s := &store.Store[int]{}
total := 1000
for i := 0; i < total; i++ {
s.Set(strconv.Itoa(i), i)
}
if s.Length() != total {
t.Fatalf("Expected %d items, got %d", total, s.Length())
}
// trigger map "shrink"
for i := 0; i < store.ShrinkThreshold; i++ {
s.Remove(strconv.Itoa(i))
}
// ensure that after the deletion, the new map was copied properly
if s.Length() != total-store.ShrinkThreshold {
t.Fatalf("Expected %d items, got %d", total-store.ShrinkThreshold, s.Length())
}
for k := range s.GetAll() {
kInt, err := strconv.Atoi(k)
if err != nil {
t.Fatalf("failed to convert %s into int: %v", k, err)
}
if kInt < store.ShrinkThreshold {
t.Fatalf("Key %q should have been deleted", k)
}
}
}
+18 -28
View File
@@ -2,47 +2,41 @@ package subscriptions
import (
"fmt"
"sync"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/store"
)
// Broker defines a struct for managing subscriptions clients.
type Broker struct {
clients map[string]Client
mux sync.RWMutex
store *store.Store[Client]
}
// NewBroker initializes and returns a new Broker instance.
func NewBroker() *Broker {
return &Broker{
clients: make(map[string]Client),
store: store.New[Client](nil),
}
}
// Clients returns a shallow copy of all registered clients indexed
// with their connection id.
func (b *Broker) Clients() map[string]Client {
b.mux.RLock()
defer b.mux.RUnlock()
return b.store.GetAll()
}
copy := make(map[string]Client, len(b.clients))
for id, c := range b.clients {
copy[id] = c
}
return copy
// ChunkedClients splits the current clients into a chunked slice.
func (b *Broker) ChunkedClients(chunkSize int) [][]Client {
return list.ToChunks(b.store.Values(), chunkSize)
}
// ClientById finds a registered client by its id.
//
// Returns non-nil error when client with clientId is not registered.
func (b *Broker) ClientById(clientId string) (Client, error) {
b.mux.RLock()
defer b.mux.RUnlock()
client, ok := b.clients[clientId]
client, ok := b.store.GetOk(clientId)
if !ok {
return nil, fmt.Errorf("No client associated with connection ID %q", clientId)
return nil, fmt.Errorf("no client associated with connection ID %q", clientId)
}
return client, nil
@@ -50,21 +44,17 @@ func (b *Broker) ClientById(clientId string) (Client, error) {
// Register adds a new client to the broker instance.
func (b *Broker) Register(client Client) {
b.mux.Lock()
defer b.mux.Unlock()
b.clients[client.Id()] = client
b.store.Set(client.Id(), client)
}
// Unregister removes a single client by its id.
//
// If client with clientId doesn't exist, this method does nothing.
func (b *Broker) Unregister(clientId string) {
b.mux.Lock()
defer b.mux.Unlock()
if client, ok := b.clients[clientId]; ok {
client.Discard()
delete(b.clients, clientId)
client := b.store.Get(clientId)
if client == nil {
return
}
client.Discard()
b.store.Remove(clientId)
}

Some files were not shown because too many files have changed in this diff Show More