Skip to content

Commit

Permalink
rpc, node: refactor request validation and add jwt validation
Browse files Browse the repository at this point in the history
  • Loading branch information
holiman committed Feb 8, 2022
1 parent 2d20fed commit fa0f9b0
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 21 deletions.
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -27,6 +27,7 @@ require (
github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff
github.com/go-ole/go-ole v1.2.1 // indirect
github.com/go-stack/stack v1.8.0
github.com/golang-jwt/jwt/v4 v4.2.0 // indirect
github.com/golang/protobuf v1.4.3
github.com/golang/snappy v0.0.4
github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -167,6 +167,8 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU=
github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
Expand Down
2 changes: 1 addition & 1 deletion graphql/service.go
Expand Up @@ -74,7 +74,7 @@ func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string)
return err
}
h := handler{Schema: s}
handler := node.NewHTTPHandlerStack(h, cors, vhosts)
handler := node.NewHTTPHandlerStack(h, cors, vhosts, nil)

stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{})
stack.RegisterHandler("GraphQL", "/graphql", handler)
Expand Down
81 changes: 81 additions & 0 deletions node/jwt_handler.go
@@ -0,0 +1,81 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package node

import (
"net/http"

"errors"
"github.com/golang-jwt/jwt/v4"
"strings"
"time"
)

// customClaim implements claims.Claim.
type customClaim struct {
// the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6
IssuedAt int64 `json:"iat,omitempty"`
}

// Valid implements claims.Claim, and checks that the iat is present and valid.
func (c customClaim) Valid() error {
if time.Now().Unix()-5 < c.IssuedAt {
return errors.New("token issuance (iat) is too old")
}
if time.Now().Unix()+5 > c.IssuedAt {
return errors.New("token issuance (iat) is too far in the future")
}
return nil
}

type jwtHandler struct {
keyFunc func(token *jwt.Token) (interface{}, error)
next http.Handler
}

// MakeJWTValidator creates a validator for jwt tokens.
func newJWTHandler(secret []byte, next http.Handler) http.Handler {
return &jwtHandler{
keyFunc: func(token *jwt.Token) (interface{}, error) {
return secret, nil
},
next: next,
}
}

func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
var token string
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
token = strings.TrimPrefix(auth, "Bearer ")
}
if len(token) == 0 {
http.Error(out, "missing token", http.StatusForbidden)
return
}
var claims customClaim
t, err := jwt.ParseWithClaims(token, claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"}))
if err != nil {
http.Error(out, err.Error(), http.StatusForbidden)
return
}
if !t.Valid {
// This should not happen, but better safe than sorry if the implementation changes.
http.Error(out, "invalid token", http.StatusForbidden)
return
}
handler.next.ServeHTTP(out, r)
}
10 changes: 6 additions & 4 deletions node/node.go
Expand Up @@ -349,14 +349,15 @@ func (n *Node) startRPC() error {
return err
}
}

jwtSecret := []byte("secret")
// Configure HTTP.
if n.config.HTTPHost != "" {
config := httpConfig{
CorsAllowedOrigins: n.config.HTTPCors,
Vhosts: n.config.HTTPVirtualHosts,
Modules: n.config.HTTPModules,
prefix: n.config.HTTPPathPrefix,
jwtSecret: jwtSecret,
}
if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil {
return err
Expand All @@ -370,9 +371,10 @@ func (n *Node) startRPC() error {
if n.config.WSHost != "" {
server := n.wsServerForPort(n.config.WSPort)
config := wsConfig{
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
jwtSecret: jwtSecret,
}
if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions node/node_test.go
Expand Up @@ -577,13 +577,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) {
}
}
for _, path := range test.wantWS {
err := wsRequest(t, wsBase+path, "")
err := wsRequest(t, wsBase+path)
if err != nil {
t.Errorf("Error: %s: WebSocket connection failed: %v", path, err)
}
}
for _, path := range test.wantNoWS {
err := wsRequest(t, wsBase+path, "")
err := wsRequest(t, wsBase+path)
if err == nil {
t.Errorf("Error: %s: WebSocket connection succeeded for path in wantNoWS", path)
}
Expand Down
25 changes: 18 additions & 7 deletions node/rpcstack.go
Expand Up @@ -40,13 +40,15 @@ type httpConfig struct {
CorsAllowedOrigins []string
Vhosts []string
prefix string // path prefix on which to mount http handler
jwtSecret []byte // optional JWT secret
}

// wsConfig is the JSON-RPC/Websocket configuration
type wsConfig struct {
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
jwtSecret []byte // optional JWT secret
}

type rpcHandler struct {
Expand Down Expand Up @@ -285,7 +287,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
}
h.httpConfig = config
h.httpHandler.Store(&rpcHandler{
Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts),
Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts, config.jwtSecret),
server: srv,
})
return nil
Expand All @@ -309,15 +311,14 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
if h.wsAllowed() {
return fmt.Errorf("JSON-RPC over WebSocket is already enabled")
}

// Create RPC server and handler.
srv := rpc.NewServer()
if err := RegisterApis(apis, config.Modules, srv, false); err != nil {
return err
}
h.wsConfig = config
h.wsHandler.Store(&rpcHandler{
Handler: srv.WebsocketHandler(config.Origins),
Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret),
server: srv,
})
return nil
Expand Down Expand Up @@ -362,13 +363,23 @@ func isWebsocket(r *http.Request) bool {
}

