initial public commit
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
crand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// S256Challenge creates base64 encoded sha256 challenge string derived from code.
|
||||
// The padding of the result base64 string is stripped per [RFC 7636].
|
||||
//
|
||||
// [RFC 7636]: https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
|
||||
func S256Challenge(code string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(code))
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(h.Sum(nil)), "=")
|
||||
}
|
||||
|
||||
// Encrypt encrypts data with key (must be valid 32 char aes key).
|
||||
func Encrypt(data []byte, key string) (string, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
|
||||
// populates the nonce with a cryptographically secure random sequence
|
||||
if _, err := io.ReadFull(crand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cipherByte := gcm.Seal(nonce, nonce, data, nil)
|
||||
|
||||
result := base64.StdEncoding.EncodeToString(cipherByte)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts encrypted text with key (must be valid 32 chars aes key).
|
||||
func Decrypt(cipherText string, key string) ([]byte, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
|
||||
cipherByte, err := base64.StdEncoding.DecodeString(cipherText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce, cipherByteClean := cipherByte[:nonceSize], cipherByte[nonceSize:]
|
||||
plainData, err := gcm.Open(nil, nonce, cipherByteClean, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plainData, nil
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestS256Challenge(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{"", "47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU"},
|
||||
{"123", "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := security.S256Challenge(scenario.code)
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
data string
|
||||
key string
|
||||
expectError bool
|
||||
}{
|
||||
{"", "", true},
|
||||
{"123", "test", true}, // key must be valid 32 char aes string
|
||||
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Encrypt([]byte(scenario.data), scenario.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
|
||||
if scenario.expectError && result != "" {
|
||||
t.Errorf("(%d) Expected empty string, got %q", i, result)
|
||||
}
|
||||
if !scenario.expectError && result == "" {
|
||||
t.Errorf("(%d) Expected non empty encrypted result string", i)
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
if result != "" {
|
||||
decrypted, _ := security.Decrypt(result, scenario.key)
|
||||
if string(decrypted) != scenario.data {
|
||||
t.Errorf("(%d) Expected decrypted value to match with the data input, got %q", i, decrypted)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
cipher string
|
||||
key string
|
||||
expectError bool
|
||||
expectedData string
|
||||
}{
|
||||
{"", "", true, ""},
|
||||
{"123", "test", true, ""}, // key must be valid 32 char aes string
|
||||
{"8kcEqilvvYKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", true, ""}, // illegal base64 encoded cipherText
|
||||
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.Decrypt(scenario.cipher, scenario.key)
|
||||
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
|
||||
resultStr := string(result)
|
||||
if resultStr != scenario.expectedData {
|
||||
t.Errorf("(%d) Expected %q, got %q", i, scenario.expectedData, resultStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// ParseUnverifiedJWT parses JWT token and returns its claims
|
||||
// but DOES NOT verify the signature.
|
||||
func ParseUnverifiedJWT(token string) (jwt.MapClaims, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
parser := &jwt.Parser{}
|
||||
_, _, err := parser.ParseUnverified(token, claims)
|
||||
|
||||
if err == nil {
|
||||
err = claims.Valid()
|
||||
}
|
||||
|
||||
return claims, err
|
||||
}
|
||||
|
||||
// ParseJWT verifies and parses JWT token and returns its claims.
|
||||
func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
|
||||
parser := &jwt.Parser{
|
||||
ValidMethods: []string{"HS256"},
|
||||
}
|
||||
|
||||
parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) {
|
||||
return []byte(verificationKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("Unable to parse token.")
|
||||
}
|
||||
|
||||
// NewToken generates and returns new HS256 signed JWT token.
|
||||
func NewToken(payload jwt.MapClaims, signingKey string, secondsDuration int64) (string, error) {
|
||||
seconds := time.Duration(secondsDuration) * time.Second
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"exp": time.Now().Add(seconds).Unix(),
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
for k, v := range payload {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestParseUnverifiedJWT(t *testing.T) {
|
||||
// invalid formatted JWT token
|
||||
result1, err1 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9")
|
||||
if err1 == nil {
|
||||
t.Error("Expected error got nil")
|
||||
}
|
||||
if len(result1) > 0 {
|
||||
t.Error("Expected no parsed claims, got", result1)
|
||||
}
|
||||
|
||||
// properly formatted JWT token with INVALID claims
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
result2, err2 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU")
|
||||
if err2 == nil {
|
||||
t.Error("Expected error got nil")
|
||||
}
|
||||
if len(result2) != 2 || result2["name"] != "test" {
|
||||
t.Errorf("Expected to have 2 claims, got %v", result2)
|
||||
}
|
||||
|
||||
// properly formatted JWT token with VALID claims
|
||||
// {"name": "test"}
|
||||
result3, err3 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9.ml0QsTms3K9wMygTu41ZhKlTyjmW9zHQtoS8FUsCCjU")
|
||||
if err3 != nil {
|
||||
t.Error("Expected nil, got", err3)
|
||||
}
|
||||
if len(result3) != 1 || result3["name"] != "test" {
|
||||
t.Errorf("Expected to have 2 claims, got %v", result3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJWT(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
token string
|
||||
secret string
|
||||
expectError bool
|
||||
expectClaims jwt.MapClaims
|
||||
}{
|
||||
// invalid formatted JWT token
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with INVALID claims and INVALID secret
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU",
|
||||
"invalid",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with INVALID claims and VALID secret
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with VALID claims and INVALID secret
|
||||
// {"name": "test", "exp": 1898636137}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTg5ODYzNjEzN30.gqRkHjpK5s1PxxBn9qPaWEWxTbpc1PPSD-an83TsXRY",
|
||||
"invalid",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted EXPIRED JWT token with VALID secret
|
||||
// {"name": "test", "exp": 1652097610}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6OTU3ODczMzc0fQ.0oUUKUnsQHs4nZO1pnxQHahKtcHspHu4_AplN2sGC4A",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT token with VALID claims and VALID secret
|
||||
// {"name": "test", "exp": 1898636137}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTg5ODYzNjEzN30.gqRkHjpK5s1PxxBn9qPaWEWxTbpc1PPSD-an83TsXRY",
|
||||
"test",
|
||||
false,
|
||||
jwt.MapClaims{"name": "test", "exp": 1898636137.0},
|
||||
},
|
||||
// properly formatted JWT token with VALID claims (without exp) and VALID secret
|
||||
// {"name": "test"}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9.ml0QsTms3K9wMygTu41ZhKlTyjmW9zHQtoS8FUsCCjU",
|
||||
"test",
|
||||
false,
|
||||
jwt.MapClaims{"name": "test"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result, err := security.ParseJWT(scenario.token, scenario.secret)
|
||||
if scenario.expectError && err == nil {
|
||||
t.Errorf("(%d) Expected error got nil", i)
|
||||
}
|
||||
if !scenario.expectError && err != nil {
|
||||
t.Errorf("(%d) Expected nil got error %v", i, err)
|
||||
}
|
||||
if len(result) != len(scenario.expectClaims) {
|
||||
t.Errorf("(%d) Expected %v got %v", i, scenario.expectClaims, result)
|
||||
}
|
||||
for k, v := range scenario.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Errorf("(%d) Missing expected claim %q", i, k)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, v, k, v2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewToken(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
claims jwt.MapClaims
|
||||
key string
|
||||
duration int64
|
||||
expectError bool
|
||||
}{
|
||||
// empty, zero duration
|
||||
{jwt.MapClaims{}, "", 0, true},
|
||||
// empty, 10 seconds duration
|
||||
{jwt.MapClaims{}, "", 10, false},
|
||||
// non-empty, 10 seconds duration
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
token, tokenErr := security.NewToken(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Errorf("(%d) Expected NewToken to succeed, got error %v", i, tokenErr)
|
||||
continue
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasParseErr to be %v, got %v (%v)", i, scenario.expectError, hasParseErr, parseErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Errorf("(%d) Missing required claim exp, got %v", i, claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Errorf("(%d) Expected %v claims, got %v", i, scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Errorf("(%d) Expected %v for %q claim, got %v", i, claims[j], k, scenario.claims[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
)
|
||||
|
||||
// RandomString generates a random string of specified length.
|
||||
//
|
||||
// The generated string is cryptographically random and matches
|
||||
// [A-Za-z0-9]+ (aka. it's transparent to URL-encoding).
|
||||
func RandomString(length int) string {
|
||||
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
for i, b := range bytes {
|
||||
bytes[i] = alphabet[b%byte(len(alphabet))]
|
||||
}
|
||||
|
||||
return string(bytes)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestRandomString(t *testing.T) {
|
||||
generated := []string{}
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
length := 5 + i
|
||||
result := security.RandomString(length)
|
||||
|
||||
if len(result) != length {
|
||||
t.Errorf("(%d) Expected the length of the string to be %d, got %d", i, length, len(result))
|
||||
}
|
||||
|
||||
if match, _ := regexp.MatchString("[a-zA-Z0-9]+", result); !match {
|
||||
t.Errorf("(%d) The generated strings should have only [a-zA-Z0-9]+ characters, got %q", i, result)
|
||||
}
|
||||
|
||||
for _, str := range generated {
|
||||
if str == result {
|
||||
t.Errorf("(%d) Repeating random string - found %q in \n%v", i, result, generated)
|
||||
}
|
||||
}
|
||||
|
||||
generated = append(generated, result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user