diff --git a/picker_wrapper.go b/picker_wrapper.go index e8367cb8993..d4e4c296854 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -105,15 +105,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. var errStr string if lastPickErr != nil { errStr = "latest balancer error: " + lastPickErr.Error() - } else { - errStr = ctx.Err().Error() - } - switch ctx.Err() { - case context.DeadlineExceeded: - return nil, nil, status.Error(codes.DeadlineExceeded, errStr) - case context.Canceled: - return nil, nil, status.Error(codes.Canceled, errStr) } + return nil, nil, status.FromContextError(ctx.Err(), errStr).Err() case <-ch: } continue diff --git a/status/status.go b/status/status.go index 7a0e324f745..be49c87ee9a 100644 --- a/status/status.go +++ b/status/status.go @@ -121,17 +121,23 @@ func Code(err error) codes.Code { // FromContextError converts a context error or wrapped context error into a // Status. It returns a Status with codes.OK if err is nil, or a Status with // codes.Unknown if err is non-nil and not a context error. -func FromContextError(err error) *Status { +// +// If msg != "" and err != nil, msg is used as the Status' message. If msg == "", +// err.Error() is used. +func FromContextError(err error, msg string) *Status { if err == nil { return nil } + if msg == "" { + msg = err.Error() + } if errors.Is(err, context.DeadlineExceeded) { - return New(codes.DeadlineExceeded, err.Error()) + return New(codes.DeadlineExceeded, msg) } if errors.Is(err, context.Canceled) { - return New(codes.Canceled, err.Error()) + return New(codes.Canceled, msg) } - return New(codes.Unknown, err.Error()) + return New(codes.Unknown, msg) } // MustFromContextError is like FromContextError, except that it expects err to @@ -140,5 +146,5 @@ func MustFromContextError(err error) error { if err == nil { return status.New(codes.Internal, "Expected non-nil context error").Err() } - return FromContextError(err).Err() + return FromContextError(err, "").Err() } diff --git a/status/status_test.go b/status/status_test.go index 420fb6b8102..91114774d86 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -358,17 +358,20 @@ func mustMarshalAny(msg proto.Message) *apb.Any { func (s) TestFromContextError(t *testing.T) { testCases := []struct { in error + msg string want *Status }{ {in: nil, want: New(codes.OK, "")}, + {in: nil, msg: "ignored", want: New(codes.OK, "")}, {in: context.DeadlineExceeded, want: New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, {in: context.Canceled, want: New(codes.Canceled, context.Canceled.Error())}, {in: errors.New("other"), want: New(codes.Unknown, "other")}, + {in: errors.New("other"), msg: "my msg", want: New(codes.Unknown, "my msg")}, {in: fmt.Errorf("wrapped: %w", context.DeadlineExceeded), want: New(codes.DeadlineExceeded, "wrapped: "+context.DeadlineExceeded.Error())}, {in: fmt.Errorf("wrapped: %w", context.Canceled), want: New(codes.Canceled, "wrapped: "+context.Canceled.Error())}, } for _, tc := range testCases { - got := FromContextError(tc.in) + got := FromContextError(tc.in, tc.msg) if got.Code() != tc.want.Code() || got.Message() != tc.want.Message() { t.Errorf("FromContextError(%v) = %v; want %v", tc.in, got, tc.want) }