// NewHTTPHandlerStack returns wrapped http-related handlers
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler {
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSecret []byte) http.Handler {
// Wrap the CORS-handler within a host-handler
handler := newCorsHandler(srv, cors)
handler = newVHostHandler(vhosts, handler)
if len(jwtSecret) != 0 {
handler = newJWTHandler(jwtSecret, handler)
}
return newGzipHandler(handler)
}

func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler {
if len(jwtSecret) != 0 {
return newJWTHandler(jwtSecret, srv)
}
return srv
}

func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
// disable CORS support if user has not specified a custom CORS configuration
if len(allowedOrigins) == 0 {
Expand Down
63 changes: 57 additions & 6 deletions node/rpcstack_test.go
Expand Up @@ -24,10 +24,12 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc"
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -146,12 +148,12 @@ func TestWebsocketOrigins(t *testing.T) {
srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)})
url := fmt.Sprintf("ws://%v", srv.listenAddr())
for _, origin := range tc.expOk {
if err := wsRequest(t, url, origin); err != nil {
if err := wsRequest(t, url, "Origin", origin); err != nil {
t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err)
}
}
for _, origin := range tc.expFail {
if err := wsRequest(t, url, origin); err == nil {
if err := wsRequest(t, url, "Origin", origin); err == nil {
t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin)
}
}
Expand Down Expand Up @@ -243,13 +245,18 @@ func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsCon
}

// wsRequest attempts to open a WebSocket connection to the given URL.
func wsRequest(t *testing.T, url, browserOrigin string) error {
func wsRequest(t *testing.T, url string, extraHeaders ...string) error {
t.Helper()
t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin)
//t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin)

headers := make(http.Header)
if browserOrigin != "" {
headers.Set("Origin", browserOrigin)
// Apply extra headers.
if len(extraHeaders)%2 != 0 {
panic("odd extraHeaders length")
}
for i := 0; i < len(extraHeaders); i += 2 {
key, value := extraHeaders[i], extraHeaders[i+1]
headers.Set(key, value)
}
conn, _, err := websocket.DefaultDialer.Dial(url, headers)
if conn != nil {
Expand Down Expand Up @@ -291,3 +298,47 @@ func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response
}
return resp
}

type tokenTest struct {
expOk []string
expFail []string
}

func TestJWT(t *testing.T) {

makeToken := func() string {
mySigningKey := []byte("secret")
// Create the Claims
claims := &jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, _ := token.SignedString(mySigningKey)
return ss
}
tests := []originTest{
{
//expFail: []string{"Bearer ", "Bearer: abc", "Baxonk hello there"},
expOk: []string{
fmt.Sprintf("Bearer %v", makeToken()),
},
},
}

for _, tc := range tests {
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
url := fmt.Sprintf("ws://%v", srv.listenAddr())
for i, token := range tc.expOk {
if err := wsRequest(t, url, "Authorization", token); err != nil {
t.Errorf("test %d, token '%v': expected ok, got %v", i, token, err)
}
}
for i, token := range tc.expFail {
if err := wsRequest(t, url, "Authorization", token); err == nil {
t.Errorf("tc %d, token '%v': expected not to allow, got ok", i, token)
}
}
srv.stop()
}
}
2 changes: 1 addition & 1 deletion rpc/websocket_test.go
Expand Up @@ -76,7 +76,7 @@ func TestWebsocketOriginCheck(t *testing.T) {
// Connections without origin header should work.
client, err = DialWebsocket(context.Background(), wsURL, "")
if err != nil {
t.Fatal("error for empty origin")
t.Fatalf("error for empty origin: %v", err)
}
client.Close()
}
Expand Down

0 comments on commit fa0f9b0

Please sign in to comment.