=added experimental expand, filter, fields, custom query and headers parameters support for the realtime subscriptions
This commit is contained in:
@@ -5,7 +5,10 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
// experimental! (need more tests before replacing encoding/json entirely)
|
||||
// Experimental!
|
||||
//
|
||||
// Need more tests before replacing encoding/json entirely.
|
||||
// Test also encoding/json/v2 once released (see https://github.com/golang/go/discussions/63397)
|
||||
goccy "github.com/goccy/go-json"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
@@ -45,42 +48,11 @@ func (s *Serializer) Serialize(c echo.Context, i any, indent string) error {
|
||||
return s.DefaultJSONSerializer.Serialize(c, i, indent)
|
||||
}
|
||||
|
||||
parsedFields, err := parseFields(rawFields)
|
||||
decoded, err := PickFields(i, rawFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// marshalize the provided data to ensure that the related json.Marshaler
|
||||
// implementations are invoked, and then convert it back to a plain
|
||||
// json value that we can further operate on.
|
||||
//
|
||||
// @todo research other approaches to avoid the double serialization
|
||||
// ---
|
||||
encoded, err := json.Marshal(i) // use the std json since goccy has several bugs reported with struct marshaling and it is not safe
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := goccy.Unmarshal(encoded, &decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
// ---
|
||||
|
||||
var isSearchResult bool
|
||||
switch i.(type) {
|
||||
case search.Result, *search.Result:
|
||||
isSearchResult = true
|
||||
}
|
||||
|
||||
if isSearchResult {
|
||||
if decodedMap, ok := decoded.(map[string]any); ok {
|
||||
pickFields(decodedMap["items"], parsedFields)
|
||||
}
|
||||
} else {
|
||||
pickFields(decoded, parsedFields)
|
||||
}
|
||||
|
||||
enc := goccy.NewEncoder(c.Response())
|
||||
if indent != "" {
|
||||
enc.SetIndent("", indent)
|
||||
@@ -89,6 +61,56 @@ func (s *Serializer) Serialize(c echo.Context, i any, indent string) error {
|
||||
return enc.Encode(decoded)
|
||||
}
|
||||
|
||||
// PickFields parses the provided fields string expression and
|
||||
// returns a new subset of data with only the requested fields.
|
||||
//
|
||||
// Fields transformations with modifiers are also supported (see initModifer()).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data := map[string]any{"a": 1, "b": 2, "c": map[string]any{"c1": 11, "c2": 22}}
|
||||
// PickFields(data, "a,c.c1") // map[string]any{"a": 1, "c": map[string]any{"c1": 11}}
|
||||
func PickFields(data any, rawFields string) (any, error) {
|
||||
parsedFields, err := parseFields(rawFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// marshalize the provided data to ensure that the related json.Marshaler
|
||||
// implementations are invoked, and then convert it back to a plain
|
||||
// json value that we can further operate on.
|
||||
//
|
||||
// @todo research other approaches to avoid the double serialization
|
||||
// ---
|
||||
encoded, err := json.Marshal(data) // use the std json since goccy has several bugs reported with struct marshaling and it is not safe
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := goccy.Unmarshal(encoded, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ---
|
||||
|
||||
// special cases to preserve the same fields format when used with single item or array data.
|
||||
var isSearchResult bool
|
||||
switch data.(type) {
|
||||
case search.Result, *search.Result:
|
||||
isSearchResult = true
|
||||
}
|
||||
|
||||
if isSearchResult {
|
||||
if decodedMap, ok := decoded.(map[string]any); ok {
|
||||
pickParsedFields(decodedMap["items"], parsedFields)
|
||||
}
|
||||
} else {
|
||||
pickParsedFields(decoded, parsedFields)
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func parseFields(rawFields string) (map[string]FieldModifier, error) {
|
||||
t := tokenizer.NewFromString(rawFields)
|
||||
|
||||
@@ -145,7 +167,7 @@ func initModifer(rawModifier string) (FieldModifier, error) {
|
||||
return nil, fmt.Errorf("missing or invalid modifier %q", name)
|
||||
}
|
||||
|
||||
func pickFields(data any, fields map[string]FieldModifier) error {
|
||||
func pickParsedFields(data any, fields map[string]FieldModifier) error {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
pickMapFields(v, fields)
|
||||
@@ -233,7 +255,7 @@ DataLoop:
|
||||
matchingFields[remains] = m
|
||||
}
|
||||
|
||||
if err := pickFields(data[k], matchingFields); err != nil {
|
||||
if err := pickParsedFields(data[k], matchingFields); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -45,6 +46,22 @@ func TestSerialize(t *testing.T) {
|
||||
"fields=missing",
|
||||
`{}`,
|
||||
},
|
||||
{
|
||||
">299 response",
|
||||
rest.Serializer{},
|
||||
300,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"<200 response",
|
||||
rest.Serializer{},
|
||||
199,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"non map response",
|
||||
rest.Serializer{},
|
||||
@@ -69,22 +86,6 @@ func TestSerialize(t *testing.T) {
|
||||
"fields=missing", // test individual fields trim
|
||||
`{}`,
|
||||
},
|
||||
{
|
||||
">299 response",
|
||||
rest.Serializer{},
|
||||
300,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"<200 response",
|
||||
rest.Serializer{},
|
||||
199,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=missing",
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"map with existing and missing fields",
|
||||
rest.Serializer{},
|
||||
@@ -101,21 +102,96 @@ func TestSerialize(t *testing.T) {
|
||||
"custom=a, c ,missing", // test individual fields trim
|
||||
`{"a":1,"c":"test"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.URL.RawQuery = s.query
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Status = s.statusCode
|
||||
|
||||
if err := s.serializer.Serialize(c, s.data, ""); err != nil {
|
||||
t.Fatalf("Serialize failure: %v", err)
|
||||
}
|
||||
|
||||
rawBody, err := io.ReadAll(rec.Result().Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read request body: %v", err)
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(string(rawBody)); v != s.expected {
|
||||
t.Fatalf("Expected body\n%v \ngot \n%v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickFields(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data any
|
||||
fields string
|
||||
expectError bool
|
||||
result string
|
||||
}{
|
||||
{
|
||||
"empty fields",
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"",
|
||||
false,
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"missing fields",
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"missing",
|
||||
false,
|
||||
`{}`,
|
||||
},
|
||||
{
|
||||
"non map data",
|
||||
"test",
|
||||
"a,b,test",
|
||||
false,
|
||||
`"test"`,
|
||||
},
|
||||
{
|
||||
"non slice of map data",
|
||||
[]any{"a", "b", "test"},
|
||||
"a,test",
|
||||
false,
|
||||
`["a","b","test"]`,
|
||||
},
|
||||
{
|
||||
"map with no matching field",
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"missing", // test individual fields trim
|
||||
false,
|
||||
`{}`,
|
||||
},
|
||||
{
|
||||
"map with existing and missing fields",
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"a, c ,missing", // test individual fields trim
|
||||
false,
|
||||
`{"a":1,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"slice of maps with existing and missing fields",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
[]any{
|
||||
map[string]any{"a": 11, "b": 11, "c": "test1"},
|
||||
map[string]any{"a": 22, "b": 22, "c": "test2"},
|
||||
},
|
||||
"fields=a, c ,missing", // test individual fields trim
|
||||
"a, c ,missing", // test individual fields trim
|
||||
false,
|
||||
`[{"a":11,"c":"test1"},{"a":22,"c":"test2"}]`,
|
||||
},
|
||||
{
|
||||
"nested fields with mixed map and any slices",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
@@ -167,13 +243,12 @@ func TestSerialize(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"fields=a, c, anySlice.A, mapSlice.C, mapSlice.D.DA, anySlice.D,fullMap",
|
||||
"a, c, anySlice.A, mapSlice.C, mapSlice.D.DA, anySlice.D,fullMap",
|
||||
false,
|
||||
`{"a":1,"anySlice":[{"A":[1,2,3],"D":{"DA":1,"DB":2}},{"A":"test"}],"c":"test","fullMap":[{"A":[1,2,3],"B":["1","2",3],"C":"test"},{"B":["1","2",3],"D":[{"DA":2},{"DA":3}]}],"mapSlice":[{"C":"test","D":[{"DA":1}]},{"D":[{"DA":2},{"DA":3},{}]}]}`,
|
||||
},
|
||||
{
|
||||
"SearchResult",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
search.Result{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
@@ -184,13 +259,12 @@ func TestSerialize(t *testing.T) {
|
||||
map[string]any{"a": 22, "b": 22, "c": "test2"},
|
||||
},
|
||||
},
|
||||
"fields=a,c,missing",
|
||||
"a,c,missing",
|
||||
false,
|
||||
`{"items":[{"a":11,"c":"test1"},{"a":22,"c":"test2"}],"page":1,"perPage":10,"totalItems":20,"totalPages":30}`,
|
||||
},
|
||||
{
|
||||
"*SearchResult",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
&search.Result{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
@@ -201,13 +275,12 @@ func TestSerialize(t *testing.T) {
|
||||
map[string]any{"a": 22, "b": 22, "c": "test2"},
|
||||
},
|
||||
},
|
||||
"fields=a,c",
|
||||
"a,c",
|
||||
false,
|
||||
`{"items":[{"a":11,"c":"test1"},{"a":22,"c":"test2"}],"page":1,"perPage":10,"totalItems":20,"totalPages":30}`,
|
||||
},
|
||||
{
|
||||
"root wildcard",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
&search.Result{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
@@ -218,13 +291,12 @@ func TestSerialize(t *testing.T) {
|
||||
map[string]any{"a": 22, "b": 22, "c": "test2"},
|
||||
},
|
||||
},
|
||||
"fields=*",
|
||||
"*",
|
||||
false,
|
||||
`{"items":[{"a":11,"b":11,"c":"test1"},{"a":22,"b":22,"c":"test2"}],"page":1,"perPage":10,"totalItems":20,"totalPages":30}`,
|
||||
},
|
||||
{
|
||||
"root wildcard with nested exception",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{
|
||||
"id": "123",
|
||||
"title": "lorem",
|
||||
@@ -233,13 +305,12 @@ func TestSerialize(t *testing.T) {
|
||||
"title": "rel_title",
|
||||
},
|
||||
},
|
||||
"fields=*,rel.id",
|
||||
"*,rel.id",
|
||||
false,
|
||||
`{"id":"123","rel":{"id":"456"},"title":"lorem"}`,
|
||||
},
|
||||
{
|
||||
"sub wildcard",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{
|
||||
"id": "123",
|
||||
"title": "lorem",
|
||||
@@ -252,13 +323,12 @@ func TestSerialize(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"fields=id,rel.*",
|
||||
"id,rel.*",
|
||||
false,
|
||||
`{"id":"123","rel":{"id":"456","sub":{"id":"789","title":"sub_title"},"title":"rel_title"}}`,
|
||||
},
|
||||
{
|
||||
"sub wildcard with nested exception",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{
|
||||
"id": "123",
|
||||
"title": "lorem",
|
||||
@@ -271,21 +341,19 @@ func TestSerialize(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"fields=id,rel.*,rel.sub.id",
|
||||
"id,rel.*,rel.sub.id",
|
||||
false,
|
||||
`{"id":"123","rel":{"id":"456","sub":{"id":"789"},"title":"rel_title"}}`,
|
||||
},
|
||||
{
|
||||
"invalid excerpt modifier",
|
||||
rest.Serializer{},
|
||||
400,
|
||||
map[string]any{"a": 1, "b": 2, "c": "test"},
|
||||
"fields=*:excerpt",
|
||||
"*:excerpt",
|
||||
true,
|
||||
`{"a":1,"b":2,"c":"test"}`,
|
||||
},
|
||||
{
|
||||
"valid excerpt modifier",
|
||||
rest.Serializer{},
|
||||
200,
|
||||
map[string]any{
|
||||
"id": "123",
|
||||
"title": "lorem",
|
||||
@@ -298,32 +366,32 @@ func TestSerialize(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"fields=*:excerpt(2),rel.title:excerpt(3, true)",
|
||||
"*:excerpt(2),rel.title:excerpt(3, true)",
|
||||
false,
|
||||
`{"id":"12","rel":{"title":"rel..."},"title":"lo"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.URL.RawQuery = s.query
|
||||
rec := httptest.NewRecorder()
|
||||
result, err := rest.PickFields(s.data, s.fields)
|
||||
|
||||
e := echo.New()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Response().Status = s.statusCode
|
||||
|
||||
if err := s.serializer.Serialize(c, s.data, ""); err != nil {
|
||||
t.Fatalf("Serialize failure: %v", err)
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
rawBody, err := io.ReadAll(rec.Result().Body)
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
serialized, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read request body: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(string(rawBody)); v != s.expected {
|
||||
t.Fatalf("Expected body\n%v \ngot: \n%v", s.expected, v)
|
||||
if v := string(serialized); v != s.result {
|
||||
t.Fatalf("Expected body\n%s \ngot \n%s", s.result, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,29 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const optionsParam = "options"
|
||||
|
||||
// Message defines a client's channel data.
|
||||
type Message struct {
|
||||
Name string
|
||||
Data []byte
|
||||
Name string `json:"name"`
|
||||
Data []byte `json:"data"`
|
||||
}
|
||||
|
||||
type SubscriptionOptions struct {
|
||||
// @todo after the requests handling refactoring consider
|
||||
// changing to map[string]string or map[string][]string
|
||||
Query map[string]any `json:"query"`
|
||||
Headers map[string]any `json:"headers"`
|
||||
}
|
||||
|
||||
// Client is an interface for a generic subscription client.
|
||||
@@ -20,10 +34,20 @@ type Client interface {
|
||||
// Channel returns the client's communication channel.
|
||||
Channel() chan Message
|
||||
|
||||
// Subscriptions returns all subscriptions to which the client has subscribed to.
|
||||
Subscriptions() map[string]struct{}
|
||||
// Subscriptions returns a shallow copy of the the client subscriptions matching the prefixes.
|
||||
// If no prefix is specified, returns all subscriptions.
|
||||
Subscriptions(prefixes ...string) map[string]SubscriptionOptions
|
||||
|
||||
// Subscribe subscribes the client to the provided subscriptions list.
|
||||
//
|
||||
// Each subscription can also have "options" (json serialized SubscriptionOptions) as query parameter.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// Subscribe(
|
||||
// "subscriptionA",
|
||||
// `subscriptionB?options={"query":{"a":1},"headers":{"x_token":"abc"}}`,
|
||||
// )
|
||||
Subscribe(subs ...string)
|
||||
|
||||
// Unsubscribe unsubscribes the client from the provided subscriptions list.
|
||||
@@ -61,8 +85,8 @@ var _ Client = (*DefaultClient)(nil)
|
||||
// DefaultClient defines a generic subscription client.
|
||||
type DefaultClient struct {
|
||||
store map[string]any
|
||||
subscriptions map[string]SubscriptionOptions
|
||||
channel chan Message
|
||||
subscriptions map[string]struct{}
|
||||
id string
|
||||
mux sync.RWMutex
|
||||
isDiscarded bool
|
||||
@@ -74,7 +98,7 @@ func NewDefaultClient() *DefaultClient {
|
||||
id: security.RandomString(40),
|
||||
store: map[string]any{},
|
||||
channel: make(chan Message),
|
||||
subscriptions: make(map[string]struct{}),
|
||||
subscriptions: map[string]SubscriptionOptions{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,11 +119,37 @@ func (c *DefaultClient) Channel() chan Message {
|
||||
}
|
||||
|
||||
// Subscriptions implements the [Client.Subscriptions] interface method.
|
||||
func (c *DefaultClient) Subscriptions() map[string]struct{} {
|
||||
//
|
||||
// It returns a shallow copy of the the client subscriptions matching the prefixes.
|
||||
// If no prefix is specified, returns all subscriptions.
|
||||
func (c *DefaultClient) Subscriptions(prefixes ...string) map[string]SubscriptionOptions {
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return c.subscriptions
|
||||
// no prefix -> return copy of all subscriptions
|
||||
if len(prefixes) == 0 {
|
||||
result := make(map[string]SubscriptionOptions, len(c.subscriptions))
|
||||
|
||||
for s, options := range c.subscriptions {
|
||||
result[s] = options
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
result := make(map[string]SubscriptionOptions)
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
for s, options := range c.subscriptions {
|
||||
// "?" ensures that the options query start character is always there
|
||||
// so that it can be used as an end separator when looking only for the main subscription topic
|
||||
if strings.HasPrefix(s+"?", prefix) {
|
||||
result[s] = options
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Subscribe implements the [Client.Subscribe] interface method.
|
||||
@@ -114,7 +164,30 @@ func (c *DefaultClient) Subscribe(subs ...string) {
|
||||
continue // skip empty
|
||||
}
|
||||
|
||||
c.subscriptions[s] = struct{}{}
|
||||
// extract subscription options (if any)
|
||||
options := SubscriptionOptions{}
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
rawOptions := u.Query().Get(optionsParam)
|
||||
if rawOptions != "" {
|
||||
json.Unmarshal([]byte(rawOptions), &options)
|
||||
}
|
||||
}
|
||||
|
||||
// normalize query
|
||||
// (currently only single string values are supported for consistency with the default routes handling)
|
||||
for k, v := range options.Query {
|
||||
options.Query[k] = cast.ToString(v)
|
||||
}
|
||||
|
||||
// normalize headers name and values, eg. "X-Token" is converted to "x_token"
|
||||
// (currently only single string values are supported for consistency with the default routes handling)
|
||||
for k, v := range options.Headers {
|
||||
delete(options.Headers, k)
|
||||
options.Headers[inflector.Snakecase(k)] = cast.ToString(v)
|
||||
}
|
||||
|
||||
c.subscriptions[s] = options
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -51,7 +53,7 @@ func TestChannel(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if c.Channel() == nil {
|
||||
t.Errorf("Expected channel to be initialized, got")
|
||||
t.Fatalf("Expected channel to be initialized, got")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,13 +61,35 @@ func TestSubscriptions(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected subscriptions to be empty")
|
||||
t.Fatalf("Expected subscriptions to be empty")
|
||||
}
|
||||
|
||||
c.Subscribe("sub1", "sub2", "sub3")
|
||||
c.Subscribe("sub1", "sub11", "sub2")
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
scenarios := []struct {
|
||||
prefixes []string
|
||||
expected []string
|
||||
}{
|
||||
{nil, []string{"sub1", "sub11", "sub2"}},
|
||||
{[]string{"missing"}, nil},
|
||||
{[]string{"sub1"}, []string{"sub1", "sub11"}},
|
||||
{[]string{"sub2"}, []string{"sub2"}}, // with extra query start char
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(strings.Join(s.prefixes, ","), func(t *testing.T) {
|
||||
subs := c.Subscriptions(s.prefixes...)
|
||||
|
||||
if len(subs) != len(s.expected) {
|
||||
t.Fatalf("Expected %d subscriptions, got %d", len(s.expected), len(subs))
|
||||
}
|
||||
|
||||
for _, s := range s.expected {
|
||||
if _, ok := subs[s]; !ok {
|
||||
t.Fatalf("Missing subscription %q in \n%v", s, subs)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +102,7 @@ func TestSubscribe(t *testing.T) {
|
||||
c.Subscribe(subs...) // empty string should be skipped
|
||||
|
||||
if len(c.Subscriptions()) != 3 {
|
||||
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
t.Fatalf("Expected 3 subscriptions, got %v", c.Subscriptions())
|
||||
}
|
||||
|
||||
for i, s := range expected {
|
||||
@@ -88,6 +112,44 @@ func TestSubscribe(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeOptions(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
sub1 := "test1"
|
||||
sub2 := `test2?options={"query":{"name":123},"headers":{"X-Token":456}}`
|
||||
|
||||
c.Subscribe(sub1, sub2)
|
||||
|
||||
subs := c.Subscriptions()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
expectedOptions string
|
||||
}{
|
||||
{sub1, `{"query":null,"headers":null}`},
|
||||
{sub2, `{"query":{"name":"123"},"headers":{"x_token":"456"}}`},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
options, ok := subs[s.name]
|
||||
if !ok {
|
||||
t.Fatalf("Missing subscription \n%q \nin \n%v", s.name, subs)
|
||||
}
|
||||
|
||||
rawBytes, err := json.Marshal(options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(rawBytes)
|
||||
|
||||
if rawStr != s.expectedOptions {
|
||||
t.Fatalf("Expected options \n%v \ngot \n%v", s.expectedOptions, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
c := subscriptions.NewDefaultClient()
|
||||
|
||||
@@ -96,12 +158,12 @@ func TestUnsubscribe(t *testing.T) {
|
||||
c.Unsubscribe("sub1")
|
||||
|
||||
if c.HasSubscription("sub1") {
|
||||
t.Error("Expected sub1 to be removed")
|
||||
t.Fatalf("Expected sub1 to be removed")
|
||||
}
|
||||
|
||||
c.Unsubscribe( /* all */ )
|
||||
if len(c.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
|
||||
t.Fatalf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user