From 99736c34b5dfc9b816210b875a77b2d8fc4c5116 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Wed, 19 Jan 2022 22:55:19 +0100 Subject: [PATCH] Implementing `Is(err) bool` to support Go 1.13 style error checking (#136) --- claims.go | 12 +++++------ errors.go | 48 +++++++++++++++++++++++++++++++++++++++++++ example_test.go | 15 ++++++-------- map_claims.go | 3 +++ parser_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++------ 5 files changed, 111 insertions(+), 21 deletions(-) diff --git a/claims.go b/claims.go index 41cc8265..4f00db2f 100644 --- a/claims.go +++ b/claims.go @@ -56,17 +56,17 @@ func (c RegisteredClaims) Valid() error { // default value in Go, let's not fail the verification for them. if !c.VerifyExpiresAt(now, false) { delta := now.Sub(c.ExpiresAt.Time) - vErr.Inner = fmt.Errorf("token is expired by %v", delta) + vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired) vErr.Errors |= ValidationErrorExpired } if !c.VerifyIssuedAt(now, false) { - vErr.Inner = fmt.Errorf("token used before issued") + vErr.Inner = ErrTokenUsedBeforeIssued vErr.Errors |= ValidationErrorIssuedAt } if !c.VerifyNotBefore(now, false) { - vErr.Inner = fmt.Errorf("token is not valid yet") + vErr.Inner = ErrTokenNotValidYet vErr.Errors |= ValidationErrorNotValidYet } @@ -149,17 +149,17 @@ func (c StandardClaims) Valid() error { // default value in Go, let's not fail the verification for them. if !c.VerifyExpiresAt(now, false) { delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) - vErr.Inner = fmt.Errorf("token is expired by %v", delta) + vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired) vErr.Errors |= ValidationErrorExpired } if !c.VerifyIssuedAt(now, false) { - vErr.Inner = fmt.Errorf("token used before issued") + vErr.Inner = ErrTokenUsedBeforeIssued vErr.Errors |= ValidationErrorIssuedAt } if !c.VerifyNotBefore(now, false) { - vErr.Inner = fmt.Errorf("token is not valid yet") + vErr.Inner = ErrTokenNotValidYet vErr.Errors |= ValidationErrorNotValidYet } diff --git a/errors.go b/errors.go index b9d18e49..10ac8835 100644 --- a/errors.go +++ b/errors.go @@ -9,6 +9,18 @@ var ( ErrInvalidKey = errors.New("key is invalid") ErrInvalidKeyType = errors.New("key is of invalid type") ErrHashUnavailable = errors.New("the requested hash function is unavailable") + + ErrTokenMalformed = errors.New("token is malformed") + ErrTokenUnverifiable = errors.New("token is unverifiable") + ErrTokenSignatureInvalid = errors.New("token signature is invalid") + + ErrTokenInvalidAudience = errors.New("token has invalid audience") + ErrTokenExpired = errors.New("token is expired") + ErrTokenUsedBeforeIssued = errors.New("token used before issued") + ErrTokenInvalidIssuer = errors.New("token has invalid issuer") + ErrTokenNotValidYet = errors.New("token is not valid yet") + ErrTokenInvalidId = errors.New("token has invalid id") + ErrTokenInvalidClaims = errors.New("token has invalid claims") ) // The errors that might occur when parsing and validating a token @@ -62,3 +74,39 @@ func (e *ValidationError) Unwrap() error { func (e *ValidationError) valid() bool { return e.Errors == 0 } + +// Is checks if this ValidationError is of the supplied error. We are first checking for the exact error message +// by comparing the inner error message. If that fails, we compare using the error flags. This way we can use +// custom error messages (mainly for backwards compatability) and still leverage errors.Is using the global error variables. +func (e *ValidationError) Is(err error) bool { + // Check, if our inner error is a direct match + if errors.Is(errors.Unwrap(e), err) { + return true + } + + // Otherwise, we need to match using our error flags + switch err { + case ErrTokenMalformed: + return e.Errors&ValidationErrorMalformed != 0 + case ErrTokenUnverifiable: + return e.Errors&ValidationErrorUnverifiable != 0 + case ErrTokenSignatureInvalid: + return e.Errors&ValidationErrorSignatureInvalid != 0 + case ErrTokenInvalidAudience: + return e.Errors&ValidationErrorAudience != 0 + case ErrTokenExpired: + return e.Errors&ValidationErrorExpired != 0 + case ErrTokenUsedBeforeIssued: + return e.Errors&ValidationErrorIssuedAt != 0 + case ErrTokenInvalidIssuer: + return e.Errors&ValidationErrorIssuer != 0 + case ErrTokenNotValidYet: + return e.Errors&ValidationErrorNotValidYet != 0 + case ErrTokenInvalidId: + return e.Errors&ValidationErrorId != 0 + case ErrTokenInvalidClaims: + return e.Errors&ValidationErrorClaimsInvalid != 0 + } + + return false +} diff --git a/example_test.go b/example_test.go index 7815757b..ddf49ccb 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "errors" "fmt" "time" @@ -103,15 +104,11 @@ func ExampleParse_errorChecking() { if token.Valid { fmt.Println("You look nice today") - } else if ve, ok := err.(*jwt.ValidationError); ok { - if ve.Errors&jwt.ValidationErrorMalformed != 0 { - fmt.Println("That's not even a token") - } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { - // Token is either expired or not active yet - fmt.Println("Timing is everything") - } else { - fmt.Println("Couldn't handle this token:", err) - } + } else if errors.Is(err, jwt.ErrTokenMalformed) { + fmt.Println("That's not even a token") + } else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) { + // Token is either expired or not active yet + fmt.Println("Timing is everything") } else { fmt.Println("Couldn't handle this token:", err) } diff --git a/map_claims.go b/map_claims.go index e7da633b..2700d64a 100644 --- a/map_claims.go +++ b/map_claims.go @@ -126,16 +126,19 @@ func (m MapClaims) Valid() error { now := TimeFunc().Unix() if !m.VerifyExpiresAt(now, false) { + // TODO(oxisto): this should be replaced with ErrTokenExpired vErr.Inner = errors.New("Token is expired") vErr.Errors |= ValidationErrorExpired } if !m.VerifyIssuedAt(now, false) { + // TODO(oxisto): this should be replaced with ErrTokenUsedBeforeIssued vErr.Inner = errors.New("Token used before issued") vErr.Errors |= ValidationErrorIssuedAt } if !m.VerifyNotBefore(now, false) { + // TODO(oxisto): this should be replaced with ErrTokenNotValidYet vErr.Inner = errors.New("Token is not valid yet") vErr.Errors |= ValidationErrorNotValidYet } diff --git a/parser_test.go b/parser_test.go index 7a7bf0ab..68aa6a93 100644 --- a/parser_test.go +++ b/parser_test.go @@ -4,6 +4,7 @@ import ( "crypto" "crypto/rsa" "encoding/json" + "errors" "fmt" "reflect" "testing" @@ -51,6 +52,7 @@ var jwtTestData = []struct { claims jwt.Claims valid bool errors uint32 + err []error parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose }{ @@ -62,6 +64,7 @@ var jwtTestData = []struct { true, 0, nil, + nil, jwt.SigningMethodRS256, }, { @@ -71,6 +74,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, jwt.ValidationErrorExpired, + []error{jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, }, @@ -81,6 +85,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, jwt.ValidationErrorNotValidYet, + []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, }, @@ -91,6 +96,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, }, @@ -101,6 +107,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, nil, jwt.SigningMethodRS256, }, @@ -111,6 +118,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorUnverifiable, + []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, }, @@ -121,6 +129,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, }, @@ -131,6 +140,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorUnverifiable, + []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, nil, jwt.SigningMethodRS256, }, @@ -141,6 +151,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + []error{jwt.ErrTokenSignatureInvalid}, &jwt.Parser{ValidMethods: []string{"HS256"}}, jwt.SigningMethodRS256, }, @@ -151,6 +162,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, true, 0, + nil, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, jwt.SigningMethodRS256, }, @@ -161,6 +173,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + []error{jwt.ErrTokenSignatureInvalid}, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, jwt.SigningMethodES256, }, @@ -171,6 +184,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, true, 0, + nil, &jwt.Parser{ValidMethods: []string{"HS256", "ES256"}}, jwt.SigningMethodES256, }, @@ -181,6 +195,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": json.Number("123.4")}, true, 0, + nil, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -193,6 +208,7 @@ var jwtTestData = []struct { }, true, 0, + nil, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -203,6 +219,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, jwt.ValidationErrorExpired, + []error{jwt.ErrTokenExpired}, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -213,6 +230,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, false, jwt.ValidationErrorNotValidYet, + []error{jwt.ErrTokenNotValidYet}, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -223,6 +241,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + []error{jwt.ErrTokenNotValidYet}, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -233,6 +252,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, true, 0, + nil, &jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true}, jwt.SigningMethodRS256, }, @@ -245,6 +265,7 @@ var jwtTestData = []struct { }, true, 0, + nil, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -257,6 +278,7 @@ var jwtTestData = []struct { }, true, 0, + nil, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -269,6 +291,7 @@ var jwtTestData = []struct { }, true, 0, + nil, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -281,6 +304,7 @@ var jwtTestData = []struct { }, false, jwt.ValidationErrorMalformed, + []error{jwt.ErrTokenMalformed}, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -293,6 +317,7 @@ var jwtTestData = []struct { }, false, jwt.ValidationErrorMalformed, + []error{jwt.ErrTokenMalformed}, &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, @@ -325,6 +350,7 @@ func TestParser_Parse(t *testing.T) { // Parse the token var token *jwt.Token + var ve *jwt.ValidationError var err error var parser = data.parser if parser == nil { @@ -361,18 +387,34 @@ func TestParser_Parse(t *testing.T) { if err == nil { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { + if errors.As(err, &ve) { + // compare the bitfield part of the error + if e := ve.Errors; e != data.errors { + t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) + } + + if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError { + t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError) + } + } + } + } - ve := err.(*jwt.ValidationError) - // compare the bitfield part of the error - if e := ve.Errors; e != data.errors { - t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) + if data.err != nil { + if err == nil { + t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name) + } else { + var all = false + for _, e := range data.err { + all = errors.Is(err, e) } - if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError { - t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError) + if !all { + t.Errorf("[%v] Errors don't match expectation. %v should contain all of %v", data.name, err, data.err) } } } + if data.valid { if token.Signature == "" { t.Errorf("[%v] Signature is left unpopulated after parsing", data.name)