diff --git a/go.mod b/go.mod index 4fcbf1002ec4e..a3d59228e92ab 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-stack/stack v1.8.0 + github.com/golang-jwt/jwt/v4 v4.2.0 // indirect github.com/golang/protobuf v1.4.3 github.com/golang/snappy v0.0.4 github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa diff --git a/go.sum b/go.sum index 87f9d5d50a84c..c7e972b17632f 100644 --- a/go.sum +++ b/go.sum @@ -167,6 +167,8 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= +github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/graphql/service.go b/graphql/service.go index bcb0a4990d646..29d98ad746836 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -74,7 +74,7 @@ func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) return err } h := handler{Schema: s} - handler := node.NewHTTPHandlerStack(h, cors, vhosts) + handler := node.NewHTTPHandlerStack(h, cors, vhosts, nil) stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{}) stack.RegisterHandler("GraphQL", "/graphql", handler) diff --git a/node/jwt_handler.go b/node/jwt_handler.go new file mode 100644 index 0000000000000..28d5b87c60bcc --- /dev/null +++ b/node/jwt_handler.go @@ -0,0 +1,78 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +type jwtHandler struct { + keyFunc func(token *jwt.Token) (interface{}, error) + next http.Handler +} + +// newJWTHandler creates a http.Handler with jwt authentication support. +func newJWTHandler(secret []byte, next http.Handler) http.Handler { + return &jwtHandler{ + keyFunc: func(token *jwt.Token) (interface{}, error) { + return secret, nil + }, + next: next, + } +} + +// ServeHTTP implements http.Handler +func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) { + var ( + strToken string + claims jwt.RegisteredClaims + ) + if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { + strToken = strings.TrimPrefix(auth, "Bearer ") + } + if len(strToken) == 0 { + http.Error(out, "missing token", http.StatusForbidden) + return + } + // We explicitly set only HS256 allowed, and also disables the + // claim-check: the RegisteredClaims internally requires 'iat' to + // be no later than 'now', but we allow for a bit of drift. + token, err := jwt.ParseWithClaims(strToken, &claims, handler.keyFunc, + jwt.WithValidMethods([]string{"HS256"}), + jwt.WithoutClaimsValidation()) + + switch { + case err != nil: + http.Error(out, err.Error(), http.StatusForbidden) + case !token.Valid: + http.Error(out, "invalid token", http.StatusForbidden) + case !claims.VerifyExpiresAt(time.Now(), false): // optional + http.Error(out, "token is expired", http.StatusForbidden) + case claims.IssuedAt == nil: + http.Error(out, "missing issued-at", http.StatusForbidden) + case time.Since(claims.IssuedAt.Time) > 5*time.Second: + http.Error(out, "stale token", http.StatusForbidden) + case time.Until(claims.IssuedAt.Time) > 5*time.Second: + http.Error(out, "future token", http.StatusForbidden) + default: + handler.next.ServeHTTP(out, r) + } +} diff --git a/node/node.go b/node/node.go index ceab1c90902b4..428c7394dcf6f 100644 --- a/node/node.go +++ b/node/node.go @@ -349,7 +349,7 @@ func (n *Node) startRPC() error { return err } } - + jwtSecret := []byte("secret") // Configure HTTP. if n.config.HTTPHost != "" { config := httpConfig{ @@ -357,6 +357,7 @@ func (n *Node) startRPC() error { Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, + jwtSecret: jwtSecret, } if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil { return err @@ -370,9 +371,10 @@ func (n *Node) startRPC() error { if n.config.WSHost != "" { server := n.wsServerForPort(n.config.WSPort) config := wsConfig{ - Modules: n.config.WSModules, - Origins: n.config.WSOrigins, - prefix: n.config.WSPathPrefix, + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + prefix: n.config.WSPathPrefix, + jwtSecret: jwtSecret, } if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil { return err diff --git a/node/node_test.go b/node/node_test.go index 25cfa9d38d78f..84f61f0c44c4e 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -577,13 +577,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { } } for _, path := range test.wantWS { - err := wsRequest(t, wsBase+path, "") + err := wsRequest(t, wsBase+path) if err != nil { t.Errorf("Error: %s: WebSocket connection failed: %v", path, err) } } for _, path := range test.wantNoWS { - err := wsRequest(t, wsBase+path, "") + err := wsRequest(t, wsBase+path) if err == nil { t.Errorf("Error: %s: WebSocket connection succeeded for path in wantNoWS", path) } diff --git a/node/rpcstack.go b/node/rpcstack.go index 2c55a070b2291..fd07814ab58ae 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -40,13 +40,15 @@ type httpConfig struct { CorsAllowedOrigins []string Vhosts []string prefix string // path prefix on which to mount http handler + jwtSecret []byte // optional JWT secret } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { - Origins []string - Modules []string - prefix string // path prefix on which to mount ws handler + Origins []string + Modules []string + prefix string // path prefix on which to mount ws handler + jwtSecret []byte // optional JWT secret } type rpcHandler struct { @@ -285,7 +287,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { } h.httpConfig = config h.httpHandler.Store(&rpcHandler{ - Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts), + Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts, config.jwtSecret), server: srv, }) return nil @@ -309,7 +311,6 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { if h.wsAllowed() { return fmt.Errorf("JSON-RPC over WebSocket is already enabled") } - // Create RPC server and handler. srv := rpc.NewServer() if err := RegisterApis(apis, config.Modules, srv, false); err != nil { @@ -317,7 +318,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } h.wsConfig = config h.wsHandler.Store(&rpcHandler{ - Handler: srv.WebsocketHandler(config.Origins), + Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret), server: srv, }) return nil @@ -362,13 +363,24 @@ func isWebsocket(r *http.Request) bool { } // NewHTTPHandlerStack returns wrapped http-related handlers -func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler { +func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSecret []byte) http.Handler { // Wrap the CORS-handler within a host-handler handler := newCorsHandler(srv, cors) handler = newVHostHandler(vhosts, handler) + if len(jwtSecret) != 0 { + handler = newJWTHandler(jwtSecret, handler) + } return newGzipHandler(handler) } +// NewWSHandlerStack returns a wrapped ws-related handler. +func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler { + if len(jwtSecret) != 0 { + return newJWTHandler(jwtSecret, srv) + } + return srv +} + func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { // disable CORS support if user has not specified a custom CORS configuration if len(allowedOrigins) == 0 { diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index f92f0ba39693b..60fcab5a90019 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -24,10 +24,12 @@ import ( "strconv" "strings" "testing" + "time" "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -146,12 +148,12 @@ func TestWebsocketOrigins(t *testing.T) { srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { - if err := wsRequest(t, url, origin); err != nil { + if err := wsRequest(t, url, "Origin", origin); err != nil { t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) } } for _, origin := range tc.expFail { - if err := wsRequest(t, url, origin); err == nil { + if err := wsRequest(t, url, "Origin", origin); err == nil { t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) } } @@ -243,13 +245,18 @@ func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsCon } // wsRequest attempts to open a WebSocket connection to the given URL. -func wsRequest(t *testing.T, url, browserOrigin string) error { +func wsRequest(t *testing.T, url string, extraHeaders ...string) error { t.Helper() - t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) + //t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) headers := make(http.Header) - if browserOrigin != "" { - headers.Set("Origin", browserOrigin) + // Apply extra headers. + if len(extraHeaders)%2 != 0 { + panic("odd extraHeaders length") + } + for i := 0; i < len(extraHeaders); i += 2 { + key, value := extraHeaders[i], extraHeaders[i+1] + headers.Set(key, value) } conn, _, err := websocket.DefaultDialer.Dial(url, headers) if conn != nil { @@ -291,3 +298,79 @@ func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response } return resp } + +type testClaim map[string]interface{} + +func (testClaim) Valid() error { + return nil +} + +func TestJWT(t *testing.T) { + var secret = []byte("secret") + issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string { + if method == nil { + method = jwt.SigningMethodHS256 + } + ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) + return ss + } + expOk := []string{ + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "exp": time.Now().Unix() + 2, + })), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "bar": "baz", + })), + } + expFail := []string{ + // future + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})), + // stale + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})), + // wrong algo + fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})), + // expired + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})), + // missing mandatory iat + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})), + // wrong secret + fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})), + // Various malformed syntax + fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer \t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + } + srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, + true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) + htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) + + for i, token := range expOk { + if err := wsRequest(t, wsUrl, "Authorization", token); err != nil { + t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) + } + } + for i, token := range expFail { + if err := wsRequest(t, wsUrl, "Authorization", token); err == nil { + t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 { + t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) + } + } + srv.stop() +} diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 8659f798e4a06..f74b7fd08bb4f 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -76,7 +76,7 @@ func TestWebsocketOriginCheck(t *testing.T) { // Connections without origin header should work. client, err = DialWebsocket(context.Background(), wsURL, "") if err != nil { - t.Fatal("error for empty origin") + t.Fatalf("error for empty origin: %v", err) } client.Close() }