=added experimental expand, filter, fields, custom query and headers parameters support for the realtime subscriptions

This commit is contained in:
Gani Georgiev
2023-10-23 22:46:47 +03:00
parent e6f1b3dfe4
commit 79617e6d99
41 changed files with 553 additions and 257 deletions
+57 -35
View File
@@ -5,7 +5,10 @@ import (
"fmt"
"strings"
// experimental! (need more tests before replacing encoding/json entirely)
// Experimental!
//
// Need more tests before replacing encoding/json entirely.
// Test also encoding/json/v2 once released (see https://github.com/golang/go/discussions/63397)
goccy "github.com/goccy/go-json"
"github.com/labstack/echo/v5"
@@ -45,42 +48,11 @@ func (s *Serializer) Serialize(c echo.Context, i any, indent string) error {
return s.DefaultJSONSerializer.Serialize(c, i, indent)
}
parsedFields, err := parseFields(rawFields)
decoded, err := PickFields(i, rawFields)
if err != nil {
return err
}
// marshalize the provided data to ensure that the related json.Marshaler
// implementations are invoked, and then convert it back to a plain
// json value that we can further operate on.
//
// @todo research other approaches to avoid the double serialization
// ---
encoded, err := json.Marshal(i) // use the std json since goccy has several bugs reported with struct marshaling and it is not safe
if err != nil {
return err
}
var decoded any
if err := goccy.Unmarshal(encoded, &decoded); err != nil {
return err
}
// ---
var isSearchResult bool
switch i.(type) {
case search.Result, *search.Result:
isSearchResult = true
}
if isSearchResult {
if decodedMap, ok := decoded.(map[string]any); ok {
pickFields(decodedMap["items"], parsedFields)
}
} else {
pickFields(decoded, parsedFields)
}
enc := goccy.NewEncoder(c.Response())
if indent != "" {
enc.SetIndent("", indent)
@@ -89,6 +61,56 @@ func (s *Serializer) Serialize(c echo.Context, i any, indent string) error {
return enc.Encode(decoded)
}
// PickFields parses the provided fields string expression and
// returns a new subset of data with only the requested fields.
//
// Fields transformations with modifiers are also supported (see initModifer()).
//
// Example:
//
// data := map[string]any{"a": 1, "b": 2, "c": map[string]any{"c1": 11, "c2": 22}}
// PickFields(data, "a,c.c1") // map[string]any{"a": 1, "c": map[string]any{"c1": 11}}
func PickFields(data any, rawFields string) (any, error) {
parsedFields, err := parseFields(rawFields)
if err != nil {
return nil, err
}
// marshalize the provided data to ensure that the related json.Marshaler
// implementations are invoked, and then convert it back to a plain
// json value that we can further operate on.
//
// @todo research other approaches to avoid the double serialization
// ---
encoded, err := json.Marshal(data) // use the std json since goccy has several bugs reported with struct marshaling and it is not safe
if err != nil {
return nil, err
}
var decoded any
if err := goccy.Unmarshal(encoded, &decoded); err != nil {
return nil, err
}
// ---
// special cases to preserve the same fields format when used with single item or array data.
var isSearchResult bool
switch data.(type) {
case search.Result, *search.Result:
isSearchResult = true
}
if isSearchResult {
if decodedMap, ok := decoded.(map[string]any); ok {
pickParsedFields(decodedMap["items"], parsedFields)
}
} else {
pickParsedFields(decoded, parsedFields)
}
return decoded, nil
}
func parseFields(rawFields string) (map[string]FieldModifier, error) {
t := tokenizer.NewFromString(rawFields)
@@ -145,7 +167,7 @@ func initModifer(rawModifier string) (FieldModifier, error) {
return nil, fmt.Errorf("missing or invalid modifier %q", name)
}
func pickFields(data any, fields map[string]FieldModifier) error {
func pickParsedFields(data any, fields map[string]FieldModifier) error {
switch v := data.(type) {
case map[string]any:
pickMapFields(v, fields)
@@ -233,7 +255,7 @@ DataLoop:
matchingFields[remains] = m
}
if err := pickFields(data[k], matchingFields); err != nil {
if err := pickParsedFields(data[k], matchingFields); err != nil {
return err
}
}
+127 -59
View File
@@ -1,6 +1,7 @@
package rest_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
@@ -45,6 +46,22 @@ func TestSerialize(t *testing.T) {
"fields=missing",
`{}`,
},
{
">299 response",
rest.Serializer{},
300,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{"a":1,"b":2,"c":"test"}`,
},
{
"<200 response",
rest.Serializer{},
199,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{"a":1,"b":2,"c":"test"}`,
},
{
"non map response",
rest.Serializer{},
@@ -69,22 +86,6 @@ func TestSerialize(t *testing.T) {
"fields=missing", // test individual fields trim
`{}`,
},
{
">299 response",
rest.Serializer{},
300,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{"a":1,"b":2,"c":"test"}`,
},
{
"<200 response",
rest.Serializer{},
199,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=missing",
`{"a":1,"b":2,"c":"test"}`,
},
{
"map with existing and missing fields",
rest.Serializer{},
@@ -101,21 +102,96 @@ func TestSerialize(t *testing.T) {
"custom=a, c ,missing", // test individual fields trim
`{"a":1,"c":"test"}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.URL.RawQuery = s.query
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
c.Response().Status = s.statusCode
if err := s.serializer.Serialize(c, s.data, ""); err != nil {
t.Fatalf("Serialize failure: %v", err)
}
rawBody, err := io.ReadAll(rec.Result().Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}
if v := strings.TrimSpace(string(rawBody)); v != s.expected {
t.Fatalf("Expected body\n%v \ngot \n%v", s.expected, v)
}
})
}
}
func TestPickFields(t *testing.T) {
scenarios := []struct {
name string
data any
fields string
expectError bool
result string
}{
{
"empty fields",
map[string]any{"a": 1, "b": 2, "c": "test"},
"",
false,
`{"a":1,"b":2,"c":"test"}`,
},
{
"missing fields",
map[string]any{"a": 1, "b": 2, "c": "test"},
"missing",
false,
`{}`,
},
{
"non map data",
"test",
"a,b,test",
false,
`"test"`,
},
{
"non slice of map data",
[]any{"a", "b", "test"},
"a,test",
false,
`["a","b","test"]`,
},
{
"map with no matching field",
map[string]any{"a": 1, "b": 2, "c": "test"},
"missing", // test individual fields trim
false,
`{}`,
},
{
"map with existing and missing fields",
map[string]any{"a": 1, "b": 2, "c": "test"},
"a, c ,missing", // test individual fields trim
false,
`{"a":1,"c":"test"}`,
},
{
"slice of maps with existing and missing fields",
rest.Serializer{},
200,
[]any{
map[string]any{"a": 11, "b": 11, "c": "test1"},
map[string]any{"a": 22, "b": 22, "c": "test2"},
},
"fields=a, c ,missing", // test individual fields trim
"a, c ,missing", // test individual fields trim
false,
`[{"a":11,"c":"test1"},{"a":22,"c":"test2"}]`,
},
{
"nested fields with mixed map and any slices",
rest.Serializer{},
200,
map[string]any{
"a": 1,
"b": 2,
@@ -167,13 +243,12 @@ func TestSerialize(t *testing.T) {
},
},
},
"fields=a, c, anySlice.A, mapSlice.C, mapSlice.D.DA, anySlice.D,fullMap",
"a, c, anySlice.A, mapSlice.C, mapSlice.D.DA, anySlice.D,fullMap",
false,
`{"a":1,"anySlice":[{"A":[1,2,3],"D":{"DA":1,"DB":2}},{"A":"test"}],"c":"test","fullMap":[{"A":[1,2,3],"B":["1","2",3],"C":"test"},{"B":["1","2",3],"D":[{"DA":2},{"DA":3}]}],"mapSlice":[{"C":"test","D":[{"DA":1}]},{"D":[{"DA":2},{"DA":3},{}]}]}`,
},
{
"SearchResult",
rest.Serializer{},
200,
search.Result{
Page: 1,
PerPage: 10,
@@ -184,13 +259,12 @@ func TestSerialize(t *testing.T) {
map[string]any{"a": 22, "b": 22, "c": "test2"},
},
},
"fields=a,c,missing",
"a,c,missing",
false,
`{"items":[{"a":11,"c":"test1"},{"a":22,"c":"test2"}],"page":1,"perPage":10,"totalItems":20,"totalPages":30}`,
},
{
"*SearchResult",
rest.Serializer{},
200,
&search.Result{
Page: 1,
PerPage: 10,
@@ -201,13 +275,12 @@ func TestSerialize(t *testing.T) {
map[string]any{"a": 22, "b": 22, "c": "test2"},
},
},
"fields=a,c",
"a,c",
false,
`{"items":[{"a":11,"c":"test1"},{"a":22,"c":"test2"}],"page":1,"perPage":10,"totalItems":20,"totalPages":30}`,
},
{
"root wildcard",
rest.Serializer{},
200,
&search.Result{
Page: 1,
PerPage: 10,
@@ -218,13 +291,12 @@ func TestSerialize(t *testing.T) {
map[string]any{"a": 22, "b": 22, "c": "test2"},
},
},
"fields=*",
"*",
false,
`{"items":[{"a":11,"b":11,"c":"test1"},{"a":22,"b":22,"c":"test2"}],"page":1,"perPage":10,"totalItems":20,"totalPages":30}`,
},
{
"root wildcard with nested exception",
rest.Serializer{},
200,
map[string]any{
"id": "123",
"title": "lorem",
@@ -233,13 +305,12 @@ func TestSerialize(t *testing.T) {
"title": "rel_title",
},
},
"fields=*,rel.id",
"*,rel.id",
false,
`{"id":"123","rel":{"id":"456"},"title":"lorem"}`,
},
{
"sub wildcard",
rest.Serializer{},
200,
map[string]any{
"id": "123",
"title": "lorem",
@@ -252,13 +323,12 @@ func TestSerialize(t *testing.T) {
},
},
},
"fields=id,rel.*",
"id,rel.*",
false,
`{"id":"123","rel":{"id":"456","sub":{"id":"789","title":"sub_title"},"title":"rel_title"}}`,
},
{
"sub wildcard with nested exception",
rest.Serializer{},
200,
map[string]any{
"id": "123",
"title": "lorem",
@@ -271,21 +341,19 @@ func TestSerialize(t *testing.T) {
},
},
},
"fields=id,rel.*,rel.sub.id",
"id,rel.*,rel.sub.id",
false,
`{"id":"123","rel":{"id":"456","sub":{"id":"789"},"title":"rel_title"}}`,
},
{
"invalid excerpt modifier",
rest.Serializer{},
400,
map[string]any{"a": 1, "b": 2, "c": "test"},
"fields=*:excerpt",
"*:excerpt",
true,
`{"a":1,"b":2,"c":"test"}`,
},
{
"valid excerpt modifier",
rest.Serializer{},
200,
map[string]any{
"id": "123",
"title": "lorem",
@@ -298,32 +366,32 @@ func TestSerialize(t *testing.T) {
},
},
},
"fields=*:excerpt(2),rel.title:excerpt(3, true)",
"*:excerpt(2),rel.title:excerpt(3, true)",
false,
`{"id":"12","rel":{"title":"rel..."},"title":"lo"}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.URL.RawQuery = s.query
rec := httptest.NewRecorder()
result, err := rest.PickFields(s.data, s.fields)
e := echo.New()
c := e.NewContext(req, rec)
c.Response().Status = s.statusCode
if err := s.serializer.Serialize(c, s.data, ""); err != nil {
t.Fatalf("Serialize failure: %v", err)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
rawBody, err := io.ReadAll(rec.Result().Body)
if hasErr {
return
}
serialized, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
t.Fatal(err)
}
if v := strings.TrimSpace(string(rawBody)); v != s.expected {
t.Fatalf("Expected body\n%v \ngot: \n%v", s.expected, v)
if v := string(serialized); v != s.result {
t.Fatalf("Expected body\n%s \ngot \n%s", s.result, v)
}
})
}
+82 -9
View File
@@ -1,15 +1,29 @@
package subscriptions
import (
"encoding/json"
"net/url"
"strings"
"sync"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
const optionsParam = "options"
// Message defines a client's channel data.
type Message struct {
Name string
Data []byte
Name string `json:"name"`
Data []byte `json:"data"`
}
type SubscriptionOptions struct {
// @todo after the requests handling refactoring consider
// changing to map[string]string or map[string][]string
Query map[string]any `json:"query"`
Headers map[string]any `json:"headers"`
}
// Client is an interface for a generic subscription client.
@@ -20,10 +34,20 @@ type Client interface {
// Channel returns the client's communication channel.
Channel() chan Message
// Subscriptions returns all subscriptions to which the client has subscribed to.
Subscriptions() map[string]struct{}
// Subscriptions returns a shallow copy of the the client subscriptions matching the prefixes.
// If no prefix is specified, returns all subscriptions.
Subscriptions(prefixes ...string) map[string]SubscriptionOptions
// Subscribe subscribes the client to the provided subscriptions list.
//
// Each subscription can also have "options" (json serialized SubscriptionOptions) as query parameter.
//
// Example:
//
// Subscribe(
// "subscriptionA",
// `subscriptionB?options={"query":{"a":1},"headers":{"x_token":"abc"}}`,
// )
Subscribe(subs ...string)
// Unsubscribe unsubscribes the client from the provided subscriptions list.
@@ -61,8 +85,8 @@ var _ Client = (*DefaultClient)(nil)
// DefaultClient defines a generic subscription client.
type DefaultClient struct {
store map[string]any
subscriptions map[string]SubscriptionOptions
channel chan Message
subscriptions map[string]struct{}
id string
mux sync.RWMutex
isDiscarded bool
@@ -74,7 +98,7 @@ func NewDefaultClient() *DefaultClient {
id: security.RandomString(40),
store: map[string]any{},
channel: make(chan Message),
subscriptions: make(map[string]struct{}),
subscriptions: map[string]SubscriptionOptions{},
}
}
@@ -95,11 +119,37 @@ func (c *DefaultClient) Channel() chan Message {
}
// Subscriptions implements the [Client.Subscriptions] interface method.
func (c *DefaultClient) Subscriptions() map[string]struct{} {
//
// It returns a shallow copy of the the client subscriptions matching the prefixes.
// If no prefix is specified, returns all subscriptions.
func (c *DefaultClient) Subscriptions(prefixes ...string) map[string]SubscriptionOptions {
c.mux.RLock()
defer c.mux.RUnlock()
return c.subscriptions
// no prefix -> return copy of all subscriptions
if len(prefixes) == 0 {
result := make(map[string]SubscriptionOptions, len(c.subscriptions))
for s, options := range c.subscriptions {
result[s] = options
}
return result
}
result := make(map[string]SubscriptionOptions)
for _, prefix := range prefixes {
for s, options := range c.subscriptions {
// "?" ensures that the options query start character is always there
// so that it can be used as an end separator when looking only for the main subscription topic
if strings.HasPrefix(s+"?", prefix) {
result[s] = options
}
}
}
return result
}
// Subscribe implements the [Client.Subscribe] interface method.
@@ -114,7 +164,30 @@ func (c *DefaultClient) Subscribe(subs ...string) {
continue // skip empty
}
c.subscriptions[s] = struct{}{}
// extract subscription options (if any)
options := SubscriptionOptions{}
u, err := url.Parse(s)
if err == nil {
rawOptions := u.Query().Get(optionsParam)
if rawOptions != "" {
json.Unmarshal([]byte(rawOptions), &options)
}
}
// normalize query
// (currently only single string values are supported for consistency with the default routes handling)
for k, v := range options.Query {
options.Query[k] = cast.ToString(v)
}
// normalize headers name and values, eg. "X-Token" is converted to "x_token"
// (currently only single string values are supported for consistency with the default routes handling)
for k, v := range options.Headers {
delete(options.Headers, k)
options.Headers[inflector.Snakecase(k)] = cast.ToString(v)
}
c.subscriptions[s] = options
}
}
+70 -8
View File
@@ -1,6 +1,8 @@
package subscriptions_test
import (
"encoding/json"
"strings"
"testing"
"time"
@@ -51,7 +53,7 @@ func TestChannel(t *testing.T) {
c := subscriptions.NewDefaultClient()
if c.Channel() == nil {
t.Errorf("Expected channel to be initialized, got")
t.Fatalf("Expected channel to be initialized, got")
}
}
@@ -59,13 +61,35 @@ func TestSubscriptions(t *testing.T) {
c := subscriptions.NewDefaultClient()
if len(c.Subscriptions()) != 0 {
t.Errorf("Expected subscriptions to be empty")
t.Fatalf("Expected subscriptions to be empty")
}
c.Subscribe("sub1", "sub2", "sub3")
c.Subscribe("sub1", "sub11", "sub2")
if len(c.Subscriptions()) != 3 {
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
scenarios := []struct {
prefixes []string
expected []string
}{
{nil, []string{"sub1", "sub11", "sub2"}},
{[]string{"missing"}, nil},
{[]string{"sub1"}, []string{"sub1", "sub11"}},
{[]string{"sub2"}, []string{"sub2"}}, // with extra query start char
}
for _, s := range scenarios {
t.Run(strings.Join(s.prefixes, ","), func(t *testing.T) {
subs := c.Subscriptions(s.prefixes...)
if len(subs) != len(s.expected) {
t.Fatalf("Expected %d subscriptions, got %d", len(s.expected), len(subs))
}
for _, s := range s.expected {
if _, ok := subs[s]; !ok {
t.Fatalf("Missing subscription %q in \n%v", s, subs)
}
}
})
}
}
@@ -78,7 +102,7 @@ func TestSubscribe(t *testing.T) {
c.Subscribe(subs...) // empty string should be skipped
if len(c.Subscriptions()) != 3 {
t.Errorf("Expected 3 subscriptions, got %v", c.Subscriptions())
t.Fatalf("Expected 3 subscriptions, got %v", c.Subscriptions())
}
for i, s := range expected {
@@ -88,6 +112,44 @@ func TestSubscribe(t *testing.T) {
}
}
func TestSubscribeOptions(t *testing.T) {
c := subscriptions.NewDefaultClient()
sub1 := "test1"
sub2 := `test2?options={"query":{"name":123},"headers":{"X-Token":456}}`
c.Subscribe(sub1, sub2)
subs := c.Subscriptions()
scenarios := []struct {
name string
expectedOptions string
}{
{sub1, `{"query":null,"headers":null}`},
{sub2, `{"query":{"name":"123"},"headers":{"x_token":"456"}}`},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
options, ok := subs[s.name]
if !ok {
t.Fatalf("Missing subscription \n%q \nin \n%v", s.name, subs)
}
rawBytes, err := json.Marshal(options)
if err != nil {
t.Fatal(err)
}
rawStr := string(rawBytes)
if rawStr != s.expectedOptions {
t.Fatalf("Expected options \n%v \ngot \n%v", s.expectedOptions, rawStr)
}
})
}
}
func TestUnsubscribe(t *testing.T) {
c := subscriptions.NewDefaultClient()
@@ -96,12 +158,12 @@ func TestUnsubscribe(t *testing.T) {
c.Unsubscribe("sub1")
if c.HasSubscription("sub1") {
t.Error("Expected sub1 to be removed")
t.Fatalf("Expected sub1 to be removed")
}
c.Unsubscribe( /* all */ )
if len(c.Subscriptions()) != 0 {
t.Errorf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
t.Fatalf("Expected all subscriptions to be removed, got %v", c.Subscriptions())
}
}