Skip to content

Commit

Permalink
Fix typo in jwt package (#1070)
Browse files Browse the repository at this point in the history
* Fix typo in jwt package

    * change JWTTokenContextKey to JWTContextKey

    * revise errors as well as comments

* Revision of typo fixing in jwt package

	* revert the API identifier to its previous value

        * add JWTTokenContextKey side by side of JWTContextKey for historical compatibility

* Fixed a typo in JWT package (#1071)

	* define JWTContextKey as a new constant

        * mark JWTTokenContextKey as a deprecated constant

        * revise corresponding error messages

* Fixed a typo in JWT package (#1071)

	* define JWTContextKey as a new constant

        * mark JWTTokenContextKey as a deprecated constant

        * revise corresponding error messages

Co-authored-by: Amid <amid.dev@protonmail.com>
  • Loading branch information
amidam and Amid committed Jun 19, 2021
1 parent 40533ee commit a5c7103
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 48 deletions.
4 changes: 2 additions & 2 deletions auth/jwt/README.md
Expand Up @@ -7,7 +7,7 @@ through [JSON Web Tokens](https://jwt.io/).

NewParser takes a key function and an expected signing method and returns an
`endpoint.Middleware`. The middleware will parse a token passed into the
context via the `jwt.JWTTokenContextKey`. If the token is valid, any claims
context via the `jwt.JWTContextKey`. If the token is valid, any claims
will be added to the context via the `jwt.JWTClaimsContextKey`.

```go
Expand All @@ -30,7 +30,7 @@ func main() {

NewSigner takes a JWT key ID header, the signing key, signing method, and a
claims object. It returns an `endpoint.Middleware`. The middleware will build
the token string and add it to the context via the `jwt.JWTTokenContextKey`.
the token string and add it to the context via the `jwt.JWTContextKey`.

```go
import (
Expand Down
26 changes: 15 additions & 11 deletions auth/jwt/middleware.go
Expand Up @@ -12,9 +12,13 @@ import (
type contextKey string

const (
// JWTTokenContextKey holds the key used to store a JWT Token in the
// context.
JWTTokenContextKey contextKey = "JWTToken"
// JWTContextKey holds the key used to store a JWT in the context.
JWTContextKey contextKey = "JWTToken"

// JWTTokenContextKey is an alias for JWTContextKey.
//
// Deprecated: prefer JWTContextKey.
JWTTokenContextKey = JWTContextKey

// JWTClaimsContextKey holds the key used to store the JWT Claims in the
// context.
Expand All @@ -27,13 +31,13 @@ var (
ErrTokenContextMissing = errors.New("token up for parsing was not passed through the context")

// ErrTokenInvalid denotes a token was not able to be validated.
ErrTokenInvalid = errors.New("JWT Token was invalid")
ErrTokenInvalid = errors.New("JWT was invalid")

// ErrTokenExpired denotes a token's expire header (exp) has since passed.
ErrTokenExpired = errors.New("JWT Token is expired")
ErrTokenExpired = errors.New("JWT is expired")

// ErrTokenMalformed denotes a token was not formatted as a JWT token.
ErrTokenMalformed = errors.New("JWT Token is malformed")
// ErrTokenMalformed denotes a token was not formatted as a JWT.
ErrTokenMalformed = errors.New("JWT is malformed")

// ErrTokenNotActive denotes a token's not before header (nbf) is in the
// future.
Expand All @@ -44,7 +48,7 @@ var (
ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
)

// NewSigner creates a new JWT token generating middleware, specifying key ID,
// NewSigner creates a new JWT generating middleware, specifying key ID,
// signing string, signing method and the claims you would like it to contain.
// Tokens are signed with a Key ID header (kid) which is useful for determining
// the key to use for parsing. Particularly useful for clients.
Expand All @@ -59,7 +63,7 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString)
ctx = context.WithValue(ctx, JWTContextKey, tokenString)

return next(ctx, request)
}
Expand All @@ -82,15 +86,15 @@ func StandardClaimsFactory() jwt.Claims {
return &jwt.StandardClaims{}
}

// NewParser creates a new JWT token parsing middleware, specifying a
// NewParser creates a new JWT parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
tokenString, ok := ctx.Value(JWTTokenContextKey).(string)
tokenString, ok := ctx.Value(JWTContextKey).(string)
if !ok {
return nil, ErrTokenContextMissing
}
Expand Down
24 changes: 12 additions & 12 deletions auth/jwt/middleware_test.go
Expand Up @@ -44,13 +44,13 @@ func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string
t.Fatalf("Signer returned error: %s", err)
}

token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string)
token, ok := ctx.(context.Context).Value(JWTContextKey).(string)
if !ok {
t.Fatal("Token did not exist in context")
}

if token != expectedKey {
t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token)
t.Fatalf("JWTs did not match: expecting %s got %s", expectedKey, token)
}
}

Expand Down Expand Up @@ -87,15 +87,15 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Token is passed into the parser
ctx := context.WithValue(context.Background(), JWTTokenContextKey, invalidKey)
ctx := context.WithValue(context.Background(), JWTContextKey, invalidKey)
_, err = parser(ctx, struct{}{})
if err == nil {
t.Error("Parser should have returned an error")
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
t.Error("Parser should have returned an error")
Expand All @@ -111,14 +111,14 @@ func TestJWTParser(t *testing.T) {
}

badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
t.Error("Parser should have returned an error")
}

// Correct token is passed into the parser
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
ctx1, err := parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
Expand All @@ -135,7 +135,7 @@ func TestJWTParser(t *testing.T) {

// Test for malformed token error response
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
Expand All @@ -148,7 +148,7 @@ func TestJWTParser(t *testing.T) {
if err != nil {
t.Fatalf("Unable to Sign Token: %+v", err)
}
ctx = context.WithValue(context.Background(), JWTTokenContextKey, token)
ctx = context.WithValue(context.Background(), JWTContextKey, token)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenExpired, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
Expand All @@ -161,15 +161,15 @@ func TestJWTParser(t *testing.T) {
if err != nil {
t.Fatalf("Unable to Sign Token: %+v", err)
}
ctx = context.WithValue(context.Background(), JWTTokenContextKey, token)
ctx = context.WithValue(context.Background(), JWTContextKey, token)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenNotActive, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// test valid standard claims token
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
Expand All @@ -184,7 +184,7 @@ func TestJWTParser(t *testing.T) {

// test valid customized claims token
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
Expand All @@ -205,7 +205,7 @@ func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
key = JWTTokenContextKey
key = JWTContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
)
Expand Down
8 changes: 4 additions & 4 deletions auth/jwt/transport.go
Expand Up @@ -26,15 +26,15 @@ func HTTPToContext() http.RequestFunc {
return ctx
}

return context.WithValue(ctx, JWTTokenContextKey, token)
return context.WithValue(ctx, JWTContextKey, token)
}
}

// ContextToHTTP moves a JWT from context to request header. Particularly
// useful for clients.
func ContextToHTTP() http.RequestFunc {
return func(ctx context.Context, r *stdhttp.Request) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
token, ok := ctx.Value(JWTContextKey).(string)
if ok {
r.Header.Add("Authorization", generateAuthHeaderFromToken(token))
}
Expand All @@ -54,7 +54,7 @@ func GRPCToContext() grpc.ServerRequestFunc {

token, ok := extractTokenFromAuthHeader(authHeader[0])
if ok {
ctx = context.WithValue(ctx, JWTTokenContextKey, token)
ctx = context.WithValue(ctx, JWTContextKey, token)
}

return ctx
Expand All @@ -65,7 +65,7 @@ func GRPCToContext() grpc.ServerRequestFunc {
// useful for clients.
func ContextToGRPC() grpc.ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
token, ok := ctx.Value(JWTContextKey).(string)
if ok {
// capital "Key" is illegal in HTTP/2.
(*md)["authorization"] = []string{generateAuthHeaderFromToken(token)}
Expand Down
38 changes: 19 additions & 19 deletions auth/jwt/transport_test.go
Expand Up @@ -15,7 +15,7 @@ func TestHTTPToContext(t *testing.T) {
// When the header doesn't exist
ctx := reqFunc(context.Background(), &http.Request{})

if ctx.Value(JWTTokenContextKey) != nil {
if ctx.Value(JWTContextKey) != nil {
t.Error("Context shouldn't contain the encoded JWT")
}

Expand All @@ -24,15 +24,15 @@ func TestHTTPToContext(t *testing.T) {
header.Set("Authorization", "no expected auth header format value")
ctx = reqFunc(context.Background(), &http.Request{Header: header})

if ctx.Value(JWTTokenContextKey) != nil {
if ctx.Value(JWTContextKey) != nil {
t.Error("Context shouldn't contain the encoded JWT")
}

// Authorization header is correct
header.Set("Authorization", generateAuthHeaderFromToken(signedKey))
ctx = reqFunc(context.Background(), &http.Request{Header: header})

token := ctx.Value(JWTTokenContextKey).(string)
token := ctx.Value(JWTContextKey).(string)
if token != signedKey {
t.Errorf("Context doesn't contain the expected encoded token value; expected: %s, got: %s", signedKey, token)
}
Expand All @@ -41,7 +41,7 @@ func TestHTTPToContext(t *testing.T) {
func TestContextToHTTP(t *testing.T) {
reqFunc := ContextToHTTP()

// No JWT Token is passed in the context
// No JWT is passed in the context
ctx := context.Background()
r := http.Request{}
reqFunc(ctx, &r)
Expand All @@ -51,16 +51,16 @@ func TestContextToHTTP(t *testing.T) {
t.Error("authorization key should not exist in metadata")
}

// Correct JWT Token is passed in the context
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
// Correct JWT is passed in the context
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
r = http.Request{Header: http.Header{}}
reqFunc(ctx, &r)

token = r.Header.Get("Authorization")
expected := generateAuthHeaderFromToken(signedKey)

if token != expected {
t.Errorf("Authorization header does not contain the expected JWT token; expected %s, got %s", expected, token)
t.Errorf("Authorization header does not contain the expected JWT; expected %s, got %s", expected, token)
}
}

Expand All @@ -70,36 +70,36 @@ func TestGRPCToContext(t *testing.T) {

// No Authorization header is passed
ctx := reqFunc(context.Background(), md)
token := ctx.Value(JWTTokenContextKey)
token := ctx.Value(JWTContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
t.Error("Context should not contain a JWT")
}

// Invalid Authorization header is passed
md["authorization"] = []string{fmt.Sprintf("%s", signedKey)}
ctx = reqFunc(context.Background(), md)
token = ctx.Value(JWTTokenContextKey)
token = ctx.Value(JWTContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
t.Error("Context should not contain a JWT")
}

// Authorization header is correct
md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)}
ctx = reqFunc(context.Background(), md)
token, ok := ctx.Value(JWTTokenContextKey).(string)
token, ok := ctx.Value(JWTContextKey).(string)
if !ok {
t.Fatal("JWT Token not passed to context correctly")
t.Fatal("JWT not passed to context correctly")
}

if token != signedKey {
t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token)
t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token)
}
}

func TestContextToGRPC(t *testing.T) {
reqFunc := ContextToGRPC()

// No JWT Token is passed in the context
// No JWT is passed in the context
ctx := context.Background()
md := metadata.MD{}
reqFunc(ctx, &md)
Expand All @@ -109,17 +109,17 @@ func TestContextToGRPC(t *testing.T) {
t.Error("authorization key should not exist in metadata")
}

// Correct JWT Token is passed in the context
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
// Correct JWT is passed in the context
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
md = metadata.MD{}
reqFunc(ctx, &md)

token, ok := md["authorization"]
if !ok {
t.Fatal("JWT Token not passed to metadata correctly")
t.Fatal("JWT not passed to metadata correctly")
}

if token[0] != generateAuthHeaderFromToken(signedKey) {
t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token[0])
t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token[0])
}
}

0 comments on commit a5c7103

Please sign in to comment.