diff --git a/api/client.go b/api/client.go index 6a804091a2194..de6b58bfb3e5b 100644 --- a/api/client.go +++ b/api/client.go @@ -1249,6 +1249,83 @@ START: return result, nil } +// requestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is +// useful when making requests where the req or resp body is likely larger than available memory +func (c *Client) requestWithContext(ctx context.Context, r *Request) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body) + if err != nil { + return nil, err + } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + + resp, err := c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return nil, err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return nil, fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return nil, err + } + + return result, nil +} + type ( RequestCallback func(*Request) ResponseCallback func(*Response) diff --git a/api/sys_raft.go b/api/sys_raft.go index 89c49d8082f80..8a39961457e63 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -14,7 +13,6 @@ import ( "time" "github.com/hashicorp/go-secure-stdlib/parseutil" - "github.com/hashicorp/vault/sdk/helper/consts" "github.com/mitchellh/mapstructure" ) @@ -138,10 +136,11 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") r.URL.RawQuery = r.Params.Encode() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - resp, err := c.requestWithContext(ctx, r, nil) + resp, err := c.c.requestWithContext(context.Background(), r) + if err != nil { + return err + } + defer resp.Body.Close() // Make sure that the last file in the archive, SHA256SUMS.sealed, is present // and non-empty. This is to catch cases where the snapshot failed midstream, @@ -209,14 +208,13 @@ func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { } r := c.c.NewRequest(http.MethodPost, path) + r.Body = snapReader - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - _, err := c.requestWithContext(ctx, r, snapReader) + resp, err := c.c.requestWithContext(context.Background(), r) if err != nil { return err } + defer resp.Body.Close() return nil } @@ -315,86 +313,3 @@ func (c *Sys) PutRaftAutopilotConfiguration(opts *AutopilotConfig) error { return nil } - -// requestWithContext avoids the use of the go-retryable library and is useful when -// making requests where the req or resp body is likely larger than available memory -func (c *Sys) requestWithContext(ctx context.Context, r *Request, body io.Reader) (*Response, error) { - req, err := http.NewRequest(r.Method, r.URL.RequestURI(), body) - if err != nil { - return nil, err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } - - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } - - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } - - var result *Response - - resp, err := c.c.config.HttpClient.Do(req) - if err != nil { - return nil, err - } - - if resp == nil { - return nil, err - } - - defer resp.Body.Close() - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return nil, err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return nil, fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc - - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return nil, err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return nil, err - } - - return result, nil -}