diff --git a/v2/apierror/apierror.go b/v2/apierror/apierror.go index 787a619..aa6be13 100644 --- a/v2/apierror/apierror.go +++ b/v2/apierror/apierror.go @@ -233,30 +233,49 @@ func (a *APIError) Metadata() map[string]string { } -// FromError parses a Status error or a googleapi.Error and builds an APIError. -func FromError(err error) (*APIError, bool) { - if err == nil { - return nil, false - } - - ae := APIError{err: err} +// setDetailsFromError parses a Status error or a googleapi.Error +// and sets status and details or httpErr and details, respectively. +// It returns false if neither Status nor googleapi.Error can be parsed. +func (a *APIError) setDetailsFromError(err error) bool { st, isStatus := status.FromError(err) var herr *googleapi.Error isHTTPErr := errors.As(err, &herr) switch { case isStatus: - ae.status = st - ae.details = parseDetails(st.Details()) + a.status = st + a.details = parseDetails(st.Details()) case isHTTPErr: - ae.httpErr = herr - ae.details = parseHTTPDetails(herr) + a.httpErr = herr + a.details = parseHTTPDetails(herr) default: - return nil, false + return false } + return true +} - return &ae, true +// FromError parses a Status error or a googleapi.Error and builds an +// APIError, wrapping the provided error in the new APIError. It +// returns false if neither Status nor googleapi.Error can be parsed. +func FromError(err error) (*APIError, bool) { + return ParseError(err, true) +} +// ParseError parses a Status error or a googleapi.Error and builds an +// APIError. If wrap is true, it wraps the error in the new APIError. +// It returns false if neither Status nor googleapi.Error can be parsed. +func ParseError(err error, wrap bool) (*APIError, bool) { + if err == nil { + return nil, false + } + ae := APIError{} + if wrap { + ae = APIError{err: err} + } + if !ae.setDetailsFromError(err) { + return nil, false + } + return &ae, true } // parseDetails accepts a slice of interface{} that should be backed by some diff --git a/v2/apierror/apierror_test.go b/v2/apierror/apierror_test.go index 358b918..ddd0419 100644 --- a/v2/apierror/apierror_test.go +++ b/v2/apierror/apierror_test.go @@ -363,21 +363,24 @@ func TestFromError(t *testing.T) { {&APIError{err: lS.Err(), status: lS, details: ErrDetails{LocalizedMessage: lo}}, true}, {&APIError{err: uS.Err(), status: uS, details: ErrDetails{Unknown: u}}, true}, {&APIError{err: hae, httpErr: hae, details: ErrDetails{ErrorInfo: httpErrInfo}}, true}, + {&APIError{err: errors.New("standard error")}, false}, } for _, tc := range tests { got, apiB := FromError(tc.apierr.err) if tc.b != apiB { - t.Errorf("got %v, want %v", apiB, tc.b) + t.Errorf("FromError(%s): got %v, want %v", tc.apierr.err, apiB, tc.b) } - if diff := cmp.Diff(got.details, tc.apierr.details, cmp.Comparer(proto.Equal)); diff != "" { - t.Errorf("got(-), want(+),: \n%s", diff) - } - if diff := cmp.Diff(got.status, tc.apierr.status, cmp.Comparer(proto.Equal), cmp.AllowUnexported(status.Status{})); diff != "" { - t.Errorf("got(-), want(+),: \n%s", diff) - } - if diff := cmp.Diff(got.err, tc.apierr.err, cmpopts.EquateErrors()); diff != "" { - t.Errorf("got(-), want(+),: \n%s", diff) + if tc.b { + if diff := cmp.Diff(got.details, tc.apierr.details, cmp.Comparer(proto.Equal)); diff != "" { + t.Errorf("got(-), want(+),: \n%s", diff) + } + if diff := cmp.Diff(got.status, tc.apierr.status, cmp.Comparer(proto.Equal), cmp.AllowUnexported(status.Status{})); diff != "" { + t.Errorf("got(-), want(+),: \n%s", diff) + } + if diff := cmp.Diff(got.err, tc.apierr.err, cmpopts.EquateErrors()); diff != "" { + t.Errorf("got(-), want(+),: \n%s", diff) + } } } if err, _ := FromError(nil); err != nil { @@ -389,6 +392,59 @@ func TestFromError(t *testing.T) { } } +func TestParseError(t *testing.T) { + httpErrInfo := &errdetails.ErrorInfo{Reason: "just because", Domain: "tests"} + any, err := anypb.New(httpErrInfo) + if err != nil { + t.Fatal(err) + } + e := &jsonerror.Error{Error: &jsonerror.Error_Status{Details: []*anypb.Any{any}}} + data, err := protojson.Marshal(e) + if err != nil { + t.Fatal(err) + } + hae := &googleapi.Error{ + Body: string(data), + } + + se := errors.New("standard error") + + tests := []struct { + source error + apierr *APIError + b bool + }{ + {hae, &APIError{httpErr: hae, details: ErrDetails{ErrorInfo: httpErrInfo}}, true}, + {se, &APIError{err: se}, false}, + } + + for _, tc := range tests { + // ParseError with wrap = true is covered by TestFromError, above. + got, apiB := ParseError(tc.source, false) + if tc.b != apiB { + t.Errorf("ParseError(%s, false): got %v, want %v", tc.apierr, apiB, tc.b) + } + if tc.b { + if diff := cmp.Diff(got.details, tc.apierr.details, cmp.Comparer(proto.Equal)); diff != "" { + t.Errorf("got(-), want(+),: \n%s", diff) + } + if diff := cmp.Diff(got.status, tc.apierr.status, cmp.Comparer(proto.Equal), cmp.AllowUnexported(status.Status{})); diff != "" { + t.Errorf("got(-), want(+),: \n%s", diff) + } + if got.err != nil { + t.Errorf("got %s, want nil", got.err) + } + } + } + if err, _ := ParseError(nil, false); err != nil { + t.Errorf("got %s, want nil", err) + } + + if c, _ := ParseError(context.DeadlineExceeded, false); c != nil { + t.Errorf("got %s, want nil", c) + } +} + func golden(name, got string) (string, error) { g := filepath.Join("testdata", name+".golden") if *update {