added RateLimitRule.Audience field

This commit is contained in:
Gani Georgiev
2024-11-08 18:04:13 +02:00
parent 0e56521e8a
commit f6aef4471d
37 changed files with 387 additions and 141 deletions
+13
View File
@@ -106,6 +106,19 @@ func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection,
// @todo consider exporting as RateLimit helper?
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
switch rule.Audience {
case core.RateLimitRuleAudienceAll:
// valid for both guest and regular users
case core.RateLimitRuleAudienceGuest:
if e.Auth != nil {
return nil
}
case core.RateLimitRuleAudienceAuth:
if e.Auth == nil {
return nil
}
}
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
return initRateLimitersStore(e.App)
}).(*store.Store[*rateLimiter])
+76 -20
View File
@@ -31,6 +31,18 @@ func TestDefaultRateLimitMiddleware(t *testing.T) {
MaxRequests: 1,
Duration: 1,
},
{
Label: "/rate/guest",
MaxRequests: 1,
Duration: 1,
Audience: core.RateLimitRuleAudienceGuest,
},
{
Label: "/rate/auth",
MaxRequests: 1,
Duration: 1,
Audience: core.RateLimitRuleAudienceAuth,
},
}
pbRouter, err := apis.NewRouter(app)
@@ -48,6 +60,12 @@ func TestDefaultRateLimitMiddleware(t *testing.T) {
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
})
pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
return e.String(200, "guest")
})
pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
return e.String(200, "auth")
})
mux, err := pbRouter.BuildMux()
if err != nil {
@@ -57,30 +75,53 @@ func TestDefaultRateLimitMiddleware(t *testing.T) {
scenarios := []struct {
url string
wait float64
authenticated bool
expectedStatus int
}{
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 429},
{"/rate/a", 0, 429},
{"/rate/a", 1.1, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 429},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 429},
{"/rate/a", 0, false, 429},
{"/rate/a", 1.1, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 429},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 429},
{"/rate/b", 1.1, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 429},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 429},
{"/rate/b", 1.1, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 429},
// "auth" with guest (should be ignored)
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 200},
// "auth" rule with regular user
{"/rate/auth", 0, true, 200},
{"/rate/auth", 0, true, 429},
{"/rate/auth", 0, true, 429},
// "guest" with guest
{"/rate/guest", 0, false, 200},
{"/rate/guest", 0, false, 429},
{"/rate/guest", 0, false, 429},
// "guest" rule with regular user (should be ignored)
{"/rate/guest", 0, true, 200},
{"/rate/guest", 0, true, 200},
{"/rate/guest", 0, true, 200},
{"/rate/guest", 0, true, 200},
}
for _, s := range scenarios {
@@ -91,6 +132,21 @@ func TestDefaultRateLimitMiddleware(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", s.url, nil)
if s.authenticated {
auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
token, err := auth.NewAuthToken()
if err != nil {
t.Fatal(err)
}
req.Header.Add("Authorization", token)
}
mux.ServeHTTP(rec, req)
result := rec.Result()