Skip to content

Commit

Permalink
node: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
holiman committed Feb 8, 2022
1 parent fa0f9b0 commit 500ec28
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 38 deletions.
50 changes: 29 additions & 21 deletions node/jwt_handler.go
Expand Up @@ -19,35 +19,19 @@ package node
import (
"net/http"

"errors"
"fmt"
"github.com/ethereum/go-ethereum/log"
"github.com/golang-jwt/jwt/v4"
"strings"
"time"
)

// customClaim implements claims.Claim.
type customClaim struct {
// the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6
IssuedAt int64 `json:"iat,omitempty"`
}

// Valid implements claims.Claim, and checks that the iat is present and valid.
func (c customClaim) Valid() error {
if time.Now().Unix()-5 < c.IssuedAt {
return errors.New("token issuance (iat) is too old")
}
if time.Now().Unix()+5 > c.IssuedAt {
return errors.New("token issuance (iat) is too far in the future")
}
return nil
}

type jwtHandler struct {
keyFunc func(token *jwt.Token) (interface{}, error)
next http.Handler
}

// MakeJWTValidator creates a validator for jwt tokens.
// newJWTHandler creates a http.Handler with jwt authentication support.
func newJWTHandler(secret []byte, next http.Handler) http.Handler {
return &jwtHandler{
keyFunc: func(token *jwt.Token) (interface{}, error) {
Expand All @@ -57,6 +41,30 @@ func newJWTHandler(secret []byte, next http.Handler) http.Handler {
}
}

// customClaim is basically a standard RegisteredClaim, but we override the
// Valid method to be more lax in allowing some time skew.
type customClaim jwt.RegisteredClaims

// Valid implements jwt.Claim. This method only validates the (optional) expiry-time.
func (c customClaim) Valid() error {
now := jwt.TimeFunc()
rc := jwt.RegisteredClaims(c)
if !rc.VerifyExpiresAt(now, false) { // optional
return fmt.Errorf("token is expired")
}
if c.IssuedAt == nil {
return fmt.Errorf("missing issued-at")
}
if time.Since(c.IssuedAt.Time) > 5*time.Second {
return fmt.Errorf("stale token")
}
if time.Until(c.IssuedAt.Time) > 5*time.Second {
return fmt.Errorf("future token")
}
return nil
}

// ServeHTTP implements http.Handler
func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
var token string
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
Expand All @@ -67,13 +75,13 @@ func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
return
}
var claims customClaim
t, err := jwt.ParseWithClaims(token, claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"}))
t, err := jwt.ParseWithClaims(token, &claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"}))
if err != nil {
log.Info("Token parsing failed", "err", err)
http.Error(out, err.Error(), http.StatusForbidden)
return
}
if !t.Valid {
// This should not happen, but better safe than sorry if the implementation changes.
http.Error(out, "invalid token", http.StatusForbidden)
return
}
Expand Down
1 change: 1 addition & 0 deletions node/rpcstack.go
Expand Up @@ -373,6 +373,7 @@ func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSe
return newGzipHandler(handler)
}

// NewWSHandlerStack returns a wrapped ws-related handler.
func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler {
if len(jwtSecret) != 0 {
return newJWTHandler(jwtSecret, srv)
Expand Down
74 changes: 57 additions & 17 deletions node/rpcstack_test.go
Expand Up @@ -304,39 +304,79 @@ type tokenTest struct {
expFail []string
}

func TestJWT(t *testing.T) {
type testClaim map[string]interface{}

func (testClaim) Valid() error {
return nil
}

makeToken := func() string {
mySigningKey := []byte("secret")
// Create the Claims
claims := &jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
func TestJWT(t *testing.T) {
var secret = []byte("secret")
issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string {
if method == nil {
method = jwt.SigningMethodHS256
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, _ := token.SignedString(mySigningKey)
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
return ss
}
tests := []originTest{
tests := []tokenTest{
{
//expFail: []string{"Bearer ", "Bearer: abc", "Baxonk hello there"},
expOk: []string{
fmt.Sprintf("Bearer %v", makeToken()),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"exp": time.Now().Unix() + 2,
})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"bar": "baz",
})),
},
expFail: []string{
// future
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})),
// stale
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})),
// wrong algo
fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})),
// expired
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})),
// missing mandatory iat
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})),
// wrong secret
fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})),
// Various malformed syntax
fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer:\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
},
},
}

for _, tc := range tests {
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
url := fmt.Sprintf("ws://%v", srv.listenAddr())
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
for i, token := range tc.expOk {
if err := wsRequest(t, url, "Authorization", token); err != nil {
t.Errorf("test %d, token '%v': expected ok, got %v", i, token, err)
if err := wsRequest(t, wsUrl, "Authorization", token); err != nil {
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
}
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 {
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
}
}
for i, token := range tc.expFail {
if err := wsRequest(t, url, "Authorization", token); err == nil {
t.Errorf("tc %d, token '%v': expected not to allow, got ok", i, token)
if err := wsRequest(t, wsUrl, "Authorization", token); err == nil {
t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token)
}
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 {
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
}
}
srv.stop()
Expand Down

0 comments on commit 500ec28

Please sign in to comment.