From 58da159641b5d31fd1a68399a641b2841c3888d3 Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Sun, 19 Oct 2025 13:49:31 +0300 Subject: [PATCH] [#7252] support ed25519 oidc id_token signature validation --- CHANGELOG.md | 2 + tools/auth/internal/jwk/jwk.go | 120 ++++++++++++++++ tools/auth/internal/jwk/jwk_test.go | 211 ++++++++++++++++++++++++++++ tools/auth/oidc.go | 81 +---------- 4 files changed, 336 insertions(+), 78 deletions(-) create mode 100644 tools/auth/internal/jwk/jwk.go create mode 100644 tools/auth/internal/jwk/jwk_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index e2916d8a..e7d0a4b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ - Visualize presentable multiple `relation` fields ([#7260](https://github.com/pocketbase/pocketbase/issues/7260)). +- Support Ed25519 in the optional OIDC id_token signature validation ([#7252](https://github.com/pocketbase/pocketbase/issues/7252); thanks @shynome). + ## v0.30.4 diff --git a/tools/auth/internal/jwk/jwk.go b/tools/auth/internal/jwk/jwk.go new file mode 100644 index 00000000..b81bf129 --- /dev/null +++ b/tools/auth/internal/jwk/jwk.go @@ -0,0 +1,120 @@ +// Package jwk implements some common utilities for interacting with JWKs +// (mostly used with OIDC providers). +package jwk + +import ( + "context" + "crypto/ed25519" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "strings" +) + +type JWK struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + Alg string `json:"alg"` + // RS256 (RSA) + E string `json:"e"` + N string `json:"n"` + // Ed25519 (OKP) + Crv string `json:"crv"` + X string `json:"x"` +} + +// PublicKey reconstructs and returns the public key from the current JWK. +func (key *JWK) PublicKey() (any, error) { + switch key.Kty { + case "RSA": + // RFC 7518 + // https://datatracker.ietf.org/doc/html/rfc7518#section-6.3 + // https://datatracker.ietf.org/doc/html/rfc7517#appendix-A.1 + exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "=")) + if err != nil { + return nil, err + } + + modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "=")) + if err != nil { + return nil, err + } + + return &rsa.PublicKey{ + E: int(big.NewInt(0).SetBytes(exponent).Uint64()), + N: big.NewInt(0).SetBytes(modulus), + }, nil + case "OKP": + // RFC 8037 + // https://datatracker.ietf.org/doc/html/rfc8037#section-2 + // https://datatracker.ietf.org/doc/html/rfc8037#appendix-A + if key.Crv != "Ed25519" { + return nil, fmt.Errorf("unsupported OKP curve (must be Ed25519): %q", key.Crv) + } + + x, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.X, "=")) + if err != nil { + return nil, err + } + + if l := len(x); l != ed25519.PublicKeySize { + return nil, fmt.Errorf("invalid Ed25519 key length: %d", l) + } + + return ed25519.PublicKey(x), nil + default: + return nil, fmt.Errorf("unsupported kty (must be RSA or OKP): %q", key.Kty) + } +} + +// Fetch retrieves the JSON Web Key Set located at jwksURL and returns +// the first key that matches the specified kid. +func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) { + req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) + if err != nil { + return nil, err + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + rawBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + // http.Client.Get doesn't treat non 2xx responses as error + if res.StatusCode >= 400 { + return nil, fmt.Errorf( + "failed to JSON Web Key Set from %s (%d):\n%s", + jwksURL, + res.StatusCode, + string(rawBody), + ) + } + + jwks := struct { + Keys []*JWK + }{} + + err = json.Unmarshal(rawBody, &jwks) + if err != nil { + return nil, err + } + + for _, key := range jwks.Keys { + if key.Kid == kid { + return key, nil + } + } + + return nil, fmt.Errorf("jwk with kid %q was not found", kid) +} diff --git a/tools/auth/internal/jwk/jwk_test.go b/tools/auth/internal/jwk/jwk_test.go new file mode 100644 index 00000000..07e780f9 --- /dev/null +++ b/tools/auth/internal/jwk/jwk_test.go @@ -0,0 +1,211 @@ +package jwk_test + +import ( + "context" + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/pocketbase/pocketbase/tools/auth/internal/jwk" +) + +type publicKey interface { + Equal(x crypto.PublicKey) bool +} + +func TestJWK_PublicKey(t *testing.T) { + t.Parallel() + + rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatalf("failed to generate test RSA private key: %v", err) + } + + scenarios := []struct { + name string + key *jwk.JWK + expectError bool + expectKey crypto.PublicKey + }{ + { + "empty", + &jwk.JWK{}, + true, + nil, + }, + { + "invalid kty", + &jwk.JWK{ + Kty: "invalid", + Alg: "RS256", + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()), + N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()), + }, + true, + nil, + }, + { + "RSA", + &jwk.JWK{ + Kty: "RSA", + Alg: "RS256", + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()), + N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()), + }, + false, + &rsaPrivate.PublicKey, + }, + { + "OKP with unsupported curve", + &jwk.JWK{ + Kty: "OKP", + Crv: "invalid", + X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize))), + }, + true, + nil, + }, + { + "OKP with invalid public key length", + &jwk.JWK{ + Kty: "OKP", + Crv: "Ed25519", + X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize-1))), + }, + true, + nil, + }, + { + "valid OKP", + &jwk.JWK{ + Kty: "OKP", + Crv: "Ed25519", + X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize))), + }, + false, + ed25519.PublicKey([]byte(strings.Repeat("a", ed25519.PublicKeySize))), + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + result, err := s.key.PublicKey() + + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err) + } + + if hasErr && result == nil { + return + } + + k, ok := result.(publicKey) + if !ok { + t.Fatalf("The returned public key %T doesn't satisfy the expected common interface", k) + } + + if !k.Equal(s.expectKey) { + t.Fatalf("The returned public key doesn't match with the expected one:\n%v\n%v", k, s.expectKey) + } + }) + } +} + +func TestFetch(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + if req.URL.Query().Has("error") { + res.WriteHeader(http.StatusBadRequest) + } + + fmt.Fprintf(res, `{ + "keys": [ + { + "kid": "abc", + "kty": "OKP", + "crv": "Ed25519", + "x": "test_x" + }, + { + "kid": "def", + "kty": "RSA", + "alg": "RS256", + "n": "test_n", + "e": "test_e" + } + ] + }`) + })) + defer server.Close() + + scenarios := []struct { + name string + kid string + expectError bool + contains []string + }{ + { + "error response", + "def", + true, + nil, + }, + { + "non-matching kid", + "missing", + true, + nil, + }, + { + "matching kid", + "def", + false, + []string{ + `"kid":"def"`, + `"kty":"RSA"`, + `"alg":"RS256"`, + `"n":"test_n"`, + `"e":"test_e"`, + }, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + url := server.URL + if s.expectError { + url += "?error" + } + + key, err := jwk.Fetch(context.Background(), url, s.kid) + + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err) + } + + raw, err := json.Marshal(key) + if err != nil { + t.Fatal(err) + } + rawStr := string(raw) + + for _, substr := range s.contains { + if !strings.Contains(rawStr, substr) { + t.Fatalf("Missing expected substring\n%s\nin\n%s", substr, rawStr) + } + } + }) + } +} diff --git a/tools/auth/oidc.go b/tools/auth/oidc.go index 81529ad6..43e093e3 100644 --- a/tools/auth/oidc.go +++ b/tools/auth/oidc.go @@ -2,20 +2,15 @@ package auth import ( "context" - "crypto/rsa" - "encoding/base64" "encoding/json" "errors" "fmt" - "io" - "math/big" - "net/http" "os" "strconv" - "strings" "time" "github.com/golang-jwt/jwt/v5" + "github.com/pocketbase/pocketbase/tools/auth/internal/jwk" "github.com/pocketbase/pocketbase/tools/security" "github.com/pocketbase/pocketbase/tools/types" "github.com/spf13/cast" @@ -198,36 +193,17 @@ func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL strin return errors.New("missing kid header value") } - key, err := fetchJWK(ctx, jwksURL, kid) + key, err := jwk.Fetch(ctx, jwksURL, kid) if err != nil { return err } - // decode the key params per RFC 7518 (https://tools.ietf.org/html/rfc7518#section-6.3) - // and construct a valid publicKey from them - // --- - exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "=")) - if err != nil { - return err - } - - modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "=")) - if err != nil { - return err - } - - publicKey := &rsa.PublicKey{ - // https://tools.ietf.org/html/rfc7517#appendix-A.1 - E: int(big.NewInt(0).SetBytes(exponent).Uint64()), - N: big.NewInt(0).SetBytes(modulus), - } - // verify the signiture // --- parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg})) parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) { - return publicKey, nil + return key.PublicKey() }) if err != nil { return err @@ -239,54 +215,3 @@ func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL strin return nil } - -type jwk struct { - Kty string - Kid string - Use string - Alg string - N string - E string -} - -func fetchJWK(ctx context.Context, jwksURL string, kid string) (*jwk, error) { - req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) - if err != nil { - return nil, err - } - - res, err := http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - - rawBody, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - - // http.Client.Get doesn't treat non 2xx responses as error - if res.StatusCode >= 400 { - return nil, fmt.Errorf( - "failed to verify the provided id_token (%d):\n%s", - res.StatusCode, - string(rawBody), - ) - } - - jwks := struct { - Keys []*jwk - }{} - if err := json.Unmarshal(rawBody, &jwks); err != nil { - return nil, err - } - - for _, key := range jwks.Keys { - if key.Kid == kid { - return key, nil - } - } - - return nil, fmt.Errorf("jwk with kid %q was not found", kid) -}