[#7252] support ed25519 oidc id_token signature validation

This commit is contained in:
Gani Georgiev 2025-10-19 13:49:31 +03:00
parent 8acb48b884
commit 58da159641
4 changed files with 336 additions and 78 deletions

View File

@ -2,6 +2,8 @@
- Visualize presentable multiple `relation` fields ([#7260](https://github.com/pocketbase/pocketbase/issues/7260)). - Visualize presentable multiple `relation` fields ([#7260](https://github.com/pocketbase/pocketbase/issues/7260)).
- Support Ed25519 in the optional OIDC id_token signature validation ([#7252](https://github.com/pocketbase/pocketbase/issues/7252); thanks @shynome).
## v0.30.4 ## v0.30.4

View File

@ -0,0 +1,120 @@
// Package jwk implements some common utilities for interacting with JWKs
// (mostly used with OIDC providers).
package jwk
import (
"context"
"crypto/ed25519"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"strings"
)
type JWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
Alg string `json:"alg"`
// RS256 (RSA)
E string `json:"e"`
N string `json:"n"`
// Ed25519 (OKP)
Crv string `json:"crv"`
X string `json:"x"`
}
// PublicKey reconstructs and returns the public key from the current JWK.
func (key *JWK) PublicKey() (any, error) {
switch key.Kty {
case "RSA":
// RFC 7518
// https://datatracker.ietf.org/doc/html/rfc7518#section-6.3
// https://datatracker.ietf.org/doc/html/rfc7517#appendix-A.1
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
if err != nil {
return nil, err
}
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
if err != nil {
return nil, err
}
return &rsa.PublicKey{
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
N: big.NewInt(0).SetBytes(modulus),
}, nil
case "OKP":
// RFC 8037
// https://datatracker.ietf.org/doc/html/rfc8037#section-2
// https://datatracker.ietf.org/doc/html/rfc8037#appendix-A
if key.Crv != "Ed25519" {
return nil, fmt.Errorf("unsupported OKP curve (must be Ed25519): %q", key.Crv)
}
x, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.X, "="))
if err != nil {
return nil, err
}
if l := len(x); l != ed25519.PublicKeySize {
return nil, fmt.Errorf("invalid Ed25519 key length: %d", l)
}
return ed25519.PublicKey(x), nil
default:
return nil, fmt.Errorf("unsupported kty (must be RSA or OKP): %q", key.Kty)
}
}
// Fetch retrieves the JSON Web Key Set located at jwksURL and returns
// the first key that matches the specified kid.
func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
rawBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// http.Client.Get doesn't treat non 2xx responses as error
if res.StatusCode >= 400 {
return nil, fmt.Errorf(
"failed to JSON Web Key Set from %s (%d):\n%s",
jwksURL,
res.StatusCode,
string(rawBody),
)
}
jwks := struct {
Keys []*JWK
}{}
err = json.Unmarshal(rawBody, &jwks)
if err != nil {
return nil, err
}
for _, key := range jwks.Keys {
if key.Kid == kid {
return key, nil
}
}
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
}

View File

@ -0,0 +1,211 @@
package jwk_test
import (
"context"
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
)
type publicKey interface {
Equal(x crypto.PublicKey) bool
}
func TestJWK_PublicKey(t *testing.T) {
t.Parallel()
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatalf("failed to generate test RSA private key: %v", err)
}
scenarios := []struct {
name string
key *jwk.JWK
expectError bool
expectKey crypto.PublicKey
}{
{
"empty",
&jwk.JWK{},
true,
nil,
},
{
"invalid kty",
&jwk.JWK{
Kty: "invalid",
Alg: "RS256",
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
},
true,
nil,
},
{
"RSA",
&jwk.JWK{
Kty: "RSA",
Alg: "RS256",
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
},
false,
&rsaPrivate.PublicKey,
},
{
"OKP with unsupported curve",
&jwk.JWK{
Kty: "OKP",
Crv: "invalid",
X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize))),
},
true,
nil,
},
{
"OKP with invalid public key length",
&jwk.JWK{
Kty: "OKP",
Crv: "Ed25519",
X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize-1))),
},
true,
nil,
},
{
"valid OKP",
&jwk.JWK{
Kty: "OKP",
Crv: "Ed25519",
X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize))),
},
false,
ed25519.PublicKey([]byte(strings.Repeat("a", ed25519.PublicKeySize))),
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result, err := s.key.PublicKey()
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr && result == nil {
return
}
k, ok := result.(publicKey)
if !ok {
t.Fatalf("The returned public key %T doesn't satisfy the expected common interface", k)
}
if !k.Equal(s.expectKey) {
t.Fatalf("The returned public key doesn't match with the expected one:\n%v\n%v", k, s.expectKey)
}
})
}
}
func TestFetch(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
if req.URL.Query().Has("error") {
res.WriteHeader(http.StatusBadRequest)
}
fmt.Fprintf(res, `{
"keys": [
{
"kid": "abc",
"kty": "OKP",
"crv": "Ed25519",
"x": "test_x"
},
{
"kid": "def",
"kty": "RSA",
"alg": "RS256",
"n": "test_n",
"e": "test_e"
}
]
}`)
}))
defer server.Close()
scenarios := []struct {
name string
kid string
expectError bool
contains []string
}{
{
"error response",
"def",
true,
nil,
},
{
"non-matching kid",
"missing",
true,
nil,
},
{
"matching kid",
"def",
false,
[]string{
`"kid":"def"`,
`"kty":"RSA"`,
`"alg":"RS256"`,
`"n":"test_n"`,
`"e":"test_e"`,
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
url := server.URL
if s.expectError {
url += "?error"
}
key, err := jwk.Fetch(context.Background(), url, s.kid)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
raw, err := json.Marshal(key)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
for _, substr := range s.contains {
if !strings.Contains(rawStr, substr) {
t.Fatalf("Missing expected substring\n%s\nin\n%s", substr, rawStr)
}
}
})
}
}

