121 lines
2.7 KiB
Go
121 lines
2.7 KiB
Go
// Package jwk implements some common utilities for interacting with JWKs
|
|
// (mostly used with OIDC providers).
|
|
package jwk
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
type JWK struct {
|
|
Kty string `json:"kty"`
|
|
Kid string `json:"kid"`
|
|
Use string `json:"use"`
|
|
Alg string `json:"alg"`
|
|
// RS256 (RSA)
|
|
E string `json:"e"`
|
|
N string `json:"n"`
|
|
// Ed25519 (OKP)
|
|
Crv string `json:"crv"`
|
|
X string `json:"x"`
|
|
}
|
|
|
|
// PublicKey reconstructs and returns the public key from the current JWK.
|
|
func (key *JWK) PublicKey() (any, error) {
|
|
switch key.Kty {
|
|
case "RSA":
|
|
// RFC 7518
|
|
// https://datatracker.ietf.org/doc/html/rfc7518#section-6.3
|
|
// https://datatracker.ietf.org/doc/html/rfc7517#appendix-A.1
|
|
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &rsa.PublicKey{
|
|
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
|
|
N: big.NewInt(0).SetBytes(modulus),
|
|
}, nil
|
|
case "OKP":
|
|
// RFC 8037
|
|
// https://datatracker.ietf.org/doc/html/rfc8037#section-2
|
|
// https://datatracker.ietf.org/doc/html/rfc8037#appendix-A
|
|
if key.Crv != "Ed25519" {
|
|
return nil, fmt.Errorf("unsupported OKP curve (must be Ed25519): %q", key.Crv)
|
|
}
|
|
|
|
x, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.X, "="))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if l := len(x); l != ed25519.PublicKeySize {
|
|
return nil, fmt.Errorf("invalid Ed25519 key length: %d", l)
|
|
}
|
|
|
|
return ed25519.PublicKey(x), nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported kty (must be RSA or OKP): %q", key.Kty)
|
|
}
|
|
}
|
|
|
|
// Fetch retrieves the JSON Web Key Set located at jwksURL and returns
|
|
// the first key that matches the specified kid.
|
|
func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
res, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
rawBody, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// http.Client.Get doesn't treat non 2xx responses as error
|
|
if res.StatusCode >= 400 {
|
|
return nil, fmt.Errorf(
|
|
"failed to JSON Web Key Set from %s (%d):\n%s",
|
|
jwksURL,
|
|
res.StatusCode,
|
|
string(rawBody),
|
|
)
|
|
}
|
|
|
|
jwks := struct {
|
|
Keys []*JWK
|
|
}{}
|
|
|
|
err = json.Unmarshal(rawBody, &jwks)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, key := range jwks.Keys {
|
|
if key.Kid == kid {
|
|
return key, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
|
|
}
|