merge v0.23.0-rc changes
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
@@ -17,30 +18,30 @@ func TestEncrypt(t *testing.T) {
|
||||
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Encrypt([]byte(scenario.data), scenario.key)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.data), func(t *testing.T) {
|
||||
result, err := security.Encrypt([]byte(s.data), s.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
hasErr := err != nil
|
||||
|
||||
if scenario.expectError && result != "" {
|
||||
t.Errorf("(%d) Expected empty string, got %q", i, result)
|
||||
}
|
||||
if !scenario.expectError && result == "" {
|
||||
t.Errorf("(%d) Expected non empty encrypted result string", i)
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
if result != "" {
|
||||
decrypted, _ := security.Decrypt(result, scenario.key)
|
||||
if string(decrypted) != scenario.data {
|
||||
t.Errorf("(%d) Expected decrypted value to match with the data input, got %q", i, decrypted)
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
if result != "" {
|
||||
t.Fatalf("Expected empty Encrypt result on error, got %q", result)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
decrypted, err := security.Decrypt(result, s.key)
|
||||
if err != nil || string(decrypted) != s.data {
|
||||
t.Fatalf("Expected decrypted value to match with the data input, got %q (%v)", decrypted, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,19 +58,23 @@ func TestDecrypt(t *testing.T) {
|
||||
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Decrypt(scenario.cipher, scenario.key)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
|
||||
result, err := security.Decrypt(s.cipher, s.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
hasErr := err != nil
|
||||
|
||||
resultStr := string(result)
|
||||
if resultStr != scenario.expectedData {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expectedData, resultStr)
|
||||
}
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if str := string(result); str != s.expectedData {
|
||||
t.Fatalf("Expected %q, got %q", s.expectedData, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+3
-12
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
// @todo update to v5
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
@@ -43,11 +44,9 @@ func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
|
||||
}
|
||||
|
||||
// NewJWT generates and returns new HS256 signed JWT.
|
||||
func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
|
||||
seconds := time.Duration(secondsDuration) * time.Second
|
||||
|
||||
func NewJWT(payload jwt.MapClaims, signingKey string, duration time.Duration) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"exp": time.Now().Add(seconds).Unix(),
|
||||
"exp": time.Now().Add(duration).Unix(),
|
||||
}
|
||||
|
||||
for k, v := range payload {
|
||||
@@ -56,11 +55,3 @@ func NewJWT(payload jwt.MapClaims, signingKey string, secondsDuration int64) (st
|
||||
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
|
||||
}
|
||||
|
||||
// Deprecated:
|
||||
// Consider replacing with NewJWT().
|
||||
//
|
||||
// NewToken is a legacy alias for NewJWT that generates a HS256 signed JWT.
|
||||
func NewToken(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
|
||||
return NewJWT(payload, signingKey, secondsDuration)
|
||||
}
|
||||
|
||||
+61
-54
@@ -1,7 +1,10 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
@@ -102,26 +105,30 @@ func TestParseJWT(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.ParseJWT(scenario.token, scenario.secret)
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
if len(result) != len(scenario.expectClaims) {
|
||||
t.Errorf("(%d) Expected %v got %v", i, scenario.expectClaims, result)
|
||||
}
|
||||
for k, v := range scenario.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Errorf("(%d) Missing expected claim %q", i, k)
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.token), func(t *testing.T) {
|
||||
result, err := security.ParseJWT(s.token, s.secret)
|
||||
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, v, k, v2)
|
||||
|
||||
if len(result) != len(s.expectClaims) {
|
||||
t.Fatalf("Expected %v claims got %v", s.expectClaims, result)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range s.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Fatalf("Missing expected claim %q", k)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Fatalf("Expected %v for %q claim, got %v", v, k, v2)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,51 +136,51 @@ func TestNewJWT(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
claims jwt.MapClaims
|
||||
key string
|
||||
duration int64
|
||||
duration time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
// empty, zero duration
|
||||
{jwt.MapClaims{}, "", 0, true},
|
||||
// empty, 10 seconds duration
|
||||
{jwt.MapClaims{}, "", 10, false},
|
||||
{jwt.MapClaims{}, "", 10 * time.Second, false},
|
||||
// non-empty, 10 seconds duration
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10, false},
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10 * time.Second, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Errorf("(%d) Expected NewJWT to succeed, got error %v", i, tokenErr)
|
||||
continue
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasParseErr to be %v, got %v (%v)", i, scenario.expectError, hasParseErr, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Errorf("(%d) Missing required claim exp, got %v", i, claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Errorf("(%d) Expected %v claims, got %v", i, scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, claims[j], k, scenario.claims[j])
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Fatalf("Expected NewJWT to succeed, got error %v", tokenErr)
|
||||
}
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Fatalf("Expected hasParseErr to be %v, got %v (%v)", scenario.expectError, hasParseErr, parseErr)
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Fatalf("Missing required claim exp, got %v", claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Fatalf("Expected %v claims, got %v", scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Fatalf("Expected %v for %q claim, got %v", claims[j], k, scenario.claims[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,11 @@ package security
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"math/big"
|
||||
mathRand "math/rand"
|
||||
"time"
|
||||
mathRand "math/rand" // @todo replace with rand/v2?
|
||||
)
|
||||
|
||||
const defaultRandomAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
func init() {
|
||||
mathRand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RandomString generates a cryptographically random string with the specified length.
|
||||
//
|
||||
// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding.
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"regexp/syntax"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultMaxRepeat = 6
|
||||
|
||||
var anyCharNotNLPairs = []rune{'A', 'Z', 'a', 'z', '0', '9'}
|
||||
|
||||
// RandomStringByRegex generates a random string matching the regex pattern.
|
||||
// If optFlags is not set, fallbacks to [syntax.Perl].
|
||||
//
|
||||
// NB! While the source of the randomness comes from [crypto/rand] this method
|
||||
// is not recommended to be used on its own in critical secure contexts because
|
||||
// the generated length could vary too much on the used pattern and may not be
|
||||
// as secure as simply calling [security.RandomString].
|
||||
// If you still insist on using it for such purposes, consider at least
|
||||
// a large enough minimum length for the generated string, e.g. `[a-z0-9]{30}`.
|
||||
//
|
||||
// This function is inspired by github.com/pipe01/revregexp, github.com/lucasjones/reggen and other similar packages.
|
||||
func RandomStringByRegex(pattern string, optFlags ...syntax.Flags) (string, error) {
|
||||
var flags syntax.Flags
|
||||
if len(optFlags) == 0 {
|
||||
flags = syntax.Perl
|
||||
} else {
|
||||
for _, f := range optFlags {
|
||||
flags |= f
|
||||
}
|
||||
}
|
||||
|
||||
r, err := syntax.Parse(pattern, flags)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var sb = new(strings.Builder)
|
||||
|
||||
err = writeRandomStringByRegex(r, sb)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func writeRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder) error {
|
||||
// https://pkg.go.dev/regexp/syntax#Op
|
||||
switch r.Op {
|
||||
case syntax.OpCharClass:
|
||||
c, err := randomRuneFromPairs(r.Rune)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sb.WriteRune(c)
|
||||
return err
|
||||
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
|
||||
c, err := randomRuneFromPairs(anyCharNotNLPairs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sb.WriteRune(c)
|
||||
return err
|
||||
case syntax.OpAlternate:
|
||||
idx, err := randomNumber(len(r.Sub))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRandomStringByRegex(r.Sub[idx], sb)
|
||||
case syntax.OpConcat:
|
||||
var err error
|
||||
for _, sub := range r.Sub {
|
||||
err = writeRandomStringByRegex(sub, sb)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return err
|
||||
case syntax.OpRepeat:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, r.Min, r.Max)
|
||||
case syntax.OpQuest:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 0, 1)
|
||||
case syntax.OpPlus:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 1, -1)
|
||||
case syntax.OpStar:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 0, -1)
|
||||
case syntax.OpCapture:
|
||||
return writeRandomStringByRegex(r.Sub[0], sb)
|
||||
case syntax.OpLiteral:
|
||||
_, err := sb.WriteString(string(r.Rune))
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("unsupported pattern operator %d", r.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func repeatRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder, min int, max int) error {
|
||||
if max < 0 {
|
||||
max = defaultMaxRepeat
|
||||
}
|
||||
|
||||
if max < min {
|
||||
max = min
|
||||
}
|
||||
|
||||
n := min
|
||||
if max != min {
|
||||
randRange, err := randomNumber(max - min)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n += randRange
|
||||
}
|
||||
|
||||
var err error
|
||||
for i := 0; i < n; i++ {
|
||||
err = writeRandomStringByRegex(r, sb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomRuneFromPairs(pairs []rune) (rune, error) {
|
||||
idx, err := randomNumber(len(pairs) / 2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return randomRuneFromRange(pairs[idx*2], pairs[idx*2+1])
|
||||
}
|
||||
|
||||
func randomRuneFromRange(min rune, max rune) (rune, error) {
|
||||
offset, err := randomNumber(int(max - min + 1))
|
||||
if err != nil {
|
||||
return min, err
|
||||
}
|
||||
|
||||
return min + rune(offset), nil
|
||||
}
|
||||
|
||||
func randomNumber(maxSoft int) (int, error) {
|
||||
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(maxSoft)))
|
||||
|
||||
return int(randRange.Int64()), err
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"regexp/syntax"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestRandomStringByRegex(t *testing.T) {
|
||||
generated := []string{}
|
||||
|
||||
scenarios := []struct {
|
||||
pattern string
|
||||
flags []syntax.Flags
|
||||
expectError bool
|
||||
}{
|
||||
{``, nil, true},
|
||||
{`test`, nil, false},
|
||||
{`\d+`, []syntax.Flags{syntax.POSIX}, true},
|
||||
{`\d+`, nil, false},
|
||||
{`\d*`, nil, false},
|
||||
{`\d{1,10}`, nil, false},
|
||||
{`\d{3}`, nil, false},
|
||||
{`\d{0,}-abc`, nil, false},
|
||||
{`[a-zA-Z]*`, nil, false},
|
||||
{`[^a-zA-Z]{5,30}`, nil, false},
|
||||
{`\w+_abc`, nil, false},
|
||||
{`[a-zA-Z_]*`, nil, false},
|
||||
{`[2-9]{5}-\w+`, nil, false},
|
||||
{`(a|b|c)`, nil, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, s.pattern), func(t *testing.T) {
|
||||
str, err := security.RandomStringByRegex(s.pattern, s.flags...)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
r, err := regexp.Compile(s.pattern)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !r.Match([]byte(str)) {
|
||||
t.Fatalf("Expected %q to match pattern %v", str, s.pattern)
|
||||
}
|
||||
|
||||
if slices.Contains(generated, str) {
|
||||
t.Fatalf("The generated string %q already exists in\n%v", str, generated)
|
||||
}
|
||||
|
||||
generated = append(generated, str)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user