Skip to content

Commit

Permalink
Adding canonical Keyfunc functions for RSA, ECDSA, EdDSA and HMAC
Browse files Browse the repository at this point in the history
This PR adds ready-to-use keyfunc functions for the various signing methods. This should simplify a lot of standard use-cases and also includes a proper signing method check.
  • Loading branch information
oxisto committed Feb 21, 2023
1 parent 5dc3299 commit 751179a
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 38 deletions.
7 changes: 7 additions & 0 deletions ecdsa.go
Expand Up @@ -140,3 +140,10 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
return "", err
}
}

// ECDSAPublicKey represents a [Keyfunc] that returns the ECDSA key specified in
// key. Furthermore, it checks, whether the signing method matches
// [SigningMethodECDSA].
func ECDSAPublicKey(key *ecdsa.PublicKey) Keyfunc {
return secureKeyFunc(key, []string{"ES256", "ES384", "ES512"})
}
7 changes: 7 additions & 0 deletions ed25519.go
Expand Up @@ -83,3 +83,10 @@ func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (stri
}
return EncodeSegment(sig), nil
}

// Ed25519PublicKey represents a [Keyfunc] that returns the Ed25519 key
// specified in key. Furthermore, it checks, whether the signing method matches
// [SigningMethodEdDSA].
func Ed25519PublicKey(key ed25519.PublicKey) Keyfunc {
return secureKeyFunc(key, []string{"EdDSA"})
}
23 changes: 11 additions & 12 deletions example_test.go
Expand Up @@ -80,9 +80,7 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, jwt.PresharedKey([]byte("AllYourBase")))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
Expand All @@ -103,9 +101,11 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))
token, err := jwt.ParseWithClaims(
tokenString, &MyCustomClaims{},
jwt.PresharedKey([]byte("AllYourBase")),
jwt.WithLeeway(5*time.Second),
)

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
Expand Down Expand Up @@ -138,9 +138,10 @@ func (m MyCustomClaims) Validate() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))
token, err := jwt.ParseWithClaims(
tokenString, &MyCustomClaims{},
jwt.PresharedKey([]byte("AllYourBase")),
jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
Expand All @@ -156,9 +157,7 @@ func ExampleParse_errorChecking() {
// Token from another example. This token is expired
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
token, err := jwt.Parse(tokenString, jwt.PresharedKey([]byte("AllYourBase")))

if token.Valid {
fmt.Println("You look nice today")
Expand Down
5 changes: 5 additions & 0 deletions hmac.go
Expand Up @@ -93,3 +93,8 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string,

return "", ErrInvalidKeyType
}

// PresharedKey represents a [Keyfunc] that simply returns the key specified in the byte slice.
func PresharedKey(key []byte) Keyfunc {
return secureKeyFunc(key, []string{"HS256", "HS384", "HS512"})
}
14 changes: 2 additions & 12 deletions parser.go
Expand Up @@ -3,7 +3,6 @@ package jwt
import (
"bytes"
"encoding/json"
"fmt"
"strings"
)

Expand Down Expand Up @@ -55,17 +54,8 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf

// Verify signing method is in the required set
if p.validMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.validMethods {
if m == alg {
signingMethodValid = true
break
}
}
if !signingMethodValid {
// signing method is not in the listed set
return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
if err = token.hasValidSigningMethod(p.validMethods); err != nil {
return token, err
}
}

Expand Down
7 changes: 7 additions & 0 deletions rsa.go
Expand Up @@ -99,3 +99,10 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string,
return "", err
}
}

// RSAPublicKey represents a [Keyfunc] that returns the RSA key specified in
// key. Furthermore, it checks, whether the signing method matches
// [SigningMethodRSA].
func RSAPublicKey(key *rsa.PublicKey) Keyfunc {
return secureKeyFunc(key, []string{"RS256", "RS384", "RS512"})
}
3 changes: 2 additions & 1 deletion test/helpers.go
Expand Up @@ -2,6 +2,7 @@ package test

import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"os"

Expand Down Expand Up @@ -56,7 +57,7 @@ func LoadECPrivateKeyFromDisk(location string) crypto.PrivateKey {
return key
}

