From 2a93c56e084c480f83fd642027a77dff6c3b15f5 Mon Sep 17 00:00:00 2001 From: "Mark S. Lewis" Date: Fri, 12 Nov 2021 18:40:15 +0000 Subject: [PATCH] Support wrapped errors in status.FromContextError Return an appropriate Status from status.FromContext error if either the supplied error or an error in its chain is one of the context sentinel error values. --- status/status.go | 19 ++++++++++--------- status/status_test.go | 2 ++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/status/status.go b/status/status.go index af2cffe985c..6d163b6e384 100644 --- a/status/status.go +++ b/status/status.go @@ -29,6 +29,7 @@ package status import ( "context" + "errors" "fmt" spb "google.golang.org/genproto/googleapis/rpc/status" @@ -117,18 +118,18 @@ func Code(err error) codes.Code { return codes.Unknown } -// FromContextError converts a context error into a Status. It returns a -// Status with codes.OK if err is nil, or a Status with codes.Unknown if err is -// non-nil and not a context error. +// FromContextError converts a context error or wrapped context error into a +// Status. It returns a Status with codes.OK if err is nil, or a Status with +// codes.Unknown if err is non-nil and not a context error. func FromContextError(err error) *Status { - switch err { - case nil: + if err == nil { return nil - case context.DeadlineExceeded: + } + if errors.Is(err, context.DeadlineExceeded) { return New(codes.DeadlineExceeded, err.Error()) - case context.Canceled: + } + if errors.Is(err, context.Canceled) { return New(codes.Canceled, err.Error()) - default: - return New(codes.Unknown, err.Error()) } + return New(codes.Unknown, err.Error()) } diff --git a/status/status_test.go b/status/status_test.go index 839a3c390ed..420fb6b8102 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -364,6 +364,8 @@ func (s) TestFromContextError(t *testing.T) { {in: context.DeadlineExceeded, want: New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, {in: context.Canceled, want: New(codes.Canceled, context.Canceled.Error())}, {in: errors.New("other"), want: New(codes.Unknown, "other")}, + {in: fmt.Errorf("wrapped: %w", context.DeadlineExceeded), want: New(codes.DeadlineExceeded, "wrapped: "+context.DeadlineExceeded.Error())}, + {in: fmt.Errorf("wrapped: %w", context.Canceled), want: New(codes.Canceled, "wrapped: "+context.Canceled.Error())}, } for _, tc := range testCases { got := FromContextError(tc.in)