Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc, node: refactor request validation and add jwt validation #24358

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -27,6 +27,7 @@ require (
github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff
github.com/go-ole/go-ole v1.2.1 // indirect
github.com/go-stack/stack v1.8.0
github.com/golang-jwt/jwt/v4 v4.2.0 // indirect
github.com/golang/protobuf v1.4.3
github.com/golang/snappy v0.0.4
github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -167,6 +167,8 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU=
github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
Expand Down
2 changes: 1 addition & 1 deletion graphql/service.go
Expand Up @@ -74,7 +74,7 @@ func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string)
return err
}
h := handler{Schema: s}
handler := node.NewHTTPHandlerStack(h, cors, vhosts)
handler := node.NewHTTPHandlerStack(h, cors, vhosts, nil)

stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{})
stack.RegisterHandler("GraphQL", "/graphql", handler)
Expand Down
78 changes: 78 additions & 0 deletions node/jwt_handler.go
@@ -0,0 +1,78 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package node

import (
"net/http"
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
)

type jwtHandler struct {
keyFunc func(token *jwt.Token) (interface{}, error)
next http.Handler
}

// newJWTHandler creates a http.Handler with jwt authentication support.
func newJWTHandler(secret []byte, next http.Handler) http.Handler {
return &jwtHandler{
keyFunc: func(token *jwt.Token) (interface{}, error) {
return secret, nil
},
next: next,
}
}

// ServeHTTP implements http.Handler
func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
var (
strToken string
claims jwt.RegisteredClaims
)
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
strToken = strings.TrimPrefix(auth, "Bearer ")
}
if len(strToken) == 0 {
http.Error(out, "missing token", http.StatusForbidden)
return
}
// We explicitly set only HS256 allowed, and also disables the
// claim-check: the RegisteredClaims internally requires 'iat' to
// be no later than 'now', but we allow for a bit of drift.
token, err := jwt.ParseWithClaims(strToken, &claims, handler.keyFunc,
jwt.WithValidMethods([]string{"HS256"}),
jwt.WithoutClaimsValidation())

switch {
case err != nil:
http.Error(out, err.Error(), http.StatusForbidden)
case !token.Valid:
http.Error(out, "invalid token", http.StatusForbidden)
case !claims.VerifyExpiresAt(time.Now(), false): // optional
http.Error(out, "token is expired", http.StatusForbidden)
case claims.IssuedAt == nil:
http.Error(out, "missing issued-at", http.StatusForbidden)
case time.Since(claims.IssuedAt.Time) > 5*time.Second:
http.Error(out, "stale token", http.StatusForbidden)
case time.Until(claims.IssuedAt.Time) > 5*time.Second:
http.Error(out, "future token", http.StatusForbidden)
default:
handler.next.ServeHTTP(out, r)
}
}
10 changes: 6 additions & 4 deletions node/node.go
Expand Up @@ -349,14 +349,15 @@ func (n *Node) startRPC() error {
return err
}
}

jwtSecret := []byte("secret")
// Configure HTTP.
if n.config.HTTPHost != "" {
config := httpConfig{
CorsAllowedOrigins: n.config.HTTPCors,
Vhosts: n.config.HTTPVirtualHosts,
Modules: n.config.HTTPModules,
prefix: n.config.HTTPPathPrefix,
jwtSecret: jwtSecret,
}
if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil {
return err
Expand All @@ -370,9 +371,10 @@ func (n *Node) startRPC() error {
if n.config.WSHost != "" {
server := n.wsServerForPort(n.config.WSPort)
config := wsConfig{
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
jwtSecret: jwtSecret,
}
if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions node/node_test.go
Expand Up @@ -577,13 +577,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
}
}
for _, path := range test.wantWS {
err := wsRequest(t, wsBase+path, "")
err := wsRequest(t, wsBase+path)
if err != nil {
t.Errorf("Error: %s: WebSocket connection failed: %v", path, err)
}
}
for _, path := range test.wantNoWS {
err := wsRequest(t, wsBase+path, "")
err := wsRequest(t, wsBase+path)
if err == nil {
t.Errorf("Error: %s: WebSocket connection succeeded for path in wantNoWS", path)
}
Expand Down
26 changes: 19 additions & 7 deletions node/rpcstack.go
Expand Up @@ -40,13 +40,15 @@ type httpConfig struct {
CorsAllowedOrigins []string
Vhosts []string
prefix string // path prefix on which to mount http handler
jwtSecret []byte // optional JWT secret
}

// wsConfig is the JSON-RPC/Websocket configuration
type wsConfig struct {
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
jwtSecret []byte // optional JWT secret
}

type rpcHandler struct {
Expand Down Expand Up @@ -285,7 +287,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
}
h.httpConfig = config
h.httpHandler.Store(&rpcHandler{
Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts),
Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts, config.jwtSecret),
server: srv,
})
return nil
Expand All @@ -309,15 +311,14 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
if h.wsAllowed() {
return fmt.Errorf("JSON-RPC over WebSocket is already enabled")
}

// Create RPC server and handler.
srv := rpc.NewServer()
if err := RegisterApis(apis, config.Modules, srv, false); err != nil {
return err
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: srv.WebsocketHandler(config.Origins),
Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret),
server: srv,
})
return nil
Expand Down Expand Up @@ -362,13 +363,24 @@ func isWebsocket(r *http.Request) bool {
}

