Skip to content

Commit

Permalink
plugins/rest: Add response header timeout for REST client
Browse files Browse the repository at this point in the history
This commit adds a new configurable timeout to the Services
config to set the amount of time to wait for the server's
response headers. With this change, the client will no longer
wait indefinitely for the HTTP request to complete.

Signed-off-by: Ashutosh Narkar <anarkar4387@gmail.com>
(cherry picked from commit b48aba8)
  • Loading branch information
ashutosh-narkar authored and tsandall committed Aug 20, 2020
1 parent 1d15c5e commit a790b65
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/content/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The file can be either JSON or YAML format.
services:
acmecorp:
url: https://example.com/control-plane-api/v1
response_header_timeout_seconds: 5
credentials:
bearer:
token: "bGFza2RqZmxha3NkamZsa2Fqc2Rsa2ZqYWtsc2RqZmtramRmYWxkc2tm"
Expand Down Expand Up @@ -236,6 +237,7 @@ multiple services.
| --- | --- | --- | --- |
| `services[_].name` | `string` | Yes | Unique name for the service. Referred to by plugins. |
| `services[_].url` | `string` | Yes | Base URL to contact the service with. |
| `services[_].response_header_timeout_seconds` | `int64` | No (default: 10) | Amount of time to wait for a server's response headers after fully writing the request. This time does not include the time to read the response body. |
| `services[_].headers` | `object` | No | HTTP headers to include in requests to the service. |
| `services[_].allow_insecure_tls` | `bool` | No | Allow insecure TLS. |

Expand Down
26 changes: 20 additions & 6 deletions plugins/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
"github.com/open-policy-agent/opa/util"
)

const (
defaultResponseHeaderTimeoutSeconds = int64(10)
)

