diff --git a/client_test.go b/client_test.go index 90092f4e3d..13da521033 100644 --- a/client_test.go +++ b/client_test.go @@ -1566,6 +1566,56 @@ func TestClientFollowRedirects(t *testing.T) { ReleaseResponse(resp) } + for i := 0; i < 10; i++ { + req := AcquireRequest() + resp := AcquireResponse() + + req.SetRequestURI("http://xxx/foo") + + req.SetTimeout(time.Second) + err := c.DoRedirects(req, resp, 16) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if statusCode := resp.StatusCode(); statusCode != StatusOK { + t.Fatalf("unexpected status code: %d", statusCode) + } + + if body := string(resp.Body()); body != "/bar" { + t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") + } + + ReleaseRequest(req) + ReleaseResponse(resp) + } + + for i := 0; i < 10; i++ { + req := AcquireRequest() + resp := AcquireResponse() + + req.SetRequestURI("http://xxx/foo") + + testConn, _ := net.Dial("tcp", ln.Addr().String()) + timeoutConn := &Client{ + Dial: func(addr string) (net.Conn, error) { + return &readTimeoutConn{Conn: testConn, t: time.Second}, nil + }, + } + + req.SetTimeout(time.Millisecond) + err := timeoutConn.DoRedirects(req, resp, 16) + if err == nil { + t.Errorf("expecting error") + } + if err != ErrTimeout { + t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) + } + + ReleaseRequest(req) + ReleaseResponse(resp) + } + req := AcquireRequest() resp := AcquireResponse() @@ -1613,6 +1663,7 @@ func TestClientDoTimeoutSuccess(t *testing.T) { defer s.Stop() testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) + testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) } func TestClientDoTimeoutSuccessConcurrent(t *testing.T) { @@ -1627,6 +1678,7 @@ func TestClientDoTimeoutSuccessConcurrent(t *testing.T) { go func() { defer wg.Done() testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) + testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) }() } wg.Wait() @@ -1687,6 +1739,7 @@ func TestClientDoTimeoutError(t *testing.T) { } testClientDoTimeoutError(t, c, 100) + testClientRequestSetTimeoutError(t, c, 100) } func TestClientDoTimeoutErrorConcurrent(t *testing.T) { @@ -1748,6 +1801,22 @@ func testClientGetTimeoutError(t *testing.T, c *Client, n int) { } } +func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) { + var req Request + var resp Response + req.SetRequestURI("http://foobar.com/baz") + for i := 0; i < n; i++ { + req.SetTimeout(time.Millisecond) + err := c.Do(&req, &resp) + if err == nil { + t.Errorf("expecting error") + } + if err != ErrTimeout { + t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout) + } + } +} + type readTimeoutConn struct { net.Conn t time.Duration @@ -2398,6 +2467,30 @@ func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) { } } +func testClientRequestSetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) { + var req Request + var resp Response + + for i := 0; i < n; i++ { + uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i) + req.SetRequestURI(uri) + req.SetTimeout(time.Second) + if err := c.Do(&req, &resp); err != nil { + t.Errorf("unexpected error: %v", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + resultURI := string(resp.Body()) + if strings.HasPrefix(uri, "https") { + resultURI = uri[:5] + resultURI[4:] + } + if resultURI != uri { + t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri) + } + } +} + func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) { var buf []byte for i := 0; i < n; i++ { diff --git a/http.go b/http.go index 4ef3744214..4b9e113a46 100644 --- a/http.go +++ b/http.go @@ -2290,3 +2290,11 @@ func round2(n int) int { return int(x + 1) } + +// SetTimeout sets timeout for the request. +// +// req.SetTimeout(t); c.Do(&req, &resp) is equivalent to +// c.DoTimeout(&req, &resp, t) +func (req *Request) SetTimeout(t time.Duration) { + req.timeout = t +}