From 3a8aa01a6a38b87c41fff110b7c749d40825ae0c Mon Sep 17 00:00:00 2001 From: Andrei Matei Date: Fri, 7 Jan 2022 17:34:44 -0500 Subject: [PATCH] picker_wrapper: improve handling of context errors pickerWrapper had logic very similar to status.FromContextError() for transforming Context errors to status errors. This patch removes the duplication by delegating to the status library. Besides removing the code duplication, the status library is arguably more robust because it doesn't rely on ctx.Error() to only ever return two types of errors. I believe this patch and the previous one stand on their own, but, FWIW, they're also motivating by me wanting to experiment in the CockroachDB codebase with using a custom implementation of context.Context whose Err() method can return better errors than the stdlib context.Context. These errors would still wrap context.Canceled. Such an implementation would technically break the documentation of context.Context, which seems to exhaustively list the sentinel error that context.Context can return. Still, as https://github.com/grpc/grpc-go/pull/4977 showed, most code should support wrapped context errors. This patch moves from "most code" to "all code" in gRPC. I haven't checked which of the callsites I've touched use contexts that might be inherited from a gRPC client, as opposed to contexts derived inside gRPC from context.Background (which contexts would not be affected by whatever I do outside of gRPC), but unifying all the context error handling code seems like a good idea to me universally. --- picker_wrapper.go | 9 +-------- status/status.go | 16 +++++++++++----- status/status_test.go | 5 ++++- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/picker_wrapper.go b/picker_wrapper.go index e8367cb8993..d4e4c296854 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -105,15 +105,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. var errStr string if lastPickErr != nil { errStr = "latest balancer error: " + lastPickErr.Error() - } else { - errStr = ctx.Err().Error() - } - switch ctx.Err() { - case context.DeadlineExceeded: - return nil, nil, status.Error(codes.DeadlineExceeded, errStr) - case context.Canceled: - return nil, nil, status.Error(codes.Canceled, errStr) } + return nil, nil, status.FromContextError(ctx.Err(), errStr).Err() case <-ch: } continue diff --git a/status/status.go b/status/status.go index 7a0e324f745..be49c87ee9a 100644 --- a/status/status.go +++ b/status/status.go @@ -121,17 +121,23 @@ func Code(err error) codes.Code { // FromContextError converts a context error or wrapped context error into a // Status. It returns a Status with codes.OK if err is nil, or a Status with // codes.Unknown if err is non-nil and not a context error. -func FromContextError(err error) *Status { +// +// If msg != "" and err != nil, msg is used as the Status' message. If msg == "", +// err.Error() is used. +func FromContextError(err error, msg string) *Status { if err == nil { return nil } + if msg == "" { + msg = err.Error() + } if errors.Is(err, context.DeadlineExceeded) { - return New(codes.DeadlineExceeded, err.Error()) + return New(codes.DeadlineExceeded, msg) } if errors.Is(err, context.Canceled) { - return New(codes.Canceled, err.Error()) + return New(codes.Canceled, msg) } - return New(codes.Unknown, err.Error()) + return New(codes.Unknown, msg) } // MustFromContextError is like FromContextError, except that it expects err to @@ -140,5 +146,5 @@ func MustFromContextError(err error) error { if err == nil { return status.New(codes.Internal, "Expected non-nil context error").Err() } - return FromContextError(err).Err() + return FromContextError(err, "").Err() } diff --git a/status/status_test.go b/status/status_test.go index 420fb6b8102..91114774d86 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -358,17 +358,20 @@ func mustMarshalAny(msg proto.Message) *apb.Any { func (s) TestFromContextError(t *testing.T) { testCases := []struct { in error + msg string want *Status }{ {in: nil, want: New(codes.OK, "")}, + {in: nil, msg: "ignored", want: New(codes.OK, "")}, {in: context.DeadlineExceeded, want: New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, {in: context.Canceled, want: New(codes.Canceled, context.Canceled.Error())}, {in: errors.New("other"), want: New(codes.Unknown, "other")}, + {in: errors.New("other"), msg: "my msg", want: New(codes.Unknown, "my msg")}, {in: fmt.Errorf("wrapped: %w", context.DeadlineExceeded), want: New(codes.DeadlineExceeded, "wrapped: "+context.DeadlineExceeded.Error())}, {in: fmt.Errorf("wrapped: %w", context.Canceled), want: New(codes.Canceled, "wrapped: "+context.Canceled.Error())}, } for _, tc := range testCases { - got := FromContextError(tc.in) + got := FromContextError(tc.in, tc.msg) if got.Code() != tc.want.Code() || got.Message() != tc.want.Message() { t.Errorf("FromContextError(%v) = %v; want %v", tc.in, got, tc.want) }