Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typo in jwt package #1070

Merged
merged 5 commits into from Jun 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions auth/jwt/README.md
Expand Up @@ -7,7 +7,7 @@ through [JSON Web Tokens](https://jwt.io/).

NewParser takes a key function and an expected signing method and returns an
`endpoint.Middleware`. The middleware will parse a token passed into the
context via the `jwt.JWTTokenContextKey`. If the token is valid, any claims
context via the `jwt.JWTContextKey`. If the token is valid, any claims
will be added to the context via the `jwt.JWTClaimsContextKey`.

```go
Expand All @@ -30,7 +30,7 @@ func main() {

NewSigner takes a JWT key ID header, the signing key, signing method, and a
claims object. It returns an `endpoint.Middleware`. The middleware will build
the token string and add it to the context via the `jwt.JWTTokenContextKey`.
the token string and add it to the context via the `jwt.JWTContextKey`.

```go
import (
Expand Down
26 changes: 15 additions & 11 deletions auth/jwt/middleware.go
Expand Up @@ -12,9 +12,13 @@ import (
type contextKey string

const (
// JWTTokenContextKey holds the key used to store a JWT Token in the
// context.
JWTTokenContextKey contextKey = "JWTToken"
// JWTContextKey holds the key used to store a JWT in the context.
JWTContextKey contextKey = "JWTToken"

// JWTTokenContextKey is an alias for JWTContextKey.
//
// Deprecated: prefer JWTContextKey.
JWTTokenContextKey = JWTContextKey

// JWTClaimsContextKey holds the key used to store the JWT Claims in the
// context.
Expand All @@ -27,13 +31,13 @@ var (
ErrTokenContextMissing = errors.New("token up for parsing was not passed through the context")

// ErrTokenInvalid denotes a token was not able to be validated.
ErrTokenInvalid = errors.New("JWT Token was invalid")
ErrTokenInvalid = errors.New("JWT was invalid")

// ErrTokenExpired denotes a token's expire header (exp) has since passed.
ErrTokenExpired = errors.New("JWT Token is expired")
ErrTokenExpired = errors.New("JWT is expired")

// ErrTokenMalformed denotes a token was not formatted as a JWT token.
ErrTokenMalformed = errors.New("JWT Token is malformed")
// ErrTokenMalformed denotes a token was not formatted as a JWT.
ErrTokenMalformed = errors.New("JWT is malformed")

// ErrTokenNotActive denotes a token's not before header (nbf) is in the
// future.
Expand All @@ -44,7 +48,7 @@ var (
ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
)

// NewSigner creates a new JWT token generating middleware, specifying key ID,
// NewSigner creates a new JWT generating middleware, specifying key ID,
// signing string, signing method and the claims you would like it to contain.
// Tokens are signed with a Key ID header (kid) which is useful for determining
// the key to use for parsing. Particularly useful for clients.
Expand All @@ -59,7 +63,7 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, JWTTokenContextKey, tokenString)
ctx = context.WithValue(ctx, JWTContextKey, tokenString)

return next(ctx, request)
}
Expand All @@ -82,15 +86,15 @@ func StandardClaimsFactory() jwt.Claims {
return &jwt.StandardClaims{}
}

// NewParser creates a new JWT token parsing middleware, specifying a
// NewParser creates a new JWT parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
tokenString, ok := ctx.Value(JWTTokenContextKey).(string)
tokenString, ok := ctx.Value(JWTContextKey).(string)
if !ok {
return nil, ErrTokenContextMissing
}
Expand Down
24 changes: 12 additions & 12 deletions auth/jwt/middleware_test.go
Expand Up @@ -44,13 +44,13 @@ func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string
t.Fatalf("Signer returned error: %s", err)
}

token, ok := ctx.(context.Context).Value(JWTTokenContextKey).(string)
token, ok := ctx.(context.Context).Value(JWTContextKey).(string)
if !ok {
t.Fatal("Token did not exist in context")
}

if token != expectedKey {
t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token)
t.Fatalf("JWTs did not match: expecting %s got %s", expectedKey, token)
}
}

Expand Down Expand Up @@ -87,15 +87,15 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Token is passed into the parser
ctx := context.WithValue(context.Background(), JWTTokenContextKey, invalidKey)
ctx := context.WithValue(context.Background(), JWTContextKey, invalidKey)
_, err = parser(ctx, struct{}{})
if err == nil {
t.Error("Parser should have returned an error")
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
t.Error("Parser should have returned an error")
Expand All @@ -111,14 +111,14 @@ func TestJWTParser(t *testing.T) {
}

badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
t.Error("Parser should have returned an error")
}

