[#7252] support ed25519 oidc id_token signature validation
This commit is contained in:
parent
8acb48b884
commit
58da159641
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
- Visualize presentable multiple `relation` fields ([#7260](https://github.com/pocketbase/pocketbase/issues/7260)).
|
||||
|
||||
- Support Ed25519 in the optional OIDC id_token signature validation ([#7252](https://github.com/pocketbase/pocketbase/issues/7252); thanks @shynome).
|
||||
|
||||
|
||||
## v0.30.4
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
// Package jwk implements some common utilities for interacting with JWKs
|
||||
// (mostly used with OIDC providers).
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
Use string `json:"use"`
|
||||
Alg string `json:"alg"`
|
||||
// RS256 (RSA)
|
||||
E string `json:"e"`
|
||||
N string `json:"n"`
|
||||
// Ed25519 (OKP)
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
}
|
||||
|
||||
// PublicKey reconstructs and returns the public key from the current JWK.
|
||||
func (key *JWK) PublicKey() (any, error) {
|
||||
switch key.Kty {
|
||||
case "RSA":
|
||||
// RFC 7518
|
||||
// https://datatracker.ietf.org/doc/html/rfc7518#section-6.3
|
||||
// https://datatracker.ietf.org/doc/html/rfc7517#appendix-A.1
|
||||
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rsa.PublicKey{
|
||||
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
|
||||
N: big.NewInt(0).SetBytes(modulus),
|
||||
}, nil
|
||||
case "OKP":
|
||||
// RFC 8037
|
||||
// https://datatracker.ietf.org/doc/html/rfc8037#section-2
|
||||
// https://datatracker.ietf.org/doc/html/rfc8037#appendix-A
|
||||
if key.Crv != "Ed25519" {
|
||||
return nil, fmt.Errorf("unsupported OKP curve (must be Ed25519): %q", key.Crv)
|
||||
}
|
||||
|
||||
x, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.X, "="))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if l := len(x); l != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid Ed25519 key length: %d", l)
|
||||
}
|
||||
|
||||
return ed25519.PublicKey(x), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported kty (must be RSA or OKP): %q", key.Kty)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch retrieves the JSON Web Key Set located at jwksURL and returns
|
||||
// the first key that matches the specified kid.
|
||||
func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
rawBody, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// http.Client.Get doesn't treat non 2xx responses as error
|
||||
if res.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to JSON Web Key Set from %s (%d):\n%s",
|
||||
jwksURL,
|
||||
res.StatusCode,
|
||||
string(rawBody),
|
||||
)
|
||||
}
|
||||
|
||||
jwks := struct {
|
||||
Keys []*JWK
|
||||
}{}
|
||||
|
||||
err = json.Unmarshal(rawBody, &jwks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
|
||||
}
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
package jwk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
||||
)
|
||||
|
||||
type publicKey interface {
|
||||
Equal(x crypto.PublicKey) bool
|
||||
}
|
||||
|
||||
func TestJWK_PublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test RSA private key: %v", err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
key *jwk.JWK
|
||||
expectError bool
|
||||
expectKey crypto.PublicKey
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
&jwk.JWK{},
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"invalid kty",
|
||||
&jwk.JWK{
|
||||
Kty: "invalid",
|
||||
Alg: "RS256",
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
|
||||
},
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"RSA",
|
||||
&jwk.JWK{
|
||||
Kty: "RSA",
|
||||
Alg: "RS256",
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
|
||||
},
|
||||
false,
|
||||
&rsaPrivate.PublicKey,
|
||||
},
|
||||
{
|
||||
"OKP with unsupported curve",
|
||||
&jwk.JWK{
|
||||
Kty: "OKP",
|
||||
Crv: "invalid",
|
||||
X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize))),
|
||||
},
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"OKP with invalid public key length",
|
||||
&jwk.JWK{
|
||||
Kty: "OKP",
|
||||
Crv: "Ed25519",
|
||||
X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize-1))),
|
||||
},
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid OKP",
|
||||
&jwk.JWK{
|
||||
Kty: "OKP",
|
||||
Crv: "Ed25519",
|
||||
X: base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat("a", ed25519.PublicKeySize))),
|
||||
},
|
||||
false,
|
||||
ed25519.PublicKey([]byte(strings.Repeat("a", ed25519.PublicKeySize))),
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result, err := s.key.PublicKey()
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr && result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
k, ok := result.(publicKey)
|
||||
if !ok {
|
||||
t.Fatalf("The returned public key %T doesn't satisfy the expected common interface", k)
|
||||
}
|
||||
|
||||
if !k.Equal(s.expectKey) {
|
||||
t.Fatalf("The returned public key doesn't match with the expected one:\n%v\n%v", k, s.expectKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Query().Has("error") {
|
||||
res.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
fmt.Fprintf(res, `{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "abc",
|
||||
"kty": "OKP",
|
||||
"crv": "Ed25519",
|
||||
"x": "test_x"
|
||||
},
|
||||
{
|
||||
"kid": "def",
|
||||
"kty": "RSA",
|
||||
"alg": "RS256",
|
||||
"n": "test_n",
|
||||
"e": "test_e"
|
||||
}
|
||||
]
|
||||
}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
kid string
|
||||
expectError bool
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
"error response",
|
||||
"def",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"non-matching kid",
|
||||
"missing",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"matching kid",
|
||||
"def",
|
||||
false,
|
||||
[]string{
|
||||
`"kid":"def"`,
|
||||
`"kty":"RSA"`,
|
||||
`"alg":"RS256"`,
|
||||
`"n":"test_n"`,
|
||||
`"e":"test_e"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
url := server.URL
|
||||
if s.expectError {
|
||||
url += "?error"
|
||||
}
|
||||
|
||||
key, err := jwk.Fetch(context.Background(), url, s.kid)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
for _, substr := range s.contains {
|
||||
if !strings.Contains(rawStr, substr) {
|
||||
t.Fatalf("Missing expected substring\n%s\nin\n%s", substr, rawStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -2,20 +2,15 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
|
|
@ -198,36 +193,17 @@ func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL strin
|
|||
return errors.New("missing kid header value")
|
||||
}
|
||||
|
||||
key, err := fetchJWK(ctx, jwksURL, kid)
|
||||
key, err := jwk.Fetch(ctx, jwksURL, kid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// decode the key params per RFC 7518 (https://tools.ietf.org/html/rfc7518#section-6.3)
|
||||
// and construct a valid publicKey from them
|
||||
// ---
|
||||
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey := &rsa.PublicKey{
|
||||
// https://tools.ietf.org/html/rfc7517#appendix-A.1
|
||||
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
|
||||
N: big.NewInt(0).SetBytes(modulus),
|
||||
}
|
||||
|
||||
// verify the signiture
|
||||
// ---
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
|
||||
|
||||
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
|
||||
return publicKey, nil
|
||||
return key.PublicKey()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -239,54 +215,3 @@ func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL strin
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
type jwk struct {
|
||||
Kty string
|
||||
Kid string
|
||||
Use string
|
||||
Alg string
|
||||
N string
|
||||
E string
|
||||
}
|
||||
|
||||
func fetchJWK(ctx context.Context, jwksURL string, kid string) (*jwk, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
rawBody, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// http.Client.Get doesn't treat non 2xx responses as error
|
||||
if res.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to verify the provided id_token (%d):\n%s",
|
||||
res.StatusCode,
|
||||
string(rawBody),
|
||||
)
|
||||
}
|
||||
|
||||
jwks := struct {
|
||||
Keys []*jwk
|
||||
}{}
|
||||
if err := json.Unmarshal(rawBody, &jwks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue