=added experimental expand, filter, fields, custom query and headers parameters support for the realtime subscriptions

This commit is contained in:
Gani Georgiev
2023-10-23 22:46:47 +03:00
parent e6f1b3dfe4
commit 79617e6d99
41 changed files with 553 additions and 257 deletions
+1 -1
View File
@@ -28,7 +28,7 @@ func InitApi(app core.App) (*echo.Echo, error) {
e := echo.New()
e.Debug = app.IsDebug()
e.JSONSerializer = &rest.Serializer{
FieldsParam: "fields",
FieldsParam: fieldsQueryParam,
}
// configure a custom router
+139 -99
View File
@@ -15,9 +15,11 @@ import (
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/spf13/cast"
)
// bindRealtimeApi registers the realtime api endpoints.
@@ -326,58 +328,16 @@ func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection
return collection
}
// canAccessRecord checks if the subscription client has access to the specified record model.
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
admin, _ := client.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
// admins can access everything
return true
}
if accessRule == nil {
// only admins can access this record
return false
}
ruleFunc := func(q *dbx.SelectQuery) error {
if *accessRule == "" {
return nil // empty public rule
}
// mock request data
requestInfo := &models.RequestInfo{
Method: "GET",
}
requestInfo.AuthRecord, _ = client.Get(ContextAuthRecordKey).(*models.Record)
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), record.Collection(), requestInfo, true)
expr, err := search.FilterData(*accessRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
return nil
}
foundRecord, err := api.app.Dao().FindRecordById(record.Collection().Id, record.Id, ruleFunc)
if err == nil && foundRecord != nil {
return true
}
return false
}
// recordData represents the broadcasted record subscrition message data.
type recordData struct {
Record *models.Record `json:"record"`
Action string `json:"action"`
Record any `json:"record"` /* map or models.Record */
Action string `json:"action"`
}
func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dryCache bool) error {
collection := record.Collection()
if collection == nil {
return errors.New("Record collection not set.")
return errors.New("[broadcastRecord] Record collection not set.")
}
clients := api.app.SubscriptionsBroker().Clients()
@@ -385,67 +345,106 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
return nil // no subscribers
}
// create a clean record copy without expand and unknown fields
// because we don't know if the clients have permissions to view them
cleanRecord := record.CleanCopy()
subscriptionRuleMap := map[string]*string{
(collection.Name + "/" + cleanRecord.Id): collection.ViewRule,
(collection.Id + "/" + cleanRecord.Id): collection.ViewRule,
(collection.Name + "/*"): collection.ListRule,
(collection.Id + "/*"): collection.ListRule,
// @deprecated: the same as the wildcard topic but kept for backward compatibility
collection.Name: collection.ListRule,
collection.Id: collection.ListRule,
(collection.Name + "/" + record.Id + "?"): collection.ViewRule,
(collection.Id + "/" + record.Id + "?"): collection.ViewRule,
(collection.Name + "/*?"): collection.ListRule,
(collection.Id + "/*?"): collection.ListRule,
// @deprecated: the same as the wildcard topic but kept for backward compatibility
(collection.Name + "?"): collection.ListRule,
(collection.Id + "?"): collection.ListRule,
}
data := &recordData{
Action: action,
Record: cleanRecord,
}
dataBytes, err := json.Marshal(data)
if err != nil {
return err
}
dryCacheKey := action + "/" + record.Id
for _, client := range clients {
client := client
for subscription, rule := range subscriptionRuleMap {
if !client.HasSubscription(subscription) {
// note: not executed concurrently to avoid races and to ensure
// that the access checks are applied for the current record db state
for prefix, rule := range subscriptionRuleMap {
subs := client.Subscriptions(prefix)
if len(subs) == 0 {
continue
}
if !api.canAccessRecord(client, data.Record, rule) {
continue
}
for sub, options := range subs {
// create a clean record copy without expand and unknown fields
// because we don't know yet which exact fields the client subscription has permissions to access
cleanRecord := record.CleanCopy()
msg := subscriptions.Message{
Name: subscription,
Data: dataBytes,
}
// ignore the auth record email visibility checks for
// auth owner, admin or manager
if collection.IsAuth() {
authId := extractAuthIdFromGetter(client)
if authId == data.Record.Id ||
api.canAccessRecord(client, data.Record, collection.AuthOptions().ManageRule) {
data.Record.IgnoreEmailVisibility(true) // ignore
if newData, err := json.Marshal(data); err == nil {
msg.Data = newData
}
data.Record.IgnoreEmailVisibility(false) // restore
// mock request data
requestInfo := &models.RequestInfo{
Method: "GET",
Query: options.Query,
Headers: options.Headers,
}
}
requestInfo.Admin, _ = client.Get(ContextAdminKey).(*models.Admin)
requestInfo.AuthRecord, _ = client.Get(ContextAuthRecordKey).(*models.Record)
if dryCache {
client.Set(action+"/"+data.Record.Id, msg)
} else {
routine.FireAndForget(func() {
client.Send(msg)
})
if !api.canAccessRecord(cleanRecord, requestInfo, rule) {
continue
}
rawExpand := cast.ToString(options.Query[expandQueryParam])
if rawExpand != "" {
expandErrs := api.app.Dao().ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(api.app.Dao(), requestInfo))
if api.app.IsDebug() && len(expandErrs) > 0 {
log.Println("[broadcastRecord] expand errors", expandErrs)
}
}
// ignore the auth record email visibility checks
// for auth owner, admin or manager
if collection.IsAuth() {
authId := extractAuthIdFromGetter(client)
if authId == cleanRecord.Id {
if api.canAccessRecord(cleanRecord, requestInfo, collection.AuthOptions().ManageRule) {
cleanRecord.IgnoreEmailVisibility(true)
}
}
}
data := &recordData{
Action: action,
Record: cleanRecord,
}
// check fields
rawFields := cast.ToString(options.Query[fieldsQueryParam])
if rawFields != "" {
decoded, err := rest.PickFields(cleanRecord, rawFields)
if err == nil {
data.Record = decoded
} else if api.app.IsDebug() {
log.Println(err)
}
}
dataBytes, err := json.Marshal(data)
if err != nil && api.app.IsDebug() {
log.Println("[broadcastRecord] data marshal error", err)
continue
}
msg := subscriptions.Message{
Name: sub,
Data: dataBytes,
}
if dryCache {
messages, ok := client.Get(dryCacheKey).([]subscriptions.Message)
if !ok {
messages = []subscriptions.Message{msg}
} else {
messages = append(messages, msg)
}
client.Set(dryCacheKey, messages)
} else {
routine.FireAndForget(func() {
client.Send(msg)
})
}
}
}
}
@@ -453,14 +452,14 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
return nil
}
// broadcastDryCachedRecord broadcasts record if it is cached in the client context.
// broadcastDryCachedRecord broadcasts all cached record related messages.
func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.Record) error {
key := action + "/" + record.Id
clients := api.app.SubscriptionsBroker().Clients()
for _, client := range clients {
key := action + "/" + record.Id
msg, ok := client.Get(key).(subscriptions.Message)
messages, ok := client.Get(key).([]subscriptions.Message)
if !ok {
continue
}
@@ -470,9 +469,12 @@ func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.R
client := client
routine.FireAndForget(func() {
client.Send(msg)
for _, msg := range messages {
client.Send(msg)
}
})
}
return nil
}
@@ -493,3 +495,41 @@ func extractAuthIdFromGetter(val getter) string {
return ""
}
// canAccessRecord checks if the subscription client has access to the specified record model.
func (api *realtimeApi) canAccessRecord(
record *models.Record,
requestInfo *models.RequestInfo,
accessRule *string,
) bool {
// check the access rule
// ---
if ok, _ := api.app.Dao().CanAccessRecord(record, requestInfo, accessRule); !ok {
return false
}
// check the subscription client-side filter (if any)
// ---
filter := cast.ToString(requestInfo.Query[search.FilterQueryParam])
if filter == "" {
return true // no further checks needed
}
ruleFunc := func(q *dbx.SelectQuery) error {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), record.Collection(), requestInfo, false)
expr, err := search.FilterData(filter).BuildExpr(resolver)
if err != nil {
return err
}
q.AndWhere(expr)
resolver.UpdateQuery(q)
return nil
}
_, err := api.app.Dao().FindRecordById(record.Collection().Id, record.Id, ruleFunc)
return err == nil
}
-2
View File
@@ -16,8 +16,6 @@ import (
"github.com/pocketbase/pocketbase/tools/search"
)
const expandQueryParam = "expand"
// bindRecordCrudApi registers the record crud api endpoints and
// the corresponding handlers.
func bindRecordCrudApi(app core.App, rg *echo.Group) {
+5 -1
View File
@@ -13,12 +13,16 @@ import (
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/search"
)
const ContextRequestInfoKey = "requestInfo"
const expandQueryParam = "expand"
const fieldsQueryParam = "fields"
// Deprecated: Use RequestInfo instead.
func RequestData(c echo.Context) *models.RequestInfo {
log.Println("RequestData(c) is deprecated and will be removed in the future! You can replace it with RequestInfo(c).")
@@ -49,7 +53,7 @@ func RequestInfo(c echo.Context) *models.RequestInfo {
// ("X-Token" is converted to "x_token")
for k, v := range c.Request().Header {
if len(v) > 0 {
result.Headers[strings.ToLower(strings.ReplaceAll(k, "-", "_"))] = v[0]
result.Headers[inflector.Snakecase(k)] = v[0]
}
}