Skip to content

Commit

Permalink
Add WithRetryPolicyFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
rafiramadhana committed Nov 13, 2023
1 parent c32379b commit 27e6a34
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 6 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ e.POST("/path").
Expect().
Status(http.StatusOK)

// custom retry policy
// custom built-in retry policy
e.POST("/path").
WithMaxRetries(5).
WithRetryPolicy(httpexpect.RetryAllErrors).
Expand All @@ -404,6 +404,15 @@ e.POST("/path").
WithRetryDelay(time.Second, time.Minute).
Expect().
Status(http.StatusOK)

// custom user-defined retry policy
e.POST("/path").
WithMaxRetries(5).
WithRetryPolicyFunc(func(resp *http.Response, err error) bool {
return resp.StatusCode == http.StatusTeapot
}).
Expect().
Status(http.StatusOK)
```

##### Subdomains and per-request URL
Expand Down
73 changes: 68 additions & 5 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ type Request struct {
redirectPolicy RedirectPolicy
maxRedirects int

retryPolicy RetryPolicy
maxRetries int
minRetryDelay time.Duration
maxRetryDelay time.Duration
sleepFn func(d time.Duration) <-chan time.Time
retryPolicy RetryPolicy
withRetryPolicyCalled bool
maxRetries int
minRetryDelay time.Duration
maxRetryDelay time.Duration
sleepFn func(d time.Duration) <-chan time.Time
retryPolicyFn func(*http.Response, error) bool

timeout time.Duration

Expand Down Expand Up @@ -755,7 +757,64 @@ func (r *Request) WithRetryPolicy(policy RetryPolicy) *Request {
return r
}

if r.retryPolicyFn != nil {
opChain.fail(AssertionFailure{
Type: AssertUsage,
Errors: []error{
fmt.Errorf("expected: " +
"WithRetryPolicyFunc() and WithRetryPolicy() should be mutual exclusive, " +
"WithRetryPolicyFunc() is already called"),
},
})
return r
}

r.retryPolicy = policy
r.withRetryPolicyCalled = true

return r
}

// WithRetryPolicyFunc sets a function to replace built-in policies
// with user-defined policy.
//
// The function expects you to return true to perform a retry. And false to
// not perform a retry.
//
// Example:
//
// req := NewRequestC(config, "POST", "/path")
// req.WithRetryPolicyFunc(func(res *http.Response, err error) bool {
// return resp.StatusCode == http.StatusTeapot
// })
func (r *Request) WithRetryPolicyFunc(fn func(res *http.Response, err error) bool) *Request {

Check failure on line 790 in request.go

View workflow job for this annotation

GitHub Actions / Linters for root

line is 93 characters (lll)
opChain := r.chain.enter("WithRetryPolicyFunc()")
defer opChain.leave()

r.mu.Lock()
defer r.mu.Unlock()

if opChain.failed() {
return r
}

if !r.checkOrder(opChain, "WithRetryPolicyFunc()") {
return r
}

if r.withRetryPolicyCalled {
opChain.fail(AssertionFailure{
Type: AssertUsage,
Errors: []error{
fmt.Errorf("expected: " +
"WithRetryPolicyFunc() and WithRetryPolicy() should be mutual exclusive, " +
"WithRetryPolicy() is already called"),
},
})
return r
}

r.retryPolicyFn = fn

return r
}
Expand Down Expand Up @@ -2332,6 +2391,10 @@ func (r *Request) retryRequest(reqFunc func() (*http.Response, error)) (
}

func (r *Request) shouldRetry(resp *http.Response, err error) bool {
if r.retryPolicyFn != nil {
return r.retryPolicyFn(resp, err)
}

var (
isTemporaryNetworkError bool // Deprecated
isTimeoutError bool
Expand Down
109 changes: 109 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3349,6 +3349,60 @@ func TestRequest_RetriesCancellation(t *testing.T) {
assert.Equal(t, 1, callCount)
}

func TestRequest_WithRetryPolicyFunc(t *testing.T) {
tests := []struct {
name string
fn func(res *http.Response, err error) bool
callCount int
}{
{
name: "should not retry",
fn: func(res *http.Response, err error) bool {
return false
},
callCount: 1,
},
{
name: "should retry",
fn: func(res *http.Response, err error) bool {
return true
},
callCount: 2,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
callCount := 0

client := &mockClient{
resp: http.Response{
StatusCode: http.StatusTeapot,
},
cb: func(req *http.Request) {
callCount++
},
}

cfg := Config{
Client: client,
Reporter: newMockReporter(t),
}

req := NewRequestC(cfg, http.MethodGet, "/url").
WithMaxRetries(1).
WithRetryDelay(0, 0).
WithRetryPolicyFunc(tt.fn)
req.chain.assert(t, success)

resp := req.Expect()
resp.chain.assert(t, success)

assert.Equal(t, tt.callCount, callCount)
})
}
}

func TestRequest_Conflicts(t *testing.T) {
client := &mockClient{}

Expand Down Expand Up @@ -3492,6 +3546,44 @@ func TestRequest_Conflicts(t *testing.T) {
})
}
})

t.Run("retry policy conflict", func(t *testing.T) {
cases := []struct {
name string
fn func(req *Request)
}{
{
"WithRetryPolicyFunc",
func(req *Request) {
req.WithRetryPolicyFunc(func(res *http.Response, err error) bool {
return res.StatusCode == http.StatusTeapot
})
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := NewRequestC(config, "GET", "url")

tc.fn(req)
req.chain.assert(t, success)

req.WithRetryPolicy(RetryAllErrors)
req.chain.assert(t, failure)
})

t.Run(tc.name+" - reversed", func(t *testing.T) {
req := NewRequestC(config, "GET", "url")

req.WithRetryPolicy(RetryAllErrors)
req.chain.assert(t, success)

tc.fn(req)
req.chain.assert(t, failure)
})
}
})
}

func TestRequest_Usage(t *testing.T) {
Expand Down Expand Up @@ -3642,6 +3734,15 @@ func TestRequest_Usage(t *testing.T) {
prepFails: false,
expectFails: true,
},
{
name: "WithRetryPolicyFunc - nil argument",
client: &mockClient{},
prepFunc: func(req *Request) {
req.WithRetryPolicyFunc(nil)
},
prepFails: false,
expectFails: false,
},
}

for _, tc := range cases {
Expand Down Expand Up @@ -3934,6 +4035,14 @@ func TestRequest_Order(t *testing.T) {
req.WithMultipart()
},
},
{
name: "WithRetryPolicyFunc after Expect",
afterFunc: func(req *Request) {
req.WithRetryPolicyFunc(func(res *http.Response, err error) bool {
return res.StatusCode == http.StatusTeapot
})
},
},
}

for _, tc := range cases {
Expand Down

0 comments on commit 27e6a34

Please sign in to comment.