// An HTTPAuthPlugin represents a mechanism to construct and configure HTTP authentication for a REST service
type HTTPAuthPlugin interface {
// implementations can assume NewClient will be called before Prepare
Expand All @@ -31,11 +35,12 @@ type HTTPAuthPlugin interface {

// Config represents configuration for a REST client.
type Config struct {
Name string `json:"name"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
AllowInsureTLS bool `json:"allow_insecure_tls,omitempty"`
Credentials struct {
Name string `json:"name"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
AllowInsureTLS bool `json:"allow_insecure_tls,omitempty"`
ResponseHeaderTimeoutSeconds *int64 `json:"response_header_timeout_seconds,omitempty"`
Credentials struct {
Bearer *bearerAuthPlugin `json:"bearer,omitempty"`
ClientTLS *clientTLSAuthPlugin `json:"client_tls,omitempty"`
S3Signing *awsSigningAuthPlugin `json:"s3_signing,omitempty"`
Expand Down Expand Up @@ -108,6 +113,12 @@ func New(config []byte, opts ...func(*Client)) (Client, error) {

parsedConfig.URL = strings.TrimRight(parsedConfig.URL, "/")

if parsedConfig.ResponseHeaderTimeoutSeconds == nil {
timeout := new(int64)
*timeout = defaultResponseHeaderTimeoutSeconds
parsedConfig.ResponseHeaderTimeoutSeconds = timeout
}

client := Client{
config: parsedConfig,
}
Expand Down Expand Up @@ -205,7 +216,10 @@ func (c Client) Do(ctx context.Context, method, path string) (*http.Response, er
req.Header.Add(key, value)
}

req = req.WithContext(ctx)
hCtx, cancel := context.WithCancel(ctx)
defer cancel()

req = req.WithContext(hCtx)

err = c.config.authPrepare(req)
if err != nil {
Expand Down
11 changes: 6 additions & 5 deletions plugins/rest/rest_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func defaultTLSConfig(c Config) (*tls.Config, error) {
}

// defaultRoundTripperClient is a reasonable set of defaults for HTTP auth plugins
func defaultRoundTripperClient(t *tls.Config) *http.Client {
func defaultRoundTripperClient(t *tls.Config, timeout int64) *http.Client {
// Ensure we use a http.Transport with proper settings: the zero values are not
// a good choice, as they cause leaking connections:
// https://github.com/golang/go/issues/19620
Expand All @@ -58,6 +58,7 @@ func defaultRoundTripperClient(t *tls.Config) *http.Client {
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: time.Duration(timeout) * time.Second,
TLSClientConfig: t,
}

Expand All @@ -75,7 +76,7 @@ func (ap *defaultAuthPlugin) NewClient(c Config) (*http.Client, error) {
if err != nil {
return nil, err
}
return defaultRoundTripperClient(t), nil
return defaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil
}

func (ap *defaultAuthPlugin) Prepare(req *http.Request) error {
Expand Down Expand Up @@ -103,7 +104,7 @@ func (ap *bearerAuthPlugin) NewClient(c Config) (*http.Client, error) {
ap.Scheme = "Bearer"
}

return defaultRoundTripperClient(t), nil
return defaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil
}

func (ap *bearerAuthPlugin) Prepare(req *http.Request) error {
Expand Down Expand Up @@ -192,7 +193,7 @@ func (ap *clientTLSAuthPlugin) NewClient(c Config) (*http.Client, error) {
}

tlsConfig.Certificates = []tls.Certificate{cert}
client := defaultRoundTripperClient(tlsConfig)
client := defaultRoundTripperClient(tlsConfig, *c.ResponseHeaderTimeoutSeconds)
return client, nil
}

Expand Down Expand Up @@ -227,7 +228,7 @@ func (ap *awsSigningAuthPlugin) NewClient(c Config) (*http.Client, error) {
return nil, errors.New("at least aws_region must be specified for AWS metadata credential service")
}
}
return defaultRoundTripperClient(t), nil
return defaultRoundTripperClient(t, *c.ResponseHeaderTimeoutSeconds), nil
}

func (ap *awsSigningAuthPlugin) Prepare(req *http.Request) error {
Expand Down
79 changes: 79 additions & 0 deletions plugins/rest/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,81 @@ func TestNew(t *testing.T) {
if err != nil && !tc.wantErr {
t.Fatalf("Unexpected error: %v", err)
}

if *client.config.ResponseHeaderTimeoutSeconds != defaultResponseHeaderTimeoutSeconds {
t.Fatalf("Expected default response header timeout but got %v seconds", *client.config.ResponseHeaderTimeoutSeconds)
}

results = append(results, client)
}

if results[3].config.Credentials.Bearer.Scheme != "Acmecorp-Token" {
t.Fatalf("Expected custom token but got: %v", results[3].config.Credentials.Bearer.Scheme)
}
}

func TestNewWithResponseHeaderTimeout(t *testing.T) {
input := `{
"name": "foo",
"url": "http://localhost",
"response_header_timeout_seconds": 20
}`

client, err := New([]byte(input))
if err != nil {
t.Fatal("Unexpected error")
}

if *client.config.ResponseHeaderTimeoutSeconds != 20 {
t.Fatalf("Expected response header timeout %v seconds but got %v seconds", 20, *client.config.ResponseHeaderTimeoutSeconds)
}
}

func TestDoWithResponseHeaderTimeout(t *testing.T) {
ctx := context.Background()

tests := map[string]struct {
d time.Duration
responseHeaderTimeout string
wantErr bool
errMsg string
}{
"response_headers_timeout_not_met": {1, "2", false, ""},
"response_headers_timeout_met": {2, "1", true, "net/http: timeout awaiting response headers"},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {

baseURL, teardown := getTestServerWithTimeout(tc.d)
defer teardown()

config := fmt.Sprintf(`{
"name": "foo",
"url": %q,
"response_header_timeout_seconds": %v,
}`, baseURL, tc.responseHeaderTimeout)
client, err := New([]byte(config))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

_, err = client.Do(ctx, "GET", "/v1/test")
if tc.wantErr {
if err == nil {
t.Fatal("Expected error but got nil")
}

if !strings.Contains(err.Error(), tc.errMsg) {
t.Fatalf("Expected error %v but got %v", tc.errMsg, err.Error())
}
} else {
if err != nil {
t.Fatalf("Unexpected error %v", err)
}
}
})
}
}

func TestValidUrl(t *testing.T) {
Expand Down Expand Up @@ -612,3 +680,14 @@ func createCert(template, parent *x509.Certificate, pub interface{}, parentPriv
certPEM = pem.EncodeToMemory(&b)
return
}

func getTestServerWithTimeout(d time.Duration) (baseURL string, teardownFn func()) {
mux := http.NewServeMux()
ts := httptest.NewServer(mux)

mux.HandleFunc("/v1/test", func(w http.ResponseWriter, req *http.Request) {
time.Sleep(d * time.Second)
w.WriteHeader(http.StatusOK)
})
return ts.URL, ts.Close
}

0 comments on commit a790b65

Please sign in to comment.