initial public commit
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Broker defines a struct for managing subscriptions clients.
|
||||
type Broker struct {
|
||||
mux sync.RWMutex
|
||||
clients map[string]Client
|
||||
}
|
||||
|
||||
// NewBroker initializes and returns a new Broker instance.
|
||||
func NewBroker() *Broker {
|
||||
return &Broker{
|
||||
clients: make(map[string]Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Clients returns all registered clients.
|
||||
func (b *Broker) Clients() map[string]Client {
|
||||
return b.clients
|
||||
}
|
||||
|
||||
// ClientById finds a registered client by its id.
|
||||
//
|
||||
// Returns non-nil error when client with clientId is not registered.
|
||||
func (b *Broker) ClientById(clientId string) (Client, error) {
|
||||
client, ok := b.clients[clientId]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("No client associated with connection ID %q", clientId)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Register adds a new client to the broker instance.
|
||||
func (b *Broker) Register(client Client) {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
|
||||
b.clients[client.Id()] = client
|
||||
}
|
||||
|
||||
// Unregister removes a single client by its id.
|
||||
//
|
||||
// If client with clientId doesn't exist, this method does nothing.
|
||||
func (b *Broker) Unregister(clientId string) {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
|
||||
// Note:
|
||||
// There is no need to explicitly close the client's channel since it will be GC-ed anyway.
|
||||
// Addinitionally, closing the channel explicitly could panic when there are several
|
||||
// subscriptions attached to the client that needs to receive the same event.
|
||||
delete(b.clients, clientId)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestNewBroker(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
if b.Clients() == nil {
|
||||
t.Fatal("Expected clients map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
if total := len(b.Clients()); total != 0 {
|
||||
t.Fatalf("Expected no clients, got %v", total)
|
||||
}
|
||||
|
||||
b.Register(subscriptions.NewDefaultClient())
|
||||
b.Register(subscriptions.NewDefaultClient())
|
||||
|
||||
if total := len(b.Clients()); total != 2 {
|
||||
t.Fatalf("Expected 2 clients, got %v", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientById(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
clientA := subscriptions.NewDefaultClient()
|
||||
clientB := subscriptions.NewDefaultClient()
|
||||
b.Register(clientA)
|
||||
b.Register(clientB)
|
||||
|
||||
resultClient, err := b.ClientById(clientA.Id())
|
||||
if err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", clientA.Id(), err)
|
||||
}
|
||||
if resultClient.Id() != clientA.Id() {
|
||||
t.Fatalf("Expected client %s, got %s", clientA.Id(), resultClient.Id())
|
||||
}
|
||||
|
||||
if c, err := b.ClientById("missing"); err == nil {
|
||||
t.Fatalf("Expected error, found client %v", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
b.Register(client)
|
||||
|
||||
if _, err := b.ClientById(client.Id()); err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", client.Id(), err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnregister(t *testing.T) {
|
||||
b := subscriptions.NewBroker()
|
||||
|
||||
clientA := subscriptions.NewDefaultClient()
|
||||
clientB := subscriptions.NewDefaultClient()
|
||||
b.Register(clientA)
|
||||
b.Register(clientB)
|
||||
|
||||
if _, err := b.ClientById(clientA.Id()); err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", clientA.Id(), err)
|
||||
}
|
||||
|
||||
b.Unregister(clientA.Id())
|
||||
|
||||
if c, err := b.ClientById(clientA.Id()); err == nil {
|
||||
t.Fatalf("Expected error, found client %v", c)
|
||||
}
|
||||
|
||||
// clientB shouldn't have been removed
|
||||
if _, err := b.ClientById(clientB.Id()); err != nil {
|
||||
t.Fatalf("Expected client with id %s, got error %v", clientB.Id(), err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// Message defines a client's channel data.
|
||||
type Message struct {
|
||||
Name string
|
||||
Data string
|
||||
}
|
||||
|
||||
// Client is an interface for a generic subscription client.
|
||||
type Client interface {
|
||||
// Id Returns the unique id of the client.
|
||||
Id() string
|
||||
|
||||
// Channel returns the client's communication channel.
|
||||
Channel() chan Message
|
||||
|
||||
// Subscriptions returns all subscriptions to which the client has subscribed to.
|
||||
Subscriptions() map[string]struct{}
|
||||
|
||||
// Subscribe subscribes the client to the provided subscriptions list.
|
||||
Subscribe(subs ...string)
|
||||
|
||||
// Unsubscribe unsubscribes the client from the provided subscriptions list.
|
||||
Unsubscribe(subs ...string)
|
||||
|
||||
// HasSubscription checks if the client is subscribed to `sub`.
|
||||
HasSubscription(sub string) bool
|
||||
|
||||
// Set stores any value to the client's context.
|
||||
Set(key string, value any)
|
||||
|
||||
// Get retrieves the key value from the client's context.
|
||||
Get(key string) any
|
||||
}
|
||||
|
||||
// ensures that DefaultClient satisfies the Client interface
|
||||
var _ Client = (*DefaultClient)(nil)
|
||||
|
||||
// DefaultClient defines a generic subscription client.
|
||||
type DefaultClient struct {
|
||||
mux sync.RWMutex
|
||||
id string
|
||||
store map[string]any
|
||||
channel chan Message
|
||||
subscriptions map[string]struct{}
|
||||
}
|
||||
|
||||
// NewDefaultClient creates and returns a new DefaultClient instance.
|
||||
func NewDefaultClient() *DefaultClient {
|
||||
return &DefaultClient{
|
||||
id: security.RandomString(40),
|
||||
store: map[string]any{},
|
||||
channel: make(chan Message),
|
||||
subscriptions: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Id implements the Client.Id interface method.
|
||||
func (c *DefaultClient) Id() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// Channel implements the Client.Channel interface method.
|
||||
func (c *DefaultClient) Channel() chan Message {
|
||||
return c.channel
|
||||
}
|
||||
|
||||
// Subscriptions implements the Client.Subscriptions interface method.
|
||||
func (c *DefaultClient) Subscriptions() map[string]struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
return c.subscriptions
|
||||
}
|
||||
|
||||
// Subscribe implements the Client.Subscribe interface method.
|
||||
//
|
||||
// Empty subscriptions (aka. "") are ignored.
|
||||
func (c *DefaultClient) Subscribe(subs ...string) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
for _, s := range subs {
|
||||
if s == "" {
|
||||
continue // skip empty
|
||||
}
|
||||
|
||||
c.subscriptions[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe implements the Client.Unsubscribe interface method.
|
||||
//
|
||||
// If subs is not set, this method removes all registered client's subscriptions.
|
||||
func (c *DefaultClient) Unsubscribe(subs ...string) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if len(subs) > 0 {
|
||||
for _, s := range subs {
|
||||
delete(c.subscriptions, s)
|
||||
}
|
||||
} else {
|
||||
// unsubsribe all
|
||||
for s := range c.subscriptions {
|
||||
delete(c.subscriptions, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasSubscription implements the Client.HasSubscription interface method.
|
||||
func (c *DefaultClient) HasSubscription(sub string) bool {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
_, ok := c.subscriptions[sub]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// Get implements the Client.Get interface method.
|
||||
func (c *DefaultClient) Get(key string) any {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
return c.store[key]
|
||||
}
|
||||
|
||||
// Set implements the Client.Set interface method.
|
||||
func (c *DefaultClient) Set(key string, value any) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.store[key] = value
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestNewDefaultClient(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.Channel() == nil {
|
||||
t.Errorf("Expected channel to be initialized")
|
||||
}
|
||||
|
||||
if c.Subscriptions() == nil {
|
||||
t.Errorf("Expected subscriptions map to be initialized")
|
||||
}
|
||||
|
||||
if c.Id() == "" {
|
||||
t.Errorf("Expected unique id to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestId(t *testing.T) {
|
||||
clients := []*subscriptions.DefaultClient{
|
||||
subscriptions.NewDefaultClient(),
|
||||
subscriptions.NewDefaultClient(),
|
||||
subscriptions.NewDefaultClient(),
|
||||
subscriptions.NewDefaultClient(),
|
||||
}
|
||||
|
||||
ids := map[string]struct{}{}
|
||||
for i, c := range clients {
|
||||
// check uniqueness
|
||||
if _, ok := ids[c.Id()]; ok {
|
||||
t.Errorf("(%d) Expected unique id, got %v", i, c.Id())
|
||||
} else {
|
||||
ids[c.Id()] = struct{}{}
|
||||
}
|
||||
|
||||
// check length
|
||||
if len(c.Id()) != 40 {
|
||||
t.Errorf("(%d) Expected unique id to have 40 chars length, got %v", i, c.Id())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannel(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.Channel() == nil {
|
||||
t.Errorf("Expected channel to be initialized, got")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptions(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected subscriptions to be empty")
|
||||
}
|
||||
|
||||
c.Subscribe("sub1", "sub2", "sub3")
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
subs := []string{"", "sub1", "sub2", "sub3"}
|
||||
expected := []string{"sub1", "sub2", "sub3"}
|
||||
|
||||
c.Subscribe(subs...) // empty string should be skipped
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
}
|
||||
|
||||
for i, s := range expected {
|
||||
if !c.HasSubscription(s) {
|
||||
t.Errorf("(%d) Expected sub %s", i, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
c.Subscribe("sub1", "sub2", "sub3")
|
||||
|
||||
c.Unsubscribe("sub1")
|
||||
|
||||
if c.HasSubscription("sub1") {
|
||||
t.Error("Expected sub1 to be removed")
|
||||
}
|
||||
|
||||
c.Unsubscribe( /* all */ )
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSubscription(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.HasSubscription("missing") {
|
||||
t.Error("Expected false, got true")
|
||||
}
|
||||
|
||||
c.Subscribe("sub")
|
||||
|
||||
if !c.HasSubscription("sub") {
|
||||
t.Error("Expected true, got false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAndGet(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
c.Set("demo", 1)
|
||||
|
||||
result, _ := c.Get("demo").(int)
|
||||
|
||||
if result != 1 {
|
||||
t.Errorf("Expected 1, got %v", result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user