Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: handle context errors returned by service handler #5156

Merged
merged 4 commits into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions interceptor.go
Expand Up @@ -72,9 +72,12 @@ type UnaryServerInfo struct {
}

// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
// execution of a unary RPC. If a UnaryHandler returns an error, it should be produced by the
// status package, or else gRPC will use codes.Unknown as the status code and err.Error() as
// the status message of the RPC.
// execution of a unary RPC.
//
// If a UnaryHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
// codes.Unknown as the status code and err.Error() as the status message of the
// RPC.
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)

// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
Expand Down
11 changes: 7 additions & 4 deletions server.go
Expand Up @@ -1283,9 +1283,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
// Convert appErr if it is not a grpc status error.
appErr = status.Error(codes.Unknown, appErr.Error())
appStatus, _ = status.FromError(appErr)
// Convert non-status application error to a status error with code
// Unknown, but handle context errors specifically.
appStatus = status.FromContextError(appErr)
appErr = appStatus.Err()
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
Expand Down Expand Up @@ -1549,7 +1550,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
appStatus = status.New(codes.Unknown, appErr.Error())
// Convert non-status application error to a status error with code
// Unknown, but handle context errors specifically.
appStatus = status.FromContextError(appErr)
appErr = appStatus.Err()
}
if trInfo != nil {
Expand Down
10 changes: 6 additions & 4 deletions stream.go
Expand Up @@ -46,10 +46,12 @@ import (
)

// StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC. If a StreamHandler returns an error, it
// should be produced by the status package, or else gRPC will use
// codes.Unknown as the status code and err.Error() as the status message
// of the RPC.
// execution of a streaming RPC.
//
// If a StreamHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
// codes.Unknown as the status code and err.Error() as the status message of the
// RPC.
type StreamHandler func(srv interface{}, stream ServerStream) error

// StreamDesc represents a streaming RPC service's method specification. Used
Expand Down
35 changes: 35 additions & 0 deletions test/server_test.go
Expand Up @@ -32,6 +32,41 @@ import (

type ctxKey string

// TestServerReturningContextError verifies that if a context error is returned
// by the service handler, the status will have the correct status code, not
// Unknown.
func (s) TestServerReturningContextError(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return nil, context.DeadlineExceeded
},
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
return context.DeadlineExceeded
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
t.Fatalf("ss.Client.EmptyCall() got error %v; want <status with Code()=DeadlineExceeded>", err)
}

stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("unexpected error starting the stream: %v", err)
}
_, err = stream.Recv()
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want <status with Code()=DeadlineExceeded>", err)
}

}

func (s) TestChainUnaryServerInterceptor(t *testing.T) {
var (
firstIntKey = ctxKey("firstIntKey")
Expand Down