merge v0.23.0-rc changes

This commit is contained in:
Gani Georgiev
2024-09-29 19:23:19 +03:00
parent ad92992324
commit 844f18cac3
753 changed files with 85141 additions and 63396 deletions
+38 -33
View File
@@ -1,6 +1,7 @@
package security_test
import (
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/security"
@@ -17,30 +18,30 @@ func TestEncrypt(t *testing.T) {
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
}
for i, scenario := range scenarios {
result, err := security.Encrypt([]byte(scenario.data), scenario.key)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.data), func(t *testing.T) {
result, err := security.Encrypt([]byte(s.data), s.key)
if scenario.expectError && err == nil {
t.Errorf("(%d) Expected error got nil", i)
}
if !scenario.expectError && err != nil {
t.Errorf("(%d) Expected nil got error %v", i, err)
}
hasErr := err != nil
if scenario.expectError && result != "" {
t.Errorf("(%d) Expected empty string, got %q", i, result)
}
if !scenario.expectError && result == "" {
t.Errorf("(%d) Expected non empty encrypted result string", i)
}
// try to decrypt
if result != "" {
decrypted, _ := security.Decrypt(result, scenario.key)
if string(decrypted) != scenario.data {
t.Errorf("(%d) Expected decrypted value to match with the data input, got %q", i, decrypted)
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
}
if hasErr {
if result != "" {
t.Fatalf("Expected empty Encrypt result on error, got %q", result)
}
return
}
// try to decrypt
decrypted, err := security.Decrypt(result, s.key)
if err != nil || string(decrypted) != s.data {
t.Fatalf("Expected decrypted value to match with the data input, got %q (%v)", decrypted, err)
}
})
}
}
@@ -57,19 +58,23 @@ func TestDecrypt(t *testing.T) {
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
}
for i, scenario := range scenarios {
result, err := security.Decrypt(scenario.cipher, scenario.key)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
result, err := security.Decrypt(s.cipher, s.key)
if scenario.expectError && err == nil {
t.Errorf("(%d) Expected error got nil", i)
}
if !scenario.expectError && err != nil {
t.Errorf("(%d) Expected nil got error %v", i, err)
}
hasErr := err != nil
resultStr := string(result)
if resultStr != scenario.expectedData {
t.Errorf("(%d) Expected %q, got %q", i, scenario.expectedData, resultStr)
}
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if str := string(result); str != s.expectedData {
t.Fatalf("Expected %q, got %q", s.expectedData, str)
}
})
}
}
+3 -12
View File
@@ -4,6 +4,7 @@ import (
"errors"
"time"
// @todo update to v5
"github.com/golang-jwt/jwt/v4"
)
@@ -43,11 +44,9 @@ func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
}
// NewJWT generates and returns new HS256 signed JWT.
func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
seconds := time.Duration(secondsDuration) * time.Second
func NewJWT(payload jwt.MapClaims, signingKey string, duration time.Duration) (string, error) {
claims := jwt.MapClaims{
"exp": time.Now().Add(seconds).Unix(),
"exp": time.Now().Add(duration).Unix(),
}
for k, v := range payload {
@@ -56,11 +55,3 @@ func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (st
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
}
// Deprecated:
// Consider replacing with NewJWT().
//
// NewToken is a legacy alias for NewJWT that generates a HS256 signed JWT.
func NewToken(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
return NewJWT(payload, signingKey, secondsDuration)
}
+61 -54
View File
@@ -1,7 +1,10 @@
package security_test
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pocketbase/pocketbase/tools/security"
@@ -102,26 +105,30 @@ func TestParseJWT(t *testing.T) {
},
}
for i, scenario := range scenarios {
result, err := security.ParseJWT(scenario.token, scenario.secret)
if scenario.expectError && err == nil {
t.Errorf("(%d) Expected error got nil", i)
}
if !scenario.expectError && err != nil {
t.Errorf("(%d) Expected nil got error %v", i, err)
}
if len(result) != len(scenario.expectClaims) {
t.Errorf("(%d) Expected %v got %v", i, scenario.expectClaims, result)
}
for k, v := range scenario.expectClaims {
v2, ok := result[k]
if !ok {
t.Errorf("(%d) Missing expected claim %q", i, k)
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.token), func(t *testing.T) {
result, err := security.ParseJWT(s.token, s.secret)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if v != v2 {
t.Errorf("(%d) Expected %v for %q claim, got %v", i, v, k, v2)
if len(result) != len(s.expectClaims) {
t.Fatalf("Expected %v claims got %v", s.expectClaims, result)
}
}
for k, v := range s.expectClaims {
v2, ok := result[k]
if !ok {
t.Fatalf("Missing expected claim %q", k)
}
if v != v2 {
t.Fatalf("Expected %v for %q claim, got %v", v, k, v2)
}
}
})
}
}
@@ -129,51 +136,51 @@ func TestNewJWT(t *testing.T) {
scenarios := []struct {
claims jwt.MapClaims
key string
duration int64
duration time.Duration
expectError bool
}{
// empty, zero duration
{jwt.MapClaims{}, "", 0, true},
// empty, 10 seconds duration
{jwt.MapClaims{}, "", 10, false},
{jwt.MapClaims{}, "", 10 * time.Second, false},
// non-empty, 10 seconds duration
{jwt.MapClaims{"name": "test"}, "test", 10, false},
{jwt.MapClaims{"name": "test"}, "test", 10 * time.Second, false},
}
for i, scenario := range scenarios {
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
if tokenErr != nil {
t.Errorf("(%d) Expected NewJWT to succeed, got error %v", i, tokenErr)
continue
}
claims, parseErr := security.ParseJWT(token, scenario.key)
hasParseErr := parseErr != nil
if hasParseErr != scenario.expectError {
t.Errorf("(%d) Expected hasParseErr to be %v, got %v (%v)", i, scenario.expectError, hasParseErr, parseErr)
continue
}
if scenario.expectError {
continue
}
if _, ok := claims["exp"]; !ok {
t.Errorf("(%d) Missing required claim exp, got %v", i, claims)
}
// clear exp claim to match with the scenario ones
delete(claims, "exp")
if len(claims) != len(scenario.claims) {
t.Errorf("(%d) Expected %v claims, got %v", i, scenario.claims, claims)
}
for j, k := range claims {
if claims[j] != scenario.claims[j] {
t.Errorf("(%d) Expected %v for %q claim, got %v", i, claims[j], k, scenario.claims[j])
t.Run(strconv.Itoa(i), func(t *testing.T) {
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
if tokenErr != nil {
t.Fatalf("Expected NewJWT to succeed, got error %v", tokenErr)
}
}
claims, parseErr := security.ParseJWT(token, scenario.key)
hasParseErr := parseErr != nil
if hasParseErr != scenario.expectError {
t.Fatalf("Expected hasParseErr to be %v, got %v (%v)", scenario.expectError, hasParseErr, parseErr)
}
if scenario.expectError {
return
}
if _, ok := claims["exp"]; !ok {
t.Fatalf("Missing required claim exp, got %v", claims)
}
// clear exp claim to match with the scenario ones
delete(claims, "exp")
if len(claims) != len(scenario.claims) {
t.Fatalf("Expected %v claims, got %v", scenario.claims, claims)
}
for j, k := range claims {
if claims[j] != scenario.claims[j] {
t.Fatalf("Expected %v for %q claim, got %v", claims[j], k, scenario.claims[j])
}
}
})
}
}
+1 -6
View File
@@ -3,16 +3,11 @@ package security
import (
cryptoRand "crypto/rand"
"math/big"
mathRand "math/rand"
"time"
mathRand "math/rand" // @todo replace with rand/v2?
)
const defaultRandomAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
func init() {
mathRand.Seed(time.Now().UnixNano())
}
// RandomString generates a cryptographically random string with the specified length.
//
// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding.
+152
View File
@@ -0,0 +1,152 @@
package security
import (
cryptoRand "crypto/rand"
"fmt"
"math/big"
"regexp/syntax"
"strings"
)
const defaultMaxRepeat = 6
var anyCharNotNLPairs = []rune{'A', 'Z', 'a', 'z', '0', '9'}
// RandomStringByRegex generates a random string matching the regex pattern.
// If optFlags is not set, fallbacks to [syntax.Perl].
//
// NB! While the source of the randomness comes from [crypto/rand] this method
// is not recommended to be used on its own in critical secure contexts because
// the generated length could vary too much on the used pattern and may not be
// as secure as simply calling [security.RandomString].
// If you still insist on using it for such purposes, consider at least
// a large enough minimum length for the generated string, e.g. `[a-z0-9]{30}`.
//
// This function is inspired by github.com/pipe01/revregexp, github.com/lucasjones/reggen and other similar packages.
func RandomStringByRegex(pattern string, optFlags ...syntax.Flags) (string, error) {
var flags syntax.Flags
if len(optFlags) == 0 {
flags = syntax.Perl
} else {
for _, f := range optFlags {
flags |= f
}
}
r, err := syntax.Parse(pattern, flags)
if err != nil {
return "", err
}
var sb = new(strings.Builder)
err = writeRandomStringByRegex(r, sb)
if err != nil {
return "", err
}
return sb.String(), nil
}
func writeRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder) error {
// https://pkg.go.dev/regexp/syntax#Op
switch r.Op {
case syntax.OpCharClass:
c, err := randomRuneFromPairs(r.Rune)
if err != nil {
return err
}
_, err = sb.WriteRune(c)
return err
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
c, err := randomRuneFromPairs(anyCharNotNLPairs)
if err != nil {
return err
}
_, err = sb.WriteRune(c)
return err
case syntax.OpAlternate:
idx, err := randomNumber(len(r.Sub))
if err != nil {
return err
}
return writeRandomStringByRegex(r.Sub[idx], sb)
case syntax.OpConcat:
var err error
for _, sub := range r.Sub {
err = writeRandomStringByRegex(sub, sb)
if err != nil {
break
}
}
return err
case syntax.OpRepeat:
return repeatRandomStringByRegex(r.Sub[0], sb, r.Min, r.Max)
case syntax.OpQuest:
return repeatRandomStringByRegex(r.Sub[0], sb, 0, 1)
case syntax.OpPlus:
return repeatRandomStringByRegex(r.Sub[0], sb, 1, -1)
case syntax.OpStar:
return repeatRandomStringByRegex(r.Sub[0], sb, 0, -1)
case syntax.OpCapture:
return writeRandomStringByRegex(r.Sub[0], sb)
case syntax.OpLiteral:
_, err := sb.WriteString(string(r.Rune))
return err
default:
return fmt.Errorf("unsupported pattern operator %d", r.Op)
}
}
func repeatRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder, min int, max int) error {
if max < 0 {
max = defaultMaxRepeat
}
if max < min {
max = min
}
n := min
if max != min {
randRange, err := randomNumber(max - min)
if err != nil {
return err
}
n += randRange
}
var err error
for i := 0; i < n; i++ {
err = writeRandomStringByRegex(r, sb)
if err != nil {
return err
}
}
return nil
}
func randomRuneFromPairs(pairs []rune) (rune, error) {
idx, err := randomNumber(len(pairs) / 2)
if err != nil {
return 0, err
}
return randomRuneFromRange(pairs[idx*2], pairs[idx*2+1])
}
func randomRuneFromRange(min rune, max rune) (rune, error) {
offset, err := randomNumber(int(max - min + 1))
if err != nil {
return min, err
}
return min + rune(offset), nil
}
func randomNumber(maxSoft int) (int, error) {
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(maxSoft)))
return int(randRange.Int64()), err
}
+66
View File
@@ -0,0 +1,66 @@
package security_test
import (
"fmt"
"regexp"
"regexp/syntax"
"slices"
"testing"
"github.com/pocketbase/pocketbase/tools/security"
)
func TestRandomStringByRegex(t *testing.T) {
generated := []string{}
scenarios := []struct {
pattern string
flags []syntax.Flags
expectError bool
}{
{``, nil, true},
{`test`, nil, false},
{`\d+`, []syntax.Flags{syntax.POSIX}, true},
{`\d+`, nil, false},
{`\d*`, nil, false},
{`\d{1,10}`, nil, false},
{`\d{3}`, nil, false},
{`\d{0,}-abc`, nil, false},
{`[a-zA-Z]*`, nil, false},
{`[^a-zA-Z]{5,30}`, nil, false},
{`\w+_abc`, nil, false},
{`[a-zA-Z_]*`, nil, false},
{`[2-9]{5}-\w+`, nil, false},
{`(a|b|c)`, nil, false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%q", i, s.pattern), func(t *testing.T) {
str, err := security.RandomStringByRegex(s.pattern, s.flags...)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
r, err := regexp.Compile(s.pattern)
if err != nil {
t.Fatal(err)
}
if !r.Match([]byte(str)) {
t.Fatalf("Expected %q to match pattern %v", str, s.pattern)
}
if slices.Contains(generated, str) {
t.Fatalf("The generated string %q already exists in\n%v", str, generated)
}
generated = append(generated, str)
})
}
}