View File

@ -2,20 +2,15 @@ package auth
import ( import (
"context" "context"
"crypto/rsa"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"math/big"
"net/http"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
"github.com/pocketbase/pocketbase/tools/security" "github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types" "github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast" "github.com/spf13/cast"
@ -198,36 +193,17 @@ func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL strin
return errors.New("missing kid header value") return errors.New("missing kid header value")
} }
key, err := fetchJWK(ctx, jwksURL, kid) key, err := jwk.Fetch(ctx, jwksURL, kid)
if err != nil { if err != nil {
return err return err
} }
// decode the key params per RFC 7518 (https://tools.ietf.org/html/rfc7518#section-6.3)
// and construct a valid publicKey from them
// ---
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
if err != nil {
return err
}
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
if err != nil {
return err
}
publicKey := &rsa.PublicKey{
// https://tools.ietf.org/html/rfc7517#appendix-A.1
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
N: big.NewInt(0).SetBytes(modulus),
}
// verify the signiture // verify the signiture
// --- // ---
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg})) parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) { parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
return publicKey, nil return key.PublicKey()
}) })
if err != nil { if err != nil {
return err return err
@ -239,54 +215,3 @@ func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL strin
return nil return nil
} }
type jwk struct {
Kty string
Kid string
Use string
Alg string
N string
E string
}
func fetchJWK(ctx context.Context, jwksURL string, kid string) (*jwk, error) {
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
rawBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// http.Client.Get doesn't treat non 2xx responses as error
if res.StatusCode >= 400 {
return nil, fmt.Errorf(
"failed to verify the provided id_token (%d):\n%s",
res.StatusCode,
string(rawBody),
)
}
jwks := struct {
Keys []*jwk
}{}
if err := json.Unmarshal(rawBody, &jwks); err != nil {
return nil, err
}
for _, key := range jwks.Keys {
if key.Kid == kid {
return key, nil
}
}
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
}