Skip to content

Commit

Permalink
jsonrpc.DefaultErrorEncoder: add RequestID in error body (#969)
Browse files Browse the repository at this point in the history
* Add RequestID in error body

* Implementing review suggestions
  • Loading branch information
esenac committed Mar 22, 2020
1 parent 6ce524c commit cb67d82
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
14 changes: 13 additions & 1 deletion transport/http/jsonrpc/server.go
Expand Up @@ -11,6 +11,10 @@ import (
httptransport "github.com/go-kit/kit/transport/http"
)

type requestIDKeyType struct{}

var requestIDKey requestIDKeyType

// Server wraps an endpoint and implements http.Handler.
type Server struct {
ecm EndpointCodecMap
Expand Down Expand Up @@ -105,6 +109,8 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

ctx = context.WithValue(ctx, requestIDKey, req.ID)

// Get the endpoint and codecs from the map using the method
// defined in the JSON object
ecm, ok := s.ecm[req.Method]
Expand Down Expand Up @@ -160,7 +166,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If the error implements ErrorCoder, the provided code will be set on the
// response error.
// If the error implements Headerer, the given headers will be set.
func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
func DefaultErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) {
w.Header().Set("Content-Type", ContentType)
if headerer, ok := err.(httptransport.Headerer); ok {
for k := range headerer.Headers() {
Expand All @@ -177,7 +183,13 @@ func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
}

w.WriteHeader(http.StatusOK)

var requestID *RequestID
if v := ctx.Value(requestIDKey); v != nil {
requestID = v.(*RequestID)
}
_ = json.NewEncoder(w).Encode(Response{
ID: requestID,
JSONRPC: Version,
Error: &e,
})
Expand Down
48 changes: 42 additions & 6 deletions transport/http/jsonrpc/server_test.go
Expand Up @@ -24,11 +24,17 @@ func body(in string) io.Reader {
return strings.NewReader(in)
}

func unmarshalResponse(body []byte) (resp jsonrpc.Response, err error) {
err = json.Unmarshal(body, &resp)
return
}

func expectErrorCode(t *testing.T, want int, body []byte) {
var r jsonrpc.Response
err := json.Unmarshal(body, &r)
t.Helper()

r, err := unmarshalResponse(body)
if err != nil {
t.Fatalf("Cant' decode response. err=%s, body=%s", err, body)
t.Fatalf("Can't decode response: %v (%s)", err, body)
}
if r.Error == nil {
t.Fatalf("Expected error on response. Got none: %s", body)
Expand All @@ -38,6 +44,34 @@ func expectErrorCode(t *testing.T, want int, body []byte) {
}
}

func expectValidRequestID(t *testing.T, want int, body []byte) {
t.Helper()

r, err := unmarshalResponse(body)
if err != nil {
t.Fatalf("Can't decode response: %v (%s)", err, body)
}
have, err := r.ID.Int()
if err != nil {
t.Fatalf("Can't get requestID in response. err=%s, body=%s", err, body)
}
if want != have {
t.Fatalf("Request ID: want %d, have %d (%s)", want, have, body)
}
}

func expectNilRequestID(t *testing.T, body []byte) {
t.Helper()

r, err := unmarshalResponse(body)
if err != nil {
t.Fatalf("Can't decode response: %v (%s)", err, body)
}
if r.ID != nil {
t.Fatalf("Request ID: want nil, have %v", r.ID)
}
}

func nopDecoder(context.Context, json.RawMessage) (interface{}, error) { return struct{}{}, nil }
func nopEncoder(context.Context, interface{}) (json.RawMessage, error) { return []byte("[]"), nil }

Expand Down Expand Up @@ -92,6 +126,7 @@ func TestServerBadEndpoint(t *testing.T) {
}
buf, _ := ioutil.ReadAll(resp.Body)
expectErrorCode(t, jsonrpc.InternalError, buf)
expectValidRequestID(t, 1, buf)
}

func TestServerBadEncode(t *testing.T) {
Expand All @@ -111,6 +146,7 @@ func TestServerBadEncode(t *testing.T) {
}
buf, _ := ioutil.ReadAll(resp.Body)
expectErrorCode(t, jsonrpc.InternalError, buf)
expectValidRequestID(t, 1, buf)
}

func TestServerErrorEncoder(t *testing.T) {
Expand Down Expand Up @@ -162,6 +198,7 @@ func TestCanRejectInvalidJSON(t *testing.T) {
}
buf, _ := ioutil.ReadAll(resp.Body)
expectErrorCode(t, jsonrpc.ParseError, buf)
expectNilRequestID(t, buf)
}

func TestServerUnregisteredMethod(t *testing.T) {
Expand All @@ -186,10 +223,9 @@ func TestServerHappyPath(t *testing.T) {
if want, have := http.StatusOK, resp.StatusCode; want != have {
t.Errorf("want %d, have %d (%s)", want, have, buf)
}
var r jsonrpc.Response
err := json.Unmarshal(buf, &r)
r, err := unmarshalResponse(buf)
if err != nil {
t.Fatalf("Cant' decode response. err=%s, body=%s", err, buf)
t.Fatalf("Can't decode response. err=%s, body=%s", err, buf)
}
if r.JSONRPC != jsonrpc.Version {
t.Fatalf("JSONRPC Version: want=%s, got=%s", jsonrpc.Version, r.JSONRPC)
Expand Down

0 comments on commit cb67d82

Please sign in to comment.