// NewHTTPHandlerStack returns wrapped http-related handlers
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler {
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSecret []byte) http.Handler {
// Wrap the CORS-handler within a host-handler
handler := newCorsHandler(srv, cors)
handler = newVHostHandler(vhosts, handler)
if len(jwtSecret) != 0 {
handler = newJWTHandler(jwtSecret, handler)
}
return newGzipHandler(handler)
}

// NewWSHandlerStack returns a wrapped ws-related handler.
func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler {
if len(jwtSecret) != 0 {
return newJWTHandler(jwtSecret, srv)
}
return srv
}

func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
// disable CORS support if user has not specified a custom CORS configuration
if len(allowedOrigins) == 0 {
Expand Down
103 changes: 97 additions & 6 deletions node/rpcstack_test.go
Expand Up @@ -24,10 +24,12 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc"
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -146,12 +148,12 @@ func TestWebsocketOrigins(t *testing.T) {
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)})
url := fmt.Sprintf("ws://%v", srv.listenAddr())
for _, origin := range tc.expOk {
if err := wsRequest(t, url, origin); err != nil {
if err := wsRequest(t, url, "Origin", origin); err != nil {
t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err)
}
}
for _, origin := range tc.expFail {
if err := wsRequest(t, url, origin); err == nil {
if err := wsRequest(t, url, "Origin", origin); err == nil {
t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin)
}
}
Expand Down Expand Up @@ -243,13 +245,18 @@ func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsCon
}

// wsRequest attempts to open a WebSocket connection to the given URL.
func wsRequest(t *testing.T, url, browserOrigin string) error {
func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
t.Helper()
t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin)
//t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin)

headers := make(http.Header)
if browserOrigin != "" {
headers.Set("Origin", browserOrigin)
// Apply extra headers.
if len(extraHeaders)%2 != 0 {
panic("odd extraHeaders length")
}
for i := 0; i < len(extraHeaders); i += 2 {
key, value := extraHeaders[i], extraHeaders[i+1]
headers.Set(key, value)
}
conn, _, err := websocket.DefaultDialer.Dial(url, headers)
if conn != nil {
Expand Down Expand Up @@ -291,3 +298,87 @@ func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response
}
return resp
}

type tokenTest struct {
expOk []string
expFail []string
}

type testClaim map[string]interface{}

func (testClaim) Valid() error {
return nil
}

func TestJWT(t *testing.T) {
var secret = []byte("secret")
issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string {
if method == nil {
method = jwt.SigningMethodHS256
}
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
return ss
}
tests := []tokenTest{
{
expOk: []string{
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"exp": time.Now().Unix() + 2,
})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"bar": "baz",
})),
},
expFail: []string{
// future
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})),
// stale
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})),
// wrong algo
fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})),
// expired
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})),
// missing mandatory iat
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})),
// wrong secret
fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})),
// Various malformed syntax
fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer:\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fmt.Sprintf("Bearer:\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),

},
},
}
for _, tc := range tests {
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
for i, token := range tc.expOk {
if err := wsRequest(t, wsUrl, "Authorization", token); err != nil {
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
}
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 {
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
}
}
for i, token := range tc.expFail {
if err := wsRequest(t, wsUrl, "Authorization", token); err == nil {
t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token)
}
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 {
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
}
}
srv.stop()
}
}
2 changes: 1 addition & 1 deletion rpc/websocket_test.go
Expand Up @@ -76,7 +76,7 @@ func TestWebsocketOriginCheck(t *testing.T) {
// Connections without origin header should work.
client, err = DialWebsocket(context.Background(), wsURL, "")
if err != nil {
t.Fatal("error for empty origin")
t.Fatalf("error for empty origin: %v", err)
}
client.Close()
}
Expand Down