// Correct token is passed into the parser
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
ctx1, err := parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
Expand All @@ -135,7 +135,7 @@ func TestJWTParser(t *testing.T) {

// Test for malformed token error response
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
Expand All @@ -148,7 +148,7 @@ func TestJWTParser(t *testing.T) {
if err != nil {
t.Fatalf("Unable to Sign Token: %+v", err)
}
ctx = context.WithValue(context.Background(), JWTTokenContextKey, token)
ctx = context.WithValue(context.Background(), JWTContextKey, token)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenExpired, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
Expand All @@ -161,15 +161,15 @@ func TestJWTParser(t *testing.T) {
if err != nil {
t.Fatalf("Unable to Sign Token: %+v", err)
}
ctx = context.WithValue(context.Background(), JWTTokenContextKey, token)
ctx = context.WithValue(context.Background(), JWTContextKey, token)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenNotActive, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// test valid standard claims token
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
Expand All @@ -184,7 +184,7 @@ func TestJWTParser(t *testing.T) {

// test valid customized claims token
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx = context.WithValue(context.Background(), JWTContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
Expand All @@ -205,7 +205,7 @@ func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
key = JWTTokenContextKey
key = JWTContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
)
Expand Down
8 changes: 4 additions & 4 deletions auth/jwt/transport.go
Expand Up @@ -26,15 +26,15 @@ func HTTPToContext() http.RequestFunc {
return ctx
}

return context.WithValue(ctx, JWTTokenContextKey, token)
return context.WithValue(ctx, JWTContextKey, token)
}
}

// ContextToHTTP moves a JWT from context to request header. Particularly
// useful for clients.
func ContextToHTTP() http.RequestFunc {
return func(ctx context.Context, r *stdhttp.Request) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
token, ok := ctx.Value(JWTContextKey).(string)
if ok {
r.Header.Add("Authorization", generateAuthHeaderFromToken(token))
}
Expand All @@ -54,7 +54,7 @@ func GRPCToContext() grpc.ServerRequestFunc {

token, ok := extractTokenFromAuthHeader(authHeader[0])
if ok {
ctx = context.WithValue(ctx, JWTTokenContextKey, token)
ctx = context.WithValue(ctx, JWTContextKey, token)
}

return ctx
Expand All @@ -65,7 +65,7 @@ func GRPCToContext() grpc.ServerRequestFunc {
// useful for clients.
func ContextToGRPC() grpc.ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
token, ok := ctx.Value(JWTContextKey).(string)
if ok {
// capital "Key" is illegal in HTTP/2.
(*md)["authorization"] = []string{generateAuthHeaderFromToken(token)}
Expand Down
38 changes: 19 additions & 19 deletions auth/jwt/transport_test.go
Expand Up @@ -15,7 +15,7 @@ func TestHTTPToContext(t *testing.T) {
// When the header doesn't exist
ctx := reqFunc(context.Background(), &http.Request{})

if ctx.Value(JWTTokenContextKey) != nil {
if ctx.Value(JWTContextKey) != nil {
t.Error("Context shouldn't contain the encoded JWT")
}

Expand All @@ -24,15 +24,15 @@ func TestHTTPToContext(t *testing.T) {
header.Set("Authorization", "no expected auth header format value")
ctx = reqFunc(context.Background(), &http.Request{Header: header})

if ctx.Value(JWTTokenContextKey) != nil {
if ctx.Value(JWTContextKey) != nil {
t.Error("Context shouldn't contain the encoded JWT")
}

// Authorization header is correct
header.Set("Authorization", generateAuthHeaderFromToken(signedKey))
ctx = reqFunc(context.Background(), &http.Request{Header: header})

token := ctx.Value(JWTTokenContextKey).(string)
token := ctx.Value(JWTContextKey).(string)
if token != signedKey {
t.Errorf("Context doesn't contain the expected encoded token value; expected: %s, got: %s", signedKey, token)
}
Expand All @@ -41,7 +41,7 @@ func TestHTTPToContext(t *testing.T) {
func TestContextToHTTP(t *testing.T) {
reqFunc := ContextToHTTP()

// No JWT Token is passed in the context
// No JWT is passed in the context
ctx := context.Background()
r := http.Request{}
reqFunc(ctx, &r)
Expand All @@ -51,16 +51,16 @@ func TestContextToHTTP(t *testing.T) {
t.Error("authorization key should not exist in metadata")
}

// Correct JWT Token is passed in the context
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
// Correct JWT is passed in the context
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
r = http.Request{Header: http.Header{}}
reqFunc(ctx, &r)

token = r.Header.Get("Authorization")
expected := generateAuthHeaderFromToken(signedKey)

if token != expected {
t.Errorf("Authorization header does not contain the expected JWT token; expected %s, got %s", expected, token)
t.Errorf("Authorization header does not contain the expected JWT; expected %s, got %s", expected, token)
}
}

Expand All @@ -70,36 +70,36 @@ func TestGRPCToContext(t *testing.T) {

// No Authorization header is passed
ctx := reqFunc(context.Background(), md)
token := ctx.Value(JWTTokenContextKey)
token := ctx.Value(JWTContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
t.Error("Context should not contain a JWT")
}

// Invalid Authorization header is passed
md["authorization"] = []string{fmt.Sprintf("%s", signedKey)}
ctx = reqFunc(context.Background(), md)
token = ctx.Value(JWTTokenContextKey)
token = ctx.Value(JWTContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
t.Error("Context should not contain a JWT")
}

// Authorization header is correct
md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)}
ctx = reqFunc(context.Background(), md)
token, ok := ctx.Value(JWTTokenContextKey).(string)
token, ok := ctx.Value(JWTContextKey).(string)
if !ok {
t.Fatal("JWT Token not passed to context correctly")
t.Fatal("JWT not passed to context correctly")
}

if token != signedKey {
t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token)
t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token)
}
}

func TestContextToGRPC(t *testing.T) {
reqFunc := ContextToGRPC()

// No JWT Token is passed in the context
// No JWT is passed in the context
ctx := context.Background()
md := metadata.MD{}
reqFunc(ctx, &md)
Expand All @@ -109,17 +109,17 @@ func TestContextToGRPC(t *testing.T) {
t.Error("authorization key should not exist in metadata")
}

// Correct JWT Token is passed in the context
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
// Correct JWT is passed in the context
ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
md = metadata.MD{}
reqFunc(ctx, &md)

token, ok := md["authorization"]
if !ok {
t.Fatal("JWT Token not passed to metadata correctly")
t.Fatal("JWT not passed to metadata correctly")
}

if token[0] != generateAuthHeaderFromToken(signedKey) {
t.Errorf("JWT tokens did not match: expecting %s got %s", signedKey, token[0])
t.Errorf("JWTs did not match: expecting %s got %s", signedKey, token[0])
}
}