Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Is(err) bool to support Go 1.13 style error checking #136

Merged
merged 3 commits into from Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions claims.go
Expand Up @@ -56,17 +56,17 @@ func (c RegisteredClaims) Valid() error {
// default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) {
delta := now.Sub(c.ExpiresAt.Time)
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired)
vErr.Errors |= ValidationErrorExpired
}

if !c.VerifyIssuedAt(now, false) {
vErr.Inner = fmt.Errorf("token used before issued")
vErr.Inner = ErrTokenUsedBeforeIssued
vErr.Errors |= ValidationErrorIssuedAt
}

if !c.VerifyNotBefore(now, false) {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Inner = ErrTokenNotValidYet
vErr.Errors |= ValidationErrorNotValidYet
}

Expand Down Expand Up @@ -143,17 +143,17 @@ func (c StandardClaims) Valid() error {
// default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) {
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired)
vErr.Errors |= ValidationErrorExpired
}

if !c.VerifyIssuedAt(now, false) {
vErr.Inner = fmt.Errorf("token used before issued")
vErr.Inner = ErrTokenUsedBeforeIssued
vErr.Errors |= ValidationErrorIssuedAt
}

if !c.VerifyNotBefore(now, false) {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Inner = ErrTokenNotValidYet
vErr.Errors |= ValidationErrorNotValidYet
}

