-
Notifications
You must be signed in to change notification settings - Fork 2
/
grpcerr_generated_test.go.tmpl
124 lines (103 loc) · 2.26 KB
/
grpcerr_generated_test.go.tmpl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Code generated by "github.com/hedhyw/semerr"; DO NOT EDIT.
package grpcerr_test
import (
"errors"
"fmt"
"testing"
"github.com/hedhyw/semerr/pkg/v1/grpcerr"
"github.com/hedhyw/semerr/pkg/v1/semerr"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestCode(t *testing.T) {
t.Parallel()
const err = semerr.Error("some error")
testCases := []struct {
Err error
Code codes.Code
}{
{
Err: nil,
Code: codes.OK,
},
{
Err: err,
Code: codes.Unknown,
},
{
Err: status.Error(codes.AlreadyExists, "already found"),
Code: codes.AlreadyExists,
},
{{- range $errorDef := . }}
{
Err: semerr.New{{ $errorDef.Name }}(err),
Code: {{ $errorDef.GRPCStatus }},
},
{{- end }}
}
for _, tc := range testCases {
tc := tc
t.Run(fmt.Sprint(tc.Err), func(t *testing.T) {
t.Parallel()
err := tc.Err
gotCode := grpcerr.Code(err)
if tc.Code != gotCode {
t.Fatal("exp", tc.Code, "got", gotCode)
}
if err != nil {
err = fmt.Errorf("wrapped: 1: %w", err)
err = fmt.Errorf("wrapped: 2: %w", err)
gotCode = grpcerr.Code(err)
if tc.Code != gotCode {
t.Fatal("exp", tc.Code, "got", gotCode)
}
}
})
}
}
func TestWrap(t *testing.T) {
t.Parallel()
const err = semerr.Error("some error")
testCases := []struct {
Code codes.Code
Check func(err error) bool
}{
{
Check: func(actualErr error) bool {
return err == actualErr
},
Code: 100,
},
{{- range $errorDef := . }}
{{- if $errorDef.Reverse }}
{
Check: func(err error) bool {
return errors.As(err, &semerr.{{ $errorDef.Name }}{})
},
Code: {{ $errorDef.GRPCStatus }},
},
{{- end }}
{{- end }}
}
for _, tc := range testCases {
tc := tc
t.Run(fmt.Sprint(tc.Code), func(t *testing.T) {
t.Parallel()
if err := grpcerr.Wrap(err, tc.Code); !tc.Check(err) {
t.Fatalf("%T", err)
}
})
}
}
func TestJoin(t *testing.T) {
t.Parallel()
const err = semerr.Error("some error")
gotCode := grpcerr.Code(errors.Join(
fmt.Errorf("regular: %w", err),
fmt.Errorf("bad request: %w", semerr.NewBadRequestError(err)),
semerr.NewNotFoundError(fmt.Errorf("not found: %w", err)),
))
if gotCode != codes.InvalidArgument {
t.Fatal("exp", codes.InvalidArgument, "got", gotCode)
}
}