Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port clockskew support #139

Merged
merged 11 commits into from Mar 8, 2022
63 changes: 40 additions & 23 deletions claims.go
Expand Up @@ -9,7 +9,7 @@ import (
// Claims must just have a Valid method that determines
// if the token is invalid for any supported reason
type Claims interface {
Valid() error
Valid(options ...*ClaimsValidationOptions) error
oxisto marked this conversation as resolved.
Show resolved Hide resolved
}

// RegisteredClaims are a structured version of the JWT Claims Set,
Expand Down Expand Up @@ -48,13 +48,13 @@ type RegisteredClaims struct {
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (c RegisteredClaims) Valid() error {
func (c RegisteredClaims) Valid(opts ...*ClaimsValidationOptions) error {
vErr := new(ValidationError)
now := TimeFunc()

// The claims below are optional, by default, so if they are set to the
// default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) {
if !c.VerifyExpiresAt(now, false, opts...) {
delta := now.Sub(c.ExpiresAt.Time)
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Errors |= ValidationErrorExpired
Expand All @@ -65,7 +65,7 @@ func (c RegisteredClaims) Valid() error {
vErr.Errors |= ValidationErrorIssuedAt
}

if !c.VerifyNotBefore(now, false) {
if !c.VerifyNotBefore(now, false, opts...) {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
Expand All @@ -85,12 +85,16 @@ func (c *RegisteredClaims) VerifyAudience(cmp string, req bool) bool {

// VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
// If req is false, it will return true, if exp is unset.
func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool) bool {
func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool, opts ...*ClaimsValidationOptions) bool {
var s time.Duration
if len(opts) > 0 && opts[0] != nil {
s = opts[0].Leeway
}
if c.ExpiresAt == nil {
return verifyExp(nil, cmp, req)
return verifyExp(nil, cmp, req, s)
}

return verifyExp(&c.ExpiresAt.Time, cmp, req)
return verifyExp(&c.ExpiresAt.Time, cmp, req, s)
}

// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
Expand All @@ -105,12 +109,16 @@ func (c *RegisteredClaims) VerifyIssuedAt(cmp time.Time, req bool) bool {

// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool) bool {
func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool, opts ...*ClaimsValidationOptions) bool {
var s time.Duration
if len(opts) > 0 && opts[0] != nil {
s = opts[0].Leeway
}
if c.NotBefore == nil {
return verifyNbf(nil, cmp, req)
return verifyNbf(nil, cmp, req, s)
}

return verifyNbf(&c.NotBefore.Time, cmp, req)
return verifyNbf(&c.NotBefore.Time, cmp, req, s)
}

// VerifyIssuer compares the iss claim against cmp.
Expand Down Expand Up @@ -141,13 +149,13 @@ type StandardClaims struct {
// Valid validates time based claims "exp, iat, nbf". There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (c StandardClaims) Valid() error {
func (c StandardClaims) Valid(opts ...*ClaimsValidationOptions) error {
vErr := new(ValidationError)
now := TimeFunc().Unix()

// The claims below are optional, by default, so if they are set to the
// default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) {
if !c.VerifyExpiresAt(now, false, opts...) {
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Errors |= ValidationErrorExpired
Expand All @@ -158,7 +166,7 @@ func (c StandardClaims) Valid() error {
vErr.Errors |= ValidationErrorIssuedAt
}

if !c.VerifyNotBefore(now, false) {
if !c.VerifyNotBefore(now, false, opts...) {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
Expand All @@ -178,13 +186,17 @@ func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {

// VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
// If req is false, it will return true, if exp is unset.
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool, opts ...*ClaimsValidationOptions) bool {
var s time.Duration
if len(opts) > 0 && opts[0] != nil {
s = opts[0].Leeway
}
if c.ExpiresAt == 0 {
return verifyExp(nil, time.Unix(cmp, 0), req)
return verifyExp(nil, time.Unix(cmp, 0), req, s)
}

t := time.Unix(c.ExpiresAt, 0)
return verifyExp(&t, time.Unix(cmp, 0), req)
return verifyExp(&t, time.Unix(cmp, 0), req, s)
}

// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
Expand All @@ -200,13 +212,17 @@ func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {

// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool, opts ...*ClaimsValidationOptions) bool {
var s time.Duration
if len(opts) > 0 && opts[0] != nil {
s = opts[0].Leeway
}
if c.NotBefore == 0 {
return verifyNbf(nil, time.Unix(cmp, 0), req)
return verifyNbf(nil, time.Unix(cmp, 0), req, s)
}

t := time.Unix(c.NotBefore, 0)
return verifyNbf(&t, time.Unix(cmp, 0), req)
return verifyNbf(&t, time.Unix(cmp, 0), req, s)
}

// VerifyIssuer compares the iss claim against cmp.
Expand Down Expand Up @@ -240,11 +256,11 @@ func verifyAud(aud []string, cmp string, required bool) bool {
return result
}

func verifyExp(exp *time.Time, now time.Time, required bool) bool {
func verifyExp(exp *time.Time, now time.Time, required bool, clockSkew time.Duration) bool {
if exp == nil {
return !required
}
return now.Before(*exp)
return now.Before((*exp).Add(+clockSkew))
oxisto marked this conversation as resolved.
Show resolved Hide resolved
}

func verifyIat(iat *time.Time, now time.Time, required bool) bool {
Expand All @@ -254,11 +270,12 @@ func verifyIat(iat *time.Time, now time.Time, required bool) bool {
return now.After(*iat) || now.Equal(*iat)
}

func verifyNbf(nbf *time.Time, now time.Time, required bool) bool {
func verifyNbf(nbf *time.Time, now time.Time, required bool, clockSkew time.Duration) bool {
oxisto marked this conversation as resolved.
Show resolved Hide resolved
if nbf == nil {
return !required
}
return now.After(*nbf) || now.Equal(*nbf)
t := (*nbf).Add(-clockSkew)
return now.After(t) || now.Equal(t)
}

func verifyIss(iss string, cmp string, required bool) bool {
Expand Down
31 changes: 31 additions & 0 deletions claims_option.go
@@ -0,0 +1,31 @@
package jwt

import "time"

// ClaimsValidationOptions represents options that can be used for claims validation
type ClaimsValidationOptions struct {
oxisto marked this conversation as resolved.
Show resolved Hide resolved
Leeway time.Duration
}

func ClaimsValidation() *ClaimsValidationOptions {
return &ClaimsValidationOptions{}
}

func (c *ClaimsValidationOptions) SetClockSkew(d time.Duration) {
c.Leeway = d
}

// MergeClaimsValidationOptions combines the given ClaimsValidationOptions instancs into a single ClaimsValidationOptions
// in a last-one-wins fashion
func MergeClaimsValidationOptions(opts ...*ClaimsValidationOptions) *ClaimsValidationOptions {
c := ClaimsValidation()
for _, opt := range opts {
if opt == nil {
continue
}
if opt.Leeway != 0 {
c.Leeway = opt.Leeway
}
}
return c
}
32 changes: 21 additions & 11 deletions map_claims.go
Expand Up @@ -34,25 +34,30 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool {

// VerifyExpiresAt compares the exp claim against cmp (cmp <= exp).
// If req is false, it will return true, if exp is unset.
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool, opts ...*ClaimsValidationOptions) bool {
cmpTime := time.Unix(cmp, 0)

v, ok := m["exp"]
if !ok {
return !req
}

var s time.Duration
if len(opts) > 0 && opts[0] != nil {
s = opts[0].Leeway
}

switch exp := v.(type) {
case float64:
if exp == 0 {
return verifyExp(nil, cmpTime, req)
return verifyExp(nil, cmpTime, req, s)
}

return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req)
return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req, s)
case json.Number:
v, _ := exp.Float64()

return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req)
return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req, s)
}

return false
Expand Down Expand Up @@ -86,25 +91,30 @@ func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {

// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
func (m MapClaims) VerifyNotBefore(cmp int64, req bool, opts ...*ClaimsValidationOptions) bool {
cmpTime := time.Unix(cmp, 0)

v, ok := m["nbf"]
if !ok {
return !req
}

var s time.Duration
if len(opts) > 0 && opts[0] != nil {
s = opts[0].Leeway
}

switch nbf := v.(type) {
case float64:
if nbf == 0 {
return verifyNbf(nil, cmpTime, req)
return verifyNbf(nil, cmpTime, req, s)
}

return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req)
return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req, s)
case json.Number:
v, _ := nbf.Float64()

return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req)
return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req, s)
}

return false
Expand All @@ -121,11 +131,11 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (m MapClaims) Valid() error {
func (m MapClaims) Valid(opts ...*ClaimsValidationOptions) error {
vErr := new(ValidationError)
now := TimeFunc().Unix()

if !m.VerifyExpiresAt(now, false) {
if !m.VerifyExpiresAt(now, false, opts...) {
vErr.Inner = errors.New("Token is expired")
vErr.Errors |= ValidationErrorExpired
}
Expand All @@ -135,7 +145,7 @@ func (m MapClaims) Valid() error {
vErr.Errors |= ValidationErrorIssuedAt
}

if !m.VerifyNotBefore(now, false) {
if !m.VerifyNotBefore(now, false, opts...) {
vErr.Inner = errors.New("Token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
Expand Down
4 changes: 3 additions & 1 deletion parser.go
Expand Up @@ -22,6 +22,8 @@ type Parser struct {
//
// Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead.
SkipClaimsValidation bool

options []*ClaimsValidationOptions
}

// NewParser creates a new Parser with the specified options
Expand Down Expand Up @@ -82,7 +84,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf

// Validate Claims
if !p.SkipClaimsValidation {
if err := token.Claims.Valid(); err != nil {
if err := token.Claims.Valid(MergeClaimsValidationOptions(p.options...)); err != nil {

// If the Claims Valid returned an error, check if it is a validation error,
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
Expand Down
8 changes: 8 additions & 0 deletions parser_option.go
@@ -1,5 +1,7 @@
package jwt

import "time"

// ParserOption is used to implement functional-style options that modify the behaviour of the parser. To add
// new options, just create a function (ideally beginning with With or Without) that returns an anonymous function that
// takes a *Parser type as input and manipulates its configuration accordingly.
Expand Down Expand Up @@ -27,3 +29,9 @@ func WithoutClaimsValidation() ParserOption {
p.SkipClaimsValidation = true
}
}

func WithLeeway(d time.Duration) ParserOption {
return func(p *Parser) {
p.options = append(p.options, &ClaimsValidationOptions{Leeway: d})
}
}
40 changes: 40 additions & 0 deletions parser_test.go
Expand Up @@ -74,6 +74,26 @@ var jwtTestData = []struct {
nil,
jwt.SigningMethodRS256,
},
{
"basic expired with 60s skew",
"", // autogen
defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
false,
jwt.ValidationErrorExpired,
jwt.NewParser(jwt.WithLeeway(time.Minute)),
jwt.SigningMethodRS256,
},
{
"basic expired with 120s skew",
"", // autogen
defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
true,
0,
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
jwt.SigningMethodRS256,
},
{
"basic nbf",
"", // autogen
Expand All @@ -84,6 +104,26 @@ var jwtTestData = []struct {
nil,
jwt.SigningMethodRS256,
},
{
"basic nbf with 60s skew",
"", // autogen
defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
false,
jwt.ValidationErrorNotValidYet,
jwt.NewParser(jwt.WithLeeway(time.Minute)),
jwt.SigningMethodRS256,
},
{
"basic nbf with 120s skew",
"", // autogen
defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
true,
0,
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
jwt.SigningMethodRS256,
},
{
"expired and nbf",
"", // autogen
Expand Down