Expand Down
48 changes: 48 additions & 0 deletions errors.go
Expand Up @@ -9,6 +9,18 @@ var (
ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")

ErrTokenMalformed = errors.New("token is malformed")
ErrTokenUnverifiable = errors.New("token is unverifiable")
ErrTokenSignatureInvalid = errors.New("token signature is invalid")

ErrTokenInvalidAudience = errors.New("token has invalid audience")
ErrTokenExpired = errors.New("token is expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrTokenInvalidIssuer = errors.New("token has invalid issuer")
ErrTokenNotValidYet = errors.New("token is not valid yet")
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
)

// The errors that might occur when parsing and validating a token
Expand Down Expand Up @@ -62,3 +74,39 @@ func (e *ValidationError) Unwrap() error {
func (e *ValidationError) valid() bool {
return e.Errors == 0
}

// Is checks if this ValidationError is of the supplied error. We are first checking for the exact error message
// by comparing the inner error message. If that fails, we compare using the error flags. This way we can use
// custom error messages (mainly for backwards compatability) and still leverage errors.Is using the global error variables.
func (e *ValidationError) Is(err error) bool {
// Check, if our inner error is a direct match
if errors.Is(errors.Unwrap(e), err) {
return true
}

// Otherwise, we need to match using our error flags
switch err {
case ErrTokenMalformed:
return e.Errors&ValidationErrorMalformed != 0
case ErrTokenUnverifiable:
return e.Errors&ValidationErrorUnverifiable != 0
case ErrTokenSignatureInvalid:
return e.Errors&ValidationErrorSignatureInvalid != 0
case ErrTokenInvalidAudience:
return e.Errors&ValidationErrorAudience != 0
case ErrTokenExpired:
return e.Errors&ValidationErrorExpired != 0
case ErrTokenUsedBeforeIssued:
return e.Errors&ValidationErrorIssuedAt != 0
case ErrTokenInvalidIssuer:
return e.Errors&ValidationErrorIssuer != 0
case ErrTokenNotValidYet:
return e.Errors&ValidationErrorNotValidYet != 0
case ErrTokenInvalidId:
return e.Errors&ValidationErrorId != 0
case ErrTokenInvalidClaims:
return e.Errors&ValidationErrorClaimsInvalid != 0
}

return false
}
15 changes: 6 additions & 9 deletions example_test.go
@@ -1,6 +1,7 @@
package jwt_test

import (
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -103,15 +104,11 @@ func ExampleParse_errorChecking() {

if token.Valid {
fmt.Println("You look nice today")
} else if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
fmt.Println("That's not even a token")
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
// Token is either expired or not active yet
fmt.Println("Timing is everything")
} else {
fmt.Println("Couldn't handle this token:", err)
}
} else if errors.Is(err, jwt.ErrTokenMalformed) {
fmt.Println("That's not even a token")
} else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) {
// Token is either expired or not active yet
fmt.Println("Timing is everything")
} else {
fmt.Println("Couldn't handle this token:", err)
}
Expand Down
3 changes: 3 additions & 0 deletions map_claims.go
Expand Up @@ -126,16 +126,19 @@ func (m MapClaims) Valid() error {
now := TimeFunc().Unix()

if !m.VerifyExpiresAt(now, false) {
// TODO(oxisto): this should be replaced with ErrTokenExpired
vErr.Inner = errors.New("Token is expired")
vErr.Errors |= ValidationErrorExpired
}

if !m.VerifyIssuedAt(now, false) {
// TODO(oxisto): this should be replaced with ErrTokenUsedBeforeIssued
vErr.Inner = errors.New("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}

if !m.VerifyNotBefore(now, false) {
// TODO(oxisto): this should be replaced with ErrTokenNotValidYet
vErr.Inner = errors.New("Token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
Expand Down
54 changes: 48 additions & 6 deletions parser_test.go
Expand Up @@ -4,6 +4,7 @@ import (
"crypto"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"reflect"
"testing"
Expand Down Expand Up @@ -51,6 +52,7 @@ var jwtTestData = []struct {
claims jwt.Claims
valid bool
errors uint32
err []error
parser *jwt.Parser
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
}{
Expand All @@ -62,6 +64,7 @@ var jwtTestData = []struct {
true,
0,
nil,
nil,
jwt.SigningMethodRS256,
},
{
Expand All @@ -71,6 +74,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
false,
jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenExpired},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -81,6 +85,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
false,
jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -91,6 +96,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenNotValidYet},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -101,6 +107,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -111,6 +118,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorUnverifiable,
[]error{jwt.ErrTokenUnverifiable},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -121,6 +129,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -131,6 +140,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorUnverifiable,
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError},
nil,
jwt.SigningMethodRS256,
},
Expand All @@ -141,6 +151,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid},
&jwt.Parser{ValidMethods: []string{"HS256"}},
jwt.SigningMethodRS256,
},
Expand All @@ -151,6 +162,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
true,
0,
nil,
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
jwt.SigningMethodRS256,
},
Expand All @@ -161,6 +173,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid},
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
jwt.SigningMethodES256,
},
Expand All @@ -171,6 +184,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
true,
0,
nil,
&jwt.Parser{ValidMethods: []string{"HS256", "ES256"}},
jwt.SigningMethodES256,
},
Expand All @@ -181,6 +195,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": json.Number("123.4")},
true,
0,
nil,
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -193,6 +208,7 @@ var jwtTestData = []struct {
},
true,
0,
nil,
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -203,6 +219,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
false,
jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenExpired},
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -213,6 +230,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
false,
jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet},
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -223,6 +241,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenNotValidYet},
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -233,6 +252,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
true,
0,
nil,
&jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true},
jwt.SigningMethodRS256,
},
Expand All @@ -245,6 +265,7 @@ var jwtTestData = []struct {
},
true,
0,
nil,
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -257,6 +278,7 @@ var jwtTestData = []struct {
},
true,
0,
nil,
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -269,6 +291,7 @@ var jwtTestData = []struct {
},
true,
0,
nil,
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -281,6 +304,7 @@ var jwtTestData = []struct {
},
false,
jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed},
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand All @@ -293,6 +317,7 @@ var jwtTestData = []struct {
},
false,
jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed},
&jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256,
},
Expand Down Expand Up @@ -325,6 +350,7 @@ func TestParser_Parse(t *testing.T) {

// Parse the token
var token *jwt.Token
var ve *jwt.ValidationError
var err error
var parser = data.parser
if parser == nil {
Expand Down Expand Up @@ -361,18 +387,34 @@ func TestParser_Parse(t *testing.T) {
if err == nil {
t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
} else {
if errors.As(err, &ve) {
// compare the bitfield part of the error
if e := ve.Errors; e != data.errors {
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
}

if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
}
}
}
}

ve := err.(*jwt.ValidationError)
// compare the bitfield part of the error
if e := ve.Errors; e != data.errors {
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
if data.err != nil {
if err == nil {
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
} else {
var all = false
for _, e := range data.err {
all = errors.Is(err, e)
}

if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
if !all {
t.Errorf("[%v] Errors don't match expectation. %v should contain all of %v", data.name, err, data.err)
}
}
}

if data.valid {
if token.Signature == "" {
t.Errorf("[%v] Signature is left unpopulated after parsing", data.name)
Expand Down