func LoadECPublicKeyFromDisk(location string) crypto.PublicKey {
func LoadECPublicKeyFromDisk(location string) *ecdsa.PublicKey {
keyData, e := os.ReadFile(location)
if e != nil {
panic(e.Error())
Expand Down
34 changes: 34 additions & 0 deletions token.go
Expand Up @@ -3,6 +3,7 @@ package jwt
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)

Expand Down Expand Up @@ -143,3 +144,36 @@ func DecodeSegment(seg string) ([]byte, error) {
}
return encoding.DecodeString(seg)
}

// secureKeyFunc returns a secure [Keyfunc] for the specified key that also
// includes a signing method check.
func secureKeyFunc(key any, validMethods []string) Keyfunc {
return func(t *Token) (interface{}, error) {
// Check, if the signing method matches
if err := t.hasValidSigningMethod(validMethods); err != nil {
return nil, err
}

return key, nil
}
}

// hasValidSigningMethod is a utility function that checks, if the signing
// method of the token is included in the validMethods slice.
func (token *Token) hasValidSigningMethod(validMethods []string) error {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range validMethods {
if m == alg {
signingMethodValid = true
break
}
}

if !signingMethodValid {
// signing method is not in the listed set
return newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
}

return nil
}
71 changes: 58 additions & 13 deletions token_test.go
@@ -1,17 +1,17 @@
package jwt_test
package jwt

import (
"errors"
"reflect"
"testing"

"github.com/golang-jwt/jwt/v5"
)

func TestToken_SigningString(t1 *testing.T) {
type fields struct {
Raw string
Method jwt.SigningMethod
Method SigningMethod
Header map[string]interface{}
Claims jwt.Claims
Claims Claims
Signature string
Valid bool
}
Expand All @@ -25,12 +25,12 @@ func TestToken_SigningString(t1 *testing.T) {
name: "",
fields: fields{
Raw: "",
Method: jwt.SigningMethodHS256,
Method: SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
"alg": jwt.SigningMethodHS256.Alg(),
"alg": SigningMethodHS256.Alg(),
},
Claims: jwt.RegisteredClaims{},
Claims: RegisteredClaims{},
Signature: "",
Valid: false,
},
Expand All @@ -40,7 +40,7 @@ func TestToken_SigningString(t1 *testing.T) {
}
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &jwt.Token{
t := &Token{
Raw: tt.fields.Raw,
Method: tt.fields.Method,
Header: tt.fields.Header,
Expand All @@ -61,13 +61,13 @@ func TestToken_SigningString(t1 *testing.T) {
}

func BenchmarkToken_SigningString(b *testing.B) {
t := &jwt.Token{
Method: jwt.SigningMethodHS256,
t := &Token{
Method: SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
"alg": jwt.SigningMethodHS256.Alg(),
"alg": SigningMethodHS256.Alg(),
},
Claims: jwt.RegisteredClaims{},
Claims: RegisteredClaims{},
}
b.Run("BenchmarkToken_SigningString", func(b *testing.B) {
b.ResetTimer()
Expand All @@ -77,3 +77,48 @@ func BenchmarkToken_SigningString(b *testing.B) {
}
})
}

func Test_secureKeyFunc(t *testing.T) {
type fields struct {
token *Token
}
type args struct {
key any
validMethods []string
}
tests := []struct {
name string
fields fields
args args
wantKey any
wantErr error
}{
{
name: "invalid method",
fields: fields{&Token{Header: map[string]interface{}{"alg": "RS512"}, Method: SigningMethodRS512}},
args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}},
wantKey: nil,
wantErr: ErrTokenSignatureInvalid,
},
{
name: "correct method",
fields: fields{&Token{Header: map[string]interface{}{"alg": "HS256"}, Method: SigningMethodHS256}},
args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}},
wantKey: []byte("mysecret"),
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyfunc := secureKeyFunc(tt.args.key, tt.args.validMethods)
gotKey, gotErr := keyfunc(tt.fields.token)

if !reflect.DeepEqual(gotKey, tt.wantKey) {
t.Errorf("secureKeyFunc() key = %v, want %v", gotKey, tt.wantKey)
}
if (gotErr != nil) && !errors.Is(gotErr, tt.wantErr) {
t.Errorf("secureKeyFunc() err = %v, want %v", gotErr, tt.wantErr)
}
})
}
}

0 comments on commit 751179a

Please sign in to comment.