diff --git a/ecdsa.go b/ecdsa.go index eac023fc..feea13b1 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -140,3 +140,10 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string return "", err } } + +// ECDSAPublicKey represents a [Keyfunc] that returns the ECDSA key specified in +// key. Furthermore, it checks, whether the signing method matches +// [SigningMethodECDSA]. +func ECDSAPublicKey(key *ecdsa.PublicKey) Keyfunc { + return secureKeyFunc(key, []string{"ES256", "ES384", "ES512"}) +} diff --git a/ed25519.go b/ed25519.go index 07d3aacd..aaeca2d9 100644 --- a/ed25519.go +++ b/ed25519.go @@ -83,3 +83,10 @@ func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (stri } return EncodeSegment(sig), nil } + +// Ed25519PublicKey represents a [Keyfunc] that returns the Ed25519 key +// specified in key. Furthermore, it checks, whether the signing method matches +// [SigningMethodEdDSA]. +func Ed25519PublicKey(key ed25519.PublicKey) Keyfunc { + return secureKeyFunc(key, []string{"EdDSA"}) +} diff --git a/example_test.go b/example_test.go index 58fdea43..176a4cdc 100644 --- a/example_test.go +++ b/example_test.go @@ -80,9 +80,7 @@ func ExampleParseWithClaims_customClaimsType() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) + token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, jwt.PresharedKey([]byte("AllYourBase"))) if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) @@ -103,9 +101,11 @@ func ExampleParseWithClaims_validationOptions() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }, jwt.WithLeeway(5*time.Second)) + token, err := jwt.ParseWithClaims( + tokenString, &MyCustomClaims{}, + jwt.PresharedKey([]byte("AllYourBase")), + jwt.WithLeeway(5*time.Second), + ) if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) @@ -138,9 +138,10 @@ func (m MyCustomClaims) Validate() error { func ExampleParseWithClaims_customValidation() { tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }, jwt.WithLeeway(5*time.Second)) + token, err := jwt.ParseWithClaims( + tokenString, &MyCustomClaims{}, + jwt.PresharedKey([]byte("AllYourBase")), + jwt.WithLeeway(5*time.Second)) if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) @@ -156,9 +157,7 @@ func ExampleParse_errorChecking() { // Token from another example. This token is expired var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) + token, err := jwt.Parse(tokenString, jwt.PresharedKey([]byte("AllYourBase"))) if token.Valid { fmt.Println("You look nice today") diff --git a/hmac.go b/hmac.go index 011f68a2..e55d19be 100644 --- a/hmac.go +++ b/hmac.go @@ -93,3 +93,8 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, return "", ErrInvalidKeyType } + +// PresharedKey represents a [Keyfunc] that simply returns the key specified in the byte slice. +func PresharedKey(key []byte) Keyfunc { + return secureKeyFunc(key, []string{"HS256", "HS384", "HS512"}) +} diff --git a/parser.go b/parser.go index 46b67931..239da30c 100644 --- a/parser.go +++ b/parser.go @@ -3,7 +3,6 @@ package jwt import ( "bytes" "encoding/json" - "fmt" "strings" ) @@ -55,17 +54,8 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // Verify signing method is in the required set if p.validMethods != nil { - var signingMethodValid = false - var alg = token.Method.Alg() - for _, m := range p.validMethods { - if m == alg { - signingMethodValid = true - break - } - } - if !signingMethodValid { - // signing method is not in the listed set - return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid) + if err = token.hasValidSigningMethod(p.validMethods); err != nil { + return token, err } } diff --git a/rsa.go b/rsa.go index b910b19c..d0d8bf6c 100644 --- a/rsa.go +++ b/rsa.go @@ -99,3 +99,10 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, return "", err } } + +// RSAPublicKey represents a [Keyfunc] that returns the RSA key specified in +// key. Furthermore, it checks, whether the signing method matches +// [SigningMethodRSA]. +func RSAPublicKey(key *rsa.PublicKey) Keyfunc { + return secureKeyFunc(key, []string{"RS256", "RS384", "RS512"}) +} diff --git a/test/helpers.go b/test/helpers.go index 381c5f8a..a27c1d56 100644 --- a/test/helpers.go +++ b/test/helpers.go @@ -2,6 +2,7 @@ package test import ( "crypto" + "crypto/ecdsa" "crypto/rsa" "os" @@ -56,7 +57,7 @@ func LoadECPrivateKeyFromDisk(location string) crypto.PrivateKey { return key } -func LoadECPublicKeyFromDisk(location string) crypto.PublicKey { +func LoadECPublicKeyFromDisk(location string) *ecdsa.PublicKey { keyData, e := os.ReadFile(location) if e != nil { panic(e.Error()) diff --git a/token.go b/token.go index b3459427..5dbfd668 100644 --- a/token.go +++ b/token.go @@ -3,6 +3,7 @@ package jwt import ( "encoding/base64" "encoding/json" + "fmt" "strings" ) @@ -143,3 +144,36 @@ func DecodeSegment(seg string) ([]byte, error) { } return encoding.DecodeString(seg) } + +// secureKeyFunc returns a secure [Keyfunc] for the specified key that also +// includes a signing method check. +func secureKeyFunc(key any, validMethods []string) Keyfunc { + return func(t *Token) (interface{}, error) { + // Check, if the signing method matches + if err := t.hasValidSigningMethod(validMethods); err != nil { + return nil, err + } + + return key, nil + } +} + +// hasValidSigningMethod is a utility function that checks, if the signing +// method of the token is included in the validMethods slice. +func (token *Token) hasValidSigningMethod(validMethods []string) error { + var signingMethodValid = false + var alg = token.Method.Alg() + for _, m := range validMethods { + if m == alg { + signingMethodValid = true + break + } + } + + if !signingMethodValid { + // signing method is not in the listed set + return newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid) + } + + return nil +} diff --git a/token_test.go b/token_test.go index 52a00212..ee64a982 100644 --- a/token_test.go +++ b/token_test.go @@ -1,17 +1,17 @@ -package jwt_test +package jwt import ( + "errors" + "reflect" "testing" - - "github.com/golang-jwt/jwt/v5" ) func TestToken_SigningString(t1 *testing.T) { type fields struct { Raw string - Method jwt.SigningMethod + Method SigningMethod Header map[string]interface{} - Claims jwt.Claims + Claims Claims Signature string Valid bool } @@ -25,12 +25,12 @@ func TestToken_SigningString(t1 *testing.T) { name: "", fields: fields{ Raw: "", - Method: jwt.SigningMethodHS256, + Method: SigningMethodHS256, Header: map[string]interface{}{ "typ": "JWT", - "alg": jwt.SigningMethodHS256.Alg(), + "alg": SigningMethodHS256.Alg(), }, - Claims: jwt.RegisteredClaims{}, + Claims: RegisteredClaims{}, Signature: "", Valid: false, }, @@ -40,7 +40,7 @@ func TestToken_SigningString(t1 *testing.T) { } for _, tt := range tests { t1.Run(tt.name, func(t1 *testing.T) { - t := &jwt.Token{ + t := &Token{ Raw: tt.fields.Raw, Method: tt.fields.Method, Header: tt.fields.Header, @@ -61,13 +61,13 @@ func TestToken_SigningString(t1 *testing.T) { } func BenchmarkToken_SigningString(b *testing.B) { - t := &jwt.Token{ - Method: jwt.SigningMethodHS256, + t := &Token{ + Method: SigningMethodHS256, Header: map[string]interface{}{ "typ": "JWT", - "alg": jwt.SigningMethodHS256.Alg(), + "alg": SigningMethodHS256.Alg(), }, - Claims: jwt.RegisteredClaims{}, + Claims: RegisteredClaims{}, } b.Run("BenchmarkToken_SigningString", func(b *testing.B) { b.ResetTimer() @@ -77,3 +77,48 @@ func BenchmarkToken_SigningString(b *testing.B) { } }) } + +func Test_secureKeyFunc(t *testing.T) { + type fields struct { + token *Token + } + type args struct { + key any + validMethods []string + } + tests := []struct { + name string + fields fields + args args + wantKey any + wantErr error + }{ + { + name: "invalid method", + fields: fields{&Token{Header: map[string]interface{}{"alg": "RS512"}, Method: SigningMethodRS512}}, + args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}}, + wantKey: nil, + wantErr: ErrTokenSignatureInvalid, + }, + { + name: "correct method", + fields: fields{&Token{Header: map[string]interface{}{"alg": "HS256"}, Method: SigningMethodHS256}}, + args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}}, + wantKey: []byte("mysecret"), + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyfunc := secureKeyFunc(tt.args.key, tt.args.validMethods) + gotKey, gotErr := keyfunc(tt.fields.token) + + if !reflect.DeepEqual(gotKey, tt.wantKey) { + t.Errorf("secureKeyFunc() key = %v, want %v", gotKey, tt.wantKey) + } + if (gotErr != nil) && !errors.Is(gotErr, tt.wantErr) { + t.Errorf("secureKeyFunc() err = %v, want %v", gotErr, tt.wantErr) + } + }) + } +}