Skip to content

Commit

Permalink
server: handle context errors returned by service handler (#5156)
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Jan 26, 2022
1 parent e277174 commit 61a6a06
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
9 changes: 6 additions & 3 deletions interceptor.go
Expand Up @@ -72,9 +72,12 @@ type UnaryServerInfo struct {
}

// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
// execution of a unary RPC. If a UnaryHandler returns an error, it should be produced by the
// status package, or else gRPC will use codes.Unknown as the status code and err.Error() as
// the status message of the RPC.
// execution of a unary RPC.
//
// If a UnaryHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
// codes.Unknown as the status code and err.Error() as the status message of the
// RPC.
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)

// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
Expand Down
11 changes: 7 additions & 4 deletions server.go
Expand Up @@ -1283,9 +1283,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
// Convert appErr if it is not a grpc status error.
appErr = status.Error(codes.Unknown, appErr.Error())
appStatus, _ = status.FromError(appErr)
// Convert non-status application error to a status error with code
// Unknown, but handle context errors specifically.
appStatus = status.FromContextError(appErr)
appErr = appStatus.Err()
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
Expand Down Expand Up @@ -1549,7 +1550,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
appStatus = status.New(codes.Unknown, appErr.Error())
// Convert non-status application error to a status error with code
// Unknown, but handle context errors specifically.
appStatus = status.FromContextError(appErr)
appErr = appStatus.Err()
}
if trInfo != nil {
Expand Down
10 changes: 6 additions & 4 deletions stream.go
Expand Up @@ -46,10 +46,12 @@ import (
)

// StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC. If a StreamHandler returns an error, it
// should be produced by the status package, or else gRPC will use
// codes.Unknown as the status code and err.Error() as the status message
// of the RPC.
// execution of a streaming RPC.
//
// If a StreamHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
// codes.Unknown as the status code and err.Error() as the status message of the
// RPC.
type StreamHandler func(srv interface{}, stream ServerStream) error

// StreamDesc represents a streaming RPC service's method specification. Used
Expand Down
35 changes: 35 additions & 0 deletions test/server_test.go
Expand Up @@ -32,6 +32,41 @@ import (

type ctxKey string

// TestServerReturningContextError verifies that if a context error is returned
// by the service handler, the status will have the correct status code, not
// Unknown.
func (s) TestServerReturningContextError(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return nil, context.DeadlineExceeded
},
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
return context.DeadlineExceeded
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
t.Fatalf("ss.Client.EmptyCall() got error %v; want <status with Code()=DeadlineExceeded>", err)
}

stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("unexpected error starting the stream: %v", err)
}
_, err = stream.Recv()
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want <status with Code()=DeadlineExceeded>", err)
}

}

func (s) TestChainUnaryServerInterceptor(t *testing.T) {
var (
firstIntKey = ctxKey("firstIntKey")
Expand Down

0 comments on commit 61a6a06

Please sign in to comment.