Skip to content

Commit

Permalink
feature : responses whit context
Browse files Browse the repository at this point in the history
  • Loading branch information
heyehang committed Nov 25, 2022
1 parent 97a8b3a commit 5e9cf7b
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 8 deletions.
6 changes: 3 additions & 3 deletions gateway/server.go
Expand Up @@ -122,7 +122,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
return func(w http.ResponseWriter, r *http.Request) {
parser, err := internal.NewRequestParser(r, resolver)
if err != nil {
httpx.Error(w, err)
httpx.Error(r.Context(), w, err)
return
}

Expand All @@ -134,12 +134,12 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
handler := internal.NewEventHandler(w, resolver)
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
handler, parser.Next); err != nil {
httpx.Error(w, err)
httpx.Error(r.Context(), w, err)
}

st := handler.Status
if st.Code() != codes.OK {
httpx.Error(w, st.Err())
httpx.Error(r.Context(), w, st.Err())
}
}
}
Expand Down
74 changes: 72 additions & 2 deletions rest/httpx/responses.go
@@ -1,6 +1,7 @@
package httpx

import (
"context"
"encoding/json"
"net/http"
"sync"
Expand All @@ -11,8 +12,9 @@ import (
)

var (
errorHandler func(error) (int, interface{})
lock sync.RWMutex
errorHandler func(error) (int, interface{})
lock sync.RWMutex
errorHandlerCtx func(context.Context, error) (int, interface{})
)

// Error writes err into w.
Expand Down Expand Up @@ -87,3 +89,71 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}

// Error writes err into w.
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handlerCtx := errorHandlerCtx
lock.RUnlock()

if handlerCtx == nil {
if len(fns) > 0 {
fns[0](w, err)
} else if errcode.IsGrpcError(err) {
// don't unwrap error and get status.Message(),
// it hides the rpc error headers.
http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
} else {
http.Error(w, err.Error(), http.StatusBadRequest)
}

return
}

code, body := handlerCtx(ctx, err)
if body == nil {
w.WriteHeader(code)
return
}

e, ok := body.(error)
if ok {
http.Error(w, e.Error(), code)
} else {
WriteJsonCtx(ctx, w, code, body)
}
}

// OkJson writes v into w with 200 OK.
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
WriteJsonCtx(ctx, w, http.StatusOK, v)
}

// WriteJson writes v as json string into w with code.
func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
bs, err := json.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set(ContentType, header.JsonContentType)
w.WriteHeader(code)

if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
logx.WithContext(ctx).Errorf("write response failed, error: %s", err)
}
} else if n < len(bs) {
logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}

// SetErrorHandler sets the error handler, which is called on calling Error.
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
lock.Lock()
defer lock.Unlock()
errorHandlerCtx = handlerCtx
}
113 changes: 113 additions & 0 deletions rest/httpx/responses_test.go
@@ -1,6 +1,7 @@
package httpx

import (
"context"
"errors"
"net/http"
"strings"
Expand Down Expand Up @@ -214,3 +215,115 @@ func (w *tracedResponseWriter) WriteHeader(code int) {
w.wroteHeader = true
w.code = code
}

func TestErrorCtx(t *testing.T) {
const (
body = "foo"
wrappedBody = `"foo"`
)

tests := []struct {
name string
input string
errorHandlerCtx func(context.Context, error) (int, interface{})
expectHasBody bool
expectBody string
expectCode int
}{
{
name: "default error handler",
input: body,
expectHasBody: true,
expectBody: body,
expectCode: http.StatusBadRequest,
},
{
name: "customized error handler return string",
input: body,
errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
return http.StatusForbidden, err.Error()
},
expectHasBody: true,
expectBody: wrappedBody,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return error",
input: body,
errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
return http.StatusForbidden, err
},
expectHasBody: true,
expectBody: body,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return nil",
input: body,
errorHandlerCtx: func(context.Context, error) (int, interface{}) {
return http.StatusForbidden, nil
},
expectHasBody: false,
expectBody: "",
expectCode: http.StatusForbidden,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
if test.errorHandlerCtx != nil {
lock.RLock()
prev := errorHandlerCtx
lock.RUnlock()
SetErrorHandlerCtx(test.errorHandlerCtx)
defer func() {
lock.Lock()
test.errorHandlerCtx = prev
lock.Unlock()
}()
}
ErrorCtx(context.Background(), &w, errors.New(test.input))
assert.Equal(t, test.expectCode, w.code)
assert.Equal(t, test.expectHasBody, w.hasBody)
assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
})
}

//The current handler is a global event,Set default values to avoid impacting subsequent unit tests
SetErrorHandlerCtx(nil)
}

func TestErrorWithGrpcErrorCtx(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo"))
assert.Equal(t, http.StatusServiceUnavailable, w.code)
assert.True(t, w.hasBody)
assert.True(t, strings.Contains(w.builder.String(), "foo"))
}

func TestErrorWithHandlerCtx(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), 499)
})
assert.Equal(t, 499, w.code)
assert.True(t, w.hasBody)
assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
}

func TestWriteJsonCtxMarshalFailed(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]interface{}{
"Data": complex(0, 0),
})
assert.Equal(t, http.StatusInternalServerError, w.code)
}
6 changes: 3 additions & 3 deletions tools/goctl/api/gogen/handler.tpl
Expand Up @@ -11,16 +11,16 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
{{if .HasRequest}}var req types.{{.RequestType}}
if err := httpx.Parse(r, &req); err != nil {
httpx.Error(w, err)
httpx.Error(r.Context(), w, err)
return
}

{{end}}l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx)
{{if .HasResp}}resp, {{end}}err := l.{{.Call}}({{if .HasRequest}}&req{{end}})
if err != nil {
httpx.Error(w, err)
httpx.Error(r.Context(), w, err)
} else {
{{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}}
{{if .HasResp}}httpx.OkJson(r.Context(), w, resp){{else}}httpx.Ok(w){{end}}
}
}
}

0 comments on commit 5e9cf7b

Please sign in to comment.