=added experimental expand, filter, fields, custom query and headers parameters support for the realtime subscriptions
This commit is contained in:
@@ -1,15 +1,29 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const optionsParam = "options"
|
||||
|
||||
// Message defines a client's channel data.
|
||||
type Message struct {
|
||||
Name string
|
||||
Data []byte
|
||||
Name string `json:"name"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
|
||||
type SubscriptionOptions struct {
|
||||
// @todo after the requests handling refactoring consider
|
||||
// changing to map[string]string or map[string][]string
|
||||
Query map[string]any `json:"query"`
|
||||
Headers map[string]any `json:"headers"`
|
||||
}
|
||||
|
||||
// Client is an interface for a generic subscription client.
|
||||
@@ -20,10 +34,20 @@ type Client interface {
|
||||
// Channel returns the client's communication channel.
|
||||
Channel() chan Message
|
||||
|
||||
// Subscriptions returns all subscriptions to which the client has subscribed to.
|
||||
Subscriptions() map[string]struct{}
|
||||
// Subscriptions returns a shallow copy of the the client subscriptions matching the prefixes.
|
||||
// If no prefix is specified, returns all subscriptions.
|
||||
Subscriptions(prefixes ...string) map[string]SubscriptionOptions
|
||||
|
||||
// Subscribe subscribes the client to the provided subscriptions list.
|
||||
//
|
||||
// Each subscription can also have "options" (json serialized SubscriptionOptions) as query parameter.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// Subscribe(
|
||||
// "subscriptionA",
|
||||
// `subscriptionB?options={"query":{"a":1},"headers":{"x_token":"abc"}}`,
|
||||
// )
|
||||
Subscribe(subs ...string)
|
||||
|
||||
// Unsubscribe unsubscribes the client from the provided subscriptions list.
|
||||
@@ -61,8 +85,8 @@ var _ Client = (*DefaultClient)(nil)
|
||||
// DefaultClient defines a generic subscription client.
|
||||
type DefaultClient struct {
|
||||
store map[string]any
|
||||
subscriptions map[string]SubscriptionOptions
|
||||
channel chan Message
|
||||
subscriptions map[string]struct{}
|
||||
id string
|
||||
mux sync.RWMutex
|
||||
isDiscarded bool
|
||||
@@ -74,7 +98,7 @@ func NewDefaultClient() *DefaultClient {
|
||||
id: security.RandomString(40),
|
||||
store: map[string]any{},
|
||||
channel: make(chan Message),
|
||||
subscriptions: make(map[string]struct{}),
|
||||
subscriptions: map[string]SubscriptionOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,11 +119,37 @@ func (c *DefaultClient) Channel() chan Message {
|
||||
}
|
||||
|
||||
// Subscriptions implements the [Client.Subscriptions] interface method.
|
||||
func (c *DefaultClient) Subscriptions() map[string]struct{} {
|
||||
//
|
||||
// It returns a shallow copy of the the client subscriptions matching the prefixes.
|
||||
// If no prefix is specified, returns all subscriptions.
|
||||
func (c *DefaultClient) Subscriptions(prefixes ...string) map[string]SubscriptionOptions {
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return c.subscriptions
|
||||
// no prefix -> return copy of all subscriptions
|
||||
if len(prefixes) == 0 {
|
||||
result := make(map[string]SubscriptionOptions, len(c.subscriptions))
|
||||
|
||||
for s, options := range c.subscriptions {
|
||||
result[s] = options
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
result := make(map[string]SubscriptionOptions)
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
for s, options := range c.subscriptions {
|
||||
// "?" ensures that the options query start character is always there
|
||||
// so that it can be used as an end separator when looking only for the main subscription topic
|
||||
if strings.HasPrefix(s+"?", prefix) {
|
||||
result[s] = options
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Subscribe implements the [Client.Subscribe] interface method.
|
||||
@@ -114,7 +164,30 @@ func (c *DefaultClient) Subscribe(subs ...string) {
|
||||
continue // skip empty
|
||||
}
|
||||
|
||||
c.subscriptions[s] = struct{}{}
|
||||
// extract subscription options (if any)
|
||||
options := SubscriptionOptions{}
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
rawOptions := u.Query().Get(optionsParam)
|
||||
if rawOptions != "" {
|
||||
json.Unmarshal([]byte(rawOptions), &options)
|
||||
}
|
||||
}
|
||||
|
||||
// normalize query
|
||||
// (currently only single string values are supported for consistency with the default routes handling)
|
||||
for k, v := range options.Query {
|
||||
options.Query[k] = cast.ToString(v)
|
||||
}
|
||||
|
||||
// normalize headers name and values, eg. "X-Token" is converted to "x_token"
|
||||
// (currently only single string values are supported for consistency with the default routes handling)
|
||||
for k, v := range options.Headers {
|
||||
delete(options.Headers, k)
|
||||
options.Headers[inflector.Snakecase(k)] = cast.ToString(v)
|
||||
}
|
||||
|
||||
c.subscriptions[s] = options
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -51,7 +53,7 @@ func TestChannel(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.Channel() == nil {
|
||||
t.Errorf("Expected channel to be initialized, got")
|
||||
t.Fatalf("Expected channel to be initialized, got")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,13 +61,35 @@ func TestSubscriptions(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected subscriptions to be empty")
|
||||
t.Fatalf("Expected subscriptions to be empty")
|
||||
}
|
||||
|
||||
c.Subscribe("sub1", "sub2", "sub3")
|
||||
c.Subscribe("sub1", "sub11", "sub2")
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
scenarios := []struct {
|
||||
prefixes []string
|
||||
expected []string
|
||||
}{
|
||||
{nil, []string{"sub1", "sub11", "sub2"}},
|
||||
{[]string{"missing"}, nil},
|
||||
{[]string{"sub1"}, []string{"sub1", "sub11"}},
|
||||
{[]string{"sub2"}, []string{"sub2"}}, // with extra query start char
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(strings.Join(s.prefixes, ","), func(t *testing.T) {
|
||||
subs := c.Subscriptions(s.prefixes...)
|
||||
|
||||
if len(subs) != len(s.expected) {
|
||||
t.Fatalf("Expected %d subscriptions, got %d", len(s.expected), len(subs))
|
||||
}
|
||||
|
||||
for _, s := range s.expected {
|
||||
if _, ok := subs[s]; !ok {
|
||||
t.Fatalf("Missing subscription %q in \n%v", s, subs)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +102,7 @@ func TestSubscribe(t *testing.T) {
|
||||
c.Subscribe(subs...) // empty string should be skipped
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
t.Fatalf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
}
|
||||
|
||||
for i, s := range expected {
|
||||
@@ -88,6 +112,44 @@ func TestSubscribe(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeOptions(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
sub1 := "test1"
|
||||
sub2 := `test2?options={"query":{"name":123},"headers":{"X-Token":456}}`
|
||||
|
||||
c.Subscribe(sub1, sub2)
|
||||
|
||||
subs := c.Subscriptions()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
expectedOptions string
|
||||
}{
|
||||
{sub1, `{"query":null,"headers":null}`},
|
||||
{sub2, `{"query":{"name":"123"},"headers":{"x_token":"456"}}`},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
options, ok := subs[s.name]
|
||||
if !ok {
|
||||
t.Fatalf("Missing subscription \n%q \nin \n%v", s.name, subs)
|
||||
}
|
||||
|
||||
rawBytes, err := json.Marshal(options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(rawBytes)
|
||||
|
||||
if rawStr != s.expectedOptions {
|
||||
t.Fatalf("Expected options \n%v \ngot \n%v", s.expectedOptions, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
@@ -96,12 +158,12 @@ func TestUnsubscribe(t *testing.T) {
|
||||
c.Unsubscribe("sub1")
|
||||
|
||||
if c.HasSubscription("sub1") {
|
||||
t.Error("Expected sub1 to be removed")
|
||||
t.Fatalf("Expected sub1 to be removed")
|
||||
}
|
||||
|
||||
c.Unsubscribe( /* all */ )
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
|
||||
t.Fatalf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user