Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature : responses whit context #2637

Merged
merged 1 commit into from Dec 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions gateway/server.go
Expand Up @@ -122,7 +122,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
return func(w http.ResponseWriter, r *http.Request) {
parser, err := internal.NewRequestParser(r, resolver)
if err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
return
}

Expand All @@ -134,12 +134,12 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
handler := internal.NewEventHandler(w, resolver)
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
handler, parser.Next); err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
}

st := handler.Status
if st.Code() != codes.OK {
httpx.Error(w, st.Err())
httpx.ErrorCtx(r.Context(), w, st.Err())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rest/handler/timeouthandler.go
Expand Up @@ -99,7 +99,7 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler,
// so we can guarantee that cancelation in biz related code won't come here.
httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
} else {
Expand Down
74 changes: 72 additions & 2 deletions rest/httpx/responses.go
@@ -1,6 +1,7 @@
package httpx

import (
"context"
"encoding/json"
"net/http"
"sync"
Expand All @@ -11,8 +12,9 @@ import (
)

var (
errorHandler func(error) (int, interface{})
lock sync.RWMutex
errorHandler func(error) (int, interface{})
lock sync.RWMutex
errorHandlerCtx func(context.Context, error) (int, interface{})
)

// Error writes err into w.
Expand Down Expand Up @@ -87,3 +89,71 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}

// Error writes err into w.
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handlerCtx := errorHandlerCtx
lock.RUnlock()

if handlerCtx == nil {
if len(fns) > 0 {
fns[0](w, err)
} else if errcode.IsGrpcError(err) {
// don't unwrap error and get status.Message(),
// it hides the rpc error headers.
http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
} else {
http.Error(w, err.Error(), http.StatusBadRequest)
}

return
}

code, body := handlerCtx(ctx, err)
if body == nil {
w.WriteHeader(code)
return
}

e, ok := body.(error)
if ok {
http.Error(w, e.Error(), code)
} else {
WriteJsonCtx(ctx, w, code, body)
}
}

// OkJson writes v into w with 200 OK.
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
WriteJsonCtx(ctx, w, http.StatusOK, v)
}

// WriteJson writes v as json string into w with code.
func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
bs, err := json.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set(ContentType, header.JsonContentType)
w.WriteHeader(code)

if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
logx.WithContext(ctx).Errorf("write response failed, error: %s", err)
}
} else if n < len(bs) {
logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}

// SetErrorHandler sets the error handler, which is called on calling Error.
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
lock.Lock()
defer lock.Unlock()
errorHandlerCtx = handlerCtx
}
113 changes: 113 additions & 0 deletions rest/httpx/responses_test.go
@@ -1,6 +1,7 @@
package httpx

import (
"context"
"errors"
"net/http"
"strings"
Expand Down Expand Up @@ -214,3 +215,115 @@ func (w *tracedResponseWriter) WriteHeader(code int) {
w.wroteHeader = true
w.code = code
}

func TestErrorCtx(t *testing.T) {
const (
body = "foo"
wrappedBody = `"foo"`
)

tests := []struct {
name string
input string
errorHandlerCtx func(context.Context, error) (int, interface{})
expectHasBody bool
expectBody string
expectCode int
}{
{
name: "default error handler",
input: body,
expectHasBody: true,
expectBody: body,
expectCode: http.StatusBadRequest,
},
{
name: "customized error handler return string",
input: body,
errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
return http.StatusForbidden, err.Error()
},
expectHasBody: true,
expectBody: wrappedBody,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return error",
input: body,
errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
return http.StatusForbidden, err
},
expectHasBody: true,
expectBody: body,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return nil",
input: body,
errorHandlerCtx: func(context.Context, error) (int, interface{}) {
return http.StatusForbidden, nil
},
expectHasBody: false,
expectBody: "",
expectCode: http.StatusForbidden,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
if test.errorHandlerCtx != nil {
lock.RLock()
prev := errorHandlerCtx
lock.RUnlock()
SetErrorHandlerCtx(test.errorHandlerCtx)
defer func() {
lock.Lock()
test.errorHandlerCtx = prev
lock.Unlock()
}()
}
ErrorCtx(context.Background(), &w, errors.New(test.input))
assert.Equal(t, test.expectCode, w.code)
assert.Equal(t, test.expectHasBody, w.hasBody)
assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
})
}

//The current handler is a global event,Set default values to avoid impacting subsequent unit tests
SetErrorHandlerCtx(nil)
}

func TestErrorWithGrpcErrorCtx(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo"))
assert.Equal(t, http.StatusServiceUnavailable, w.code)
assert.True(t, w.hasBody)
assert.True(t, strings.Contains(w.builder.String(), "foo"))
}

func TestErrorWithHandlerCtx(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), 499)
})
assert.Equal(t, 499, w.code)
assert.True(t, w.hasBody)
assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
}

func TestWriteJsonCtxMarshalFailed(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]interface{}{
"Data": complex(0, 0),
})
assert.Equal(t, http.StatusInternalServerError, w.code)
}
6 changes: 3 additions & 3 deletions tools/goctl/api/gogen/handler.tpl
Expand Up @@ -11,16 +11,16 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
{{if .HasRequest}}var req types.{{.RequestType}}
if err := httpx.Parse(r, &req); err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
return
}

{{end}}l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx)
{{if .HasResp}}resp, {{end}}err := l.{{.Call}}({{if .HasRequest}}&req{{end}})
if err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
} else {
{{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}}
{{if .HasResp}}httpx.OkJsonCtx(r.Context(), w, resp){{else}}httpx.Ok(w){{end}}
}
}
}