From 6c9094b5d3333361dafca2a93e3c922e7fd29050 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Thu, 24 Feb 2022 23:47:57 +0000 Subject: [PATCH 01/13] backport of commit 8023e4c2fa72d6bcc80682cb5e5e35f8b692e4e1 --- api/sys_raft.go | 81 +++++++++++++++++++++++++++++++++++++++++---- changelog/14269.txt | 3 ++ 2 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 changelog/14269.txt diff --git a/api/sys_raft.go b/api/sys_raft.go index faa62eb3e0f93..c0ee949b7846e 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -272,24 +272,93 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { } // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. +// snapshot, returning the cluster to the state defined by it. This avoids the use of +// RawRequestWithContext which copies the body (leading to possible OOMs) for retrying func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { path = "/v1/sys/storage/raft/snapshot-force" } - r := c.c.NewRequest("POST", path) - r.Body = snapReader + r := c.c.NewRequest(http.MethodPost, path) + r.URL.RawQuery = r.Params.Encode() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - resp, err := c.c.RawRequestWithContext(ctx, r) + req, err := http.NewRequest(http.MethodPost, r.URL.RequestURI(), snapReader) if err != nil { return err } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + resp, err := c.c.config.HttpClient.Do(req) defer resp.Body.Close() + if err != nil { + return err + } + + if resp == nil { + return nil + } + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.c.config.HttpClient.Do(req) + if err != nil { + return err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return err + } + return nil } diff --git a/changelog/14269.txt b/changelog/14269.txt new file mode 100644 index 0000000000000..529b7c6264299 --- /dev/null +++ b/changelog/14269.txt @@ -0,0 +1,3 @@ +```release-note:bug + api/sys/raft: Update RaftSnapshotRestore to use net/http client allowing bodies larger than allocated memory to be streamed +``` From 5073f050044a19104734182e86ab8703c111262f Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Fri, 25 Feb 2022 08:54:39 -0800 Subject: [PATCH 02/13] dry --- api/sys_raft.go | 238 ++++++++++++++++++------------------------------ 1 file changed, 90 insertions(+), 148 deletions(-) diff --git a/api/sys_raft.go b/api/sys_raft.go index c0ee949b7846e..11f83edcb60df 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -138,81 +138,10 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") r.URL.RawQuery = r.Params.Encode() - req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil) - if err != nil { - return err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } - - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } - - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } - - // Avoiding the use of RawRequestWithContext which reads the response body - // to determine if the body contains error message. - var result *Response - resp, err := c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - - if resp == nil { - return nil - } - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return err - } + resp, err := c.requestWithContext(ctx, r, nil) // Make sure that the last file in the archive, SHA256SUMS.sealed, is present // and non-empty. This is to catch cases where the snapshot failed midstream, @@ -272,8 +201,7 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { } // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. This avoids the use of -// RawRequestWithContext which copies the body (leading to possible OOMs) for retrying +// snapshot, returning the cluster to the state defined by it. func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { @@ -281,84 +209,15 @@ func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { } r := c.c.NewRequest(http.MethodPost, path) - r.URL.RawQuery = r.Params.Encode() - req, err := http.NewRequest(http.MethodPost, r.URL.RequestURI(), snapReader) - if err != nil { - return err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } - - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } - - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } - - var result *Response - resp, err := c.c.config.HttpClient.Do(req) - defer resp.Body.Close() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + _, err := c.requestWithContext(ctx, r, snapReader) if err != nil { return err } - if resp == nil { - return nil - } - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc - - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return err - } - return nil } @@ -437,3 +296,86 @@ func (c *Sys) RaftAutopilotConfiguration() (*AutopilotConfig, error) { return &result, err } + +// requestWithContext avoids the use of the go-retryable library and is useful when +// making requests where the req or resp body is likely larger than available memory +func (c *Sys) requestWithContext(ctx context.Context, r *Request, body io.Reader) (*Response, error) { + req, err := http.NewRequest(r.Method, r.URL.RequestURI(), body) + if err != nil { + return nil, err + } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + + resp, err := c.c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + + if resp == nil { + return nil, err + } + + defer resp.Body.Close() + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return nil, err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return nil, fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return nil, err + } + + return result, nil +} From d680605483e2ce025befd9fed38ee7c57e57c22f Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Fri, 25 Feb 2022 09:56:46 -0800 Subject: [PATCH 03/13] move to client --- api/client.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/api/client.go b/api/client.go index 1c890e01d4227..637ea72be0a8f 100644 --- a/api/client.go +++ b/api/client.go @@ -1068,6 +1068,83 @@ START: return result, nil } +// requestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is +// useful when making requests where the req or resp body is likely larger than available memory +func (c *Client) requestWithContext(ctx context.Context, r *Request) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body) + if err != nil { + return nil, err + } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + + resp, err := c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return nil, err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return nil, fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.config.HttpClient.Do(req) + if err != nil { + return nil, err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return nil, err + } + + return result, nil +} + type ( RequestCallback func(*Request) ResponseCallback func(*Response) From bea374ad7678691a4ebf7b31d8a07b0744e02156 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Fri, 25 Feb 2022 12:21:33 -0800 Subject: [PATCH 04/13] update tests --- vault/external_tests/raft/raft_test.go | 33 ++++---------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index 0ab2a3032202c..0a7ee3fa8060a 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -489,28 +489,10 @@ func TestRaft_SnapshotAPI(t *testing.T) { } } - transport := cleanhttp.DefaultPooledTransport() - transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone() - if err := http2.ConfigureTransport(transport); err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: transport, - } - // Take a snapshot - req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot") - httpReq, err := req.ToHTTP() - if err != nil { - t.Fatal(err) - } - resp, err := client.Do(httpReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - snap, err := ioutil.ReadAll(resp.Body) + buf := new(bytes.Buffer) + err := leaderClient.Sys().RaftSnapshot(buf) + snap, err := io.ReadAll(buf) if err != nil { t.Fatal(err) } @@ -527,15 +509,8 @@ func TestRaft_SnapshotAPI(t *testing.T) { t.Fatal(err) } } - // Restore snapshot - req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot") - req.Body = bytes.NewBuffer(snap) - httpReq, err = req.ToHTTP() - if err != nil { - t.Fatal(err) - } - resp, err = client.Do(httpReq) + err = leaderClient.Sys().RaftSnapshotRestore(bytes.NewReader(snap), false) if err != nil { t.Fatal(err) } From 092136ccebbe12e3feb417443483ad02a7b12032 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Tue, 1 Mar 2022 11:14:19 -0800 Subject: [PATCH 05/13] more robust client --- api/client.go | 96 ++++++++++++++++++++++++++++++++++++------------- api/sys_raft.go | 17 ++++----- 2 files changed, 78 insertions(+), 35 deletions(-) diff --git a/api/client.go b/api/client.go index 637ea72be0a8f..49339f0d1e159 100644 --- a/api/client.go +++ b/api/client.go @@ -42,6 +42,15 @@ const ( EnvVaultToken = "VAULT_TOKEN" EnvVaultMFA = "VAULT_MFA" EnvRateLimit = "VAULT_RATE_LIMIT" + HeaderIndex = "X-Vault-Index" + TLSErrorString = "This error usually means that the server is running with TLS disabled\n" + + "but the client is configured to use TLS. Please either enable TLS\n" + + "on the server or run the client with -address set to an address\n" + + "that uses the http protocol:\n\n" + + " vault -address http://
\n\n" + + "You can also set the VAULT_ADDR environment variable:\n\n\n" + + " VAULT_ADDR=http://
vault \n\n" + + "where
is replaced by the actual address to the server." ) // Deprecated values @@ -955,11 +964,8 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon } // Sanity check the token before potentially erroring from the API - idx := strings.IndexFunc(token, func(c rune) bool { - return !unicode.IsPrint(c) - }) - if idx != -1 { - return nil, fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") + if err := tokenSanityCheck(token); err != nil { + return nil, err } redirectCount := 0 @@ -1015,17 +1021,7 @@ START: } if err != nil { if strings.Contains(err.Error(), "tls: oversized") { - err = errwrap.Wrapf( - "{{err}}\n\n"+ - "This error usually means that the server is running with TLS disabled\n"+ - "but the client is configured to use TLS. Please either enable TLS\n"+ - "on the server or run the client with -address set to an address\n"+ - "that uses the http protocol:\n\n"+ - " vault -address http://
\n\n"+ - "You can also set the VAULT_ADDR environment variable:\n\n\n"+ - " VAULT_ADDR=http://
vault \n\n"+ - "where
is replaced by the actual address to the server.", - err) + err = fmt.Errorf("%s\n\n%s", err, TLSErrorString) } return result, err } @@ -1068,14 +1064,32 @@ START: return result, nil } -// requestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is -// useful when making requests where the req or resp body is likely larger than available memory -func (c *Client) requestWithContext(ctx context.Context, r *Request) (*Response, error) { +// httpRequestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is +// useful when making calls where a net/http client is desirable. A single redirect (status code 301, 302, +// or 307) will be followed but all retry and timeout logic is the responsibility of the caller as is +// closing the Response body. +func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Response, error) { req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body) if err != nil { return nil, err } + c.modifyLock.RLock() + token := c.token + + c.config.modifyLock.RLock() + limiter := c.config.Limiter + httpClient := c.config.HttpClient + outputCurlString := c.config.OutputCurlString + c.config.modifyLock.RUnlock() + + c.modifyLock.RUnlock() + + // OutputCurlString logic relies on the request type to be retryable.Request as + if outputCurlString { + return nil, fmt.Errorf("output-curl-string is not implemented for this request") + } + req.URL.User = r.URL.User req.URL.Scheme = r.URL.Scheme req.URL.Host = r.URL.Host @@ -1107,11 +1121,28 @@ func (c *Client) requestWithContext(ctx context.Context, r *Request) (*Response, req.Header.Set("X-Vault-Policy-Override", "true") } + if limiter != nil { + limiter.Wait(ctx) + } + + // Sanity check the token before potentially erroring from the API + if err := tokenSanityCheck(token); err != nil { + return nil, err + } + var result *Response - resp, err := c.config.HttpClient.Do(req) + resp, err := httpClient.Do(req) + + if resp != nil { + result = &Response{Response: resp} + } + if err != nil { - return nil, err + if strings.Contains(err.Error(), "tls: oversized") { + err = fmt.Errorf("%s\n\n%s", err, TLSErrorString) + } + return result, err } // Check for a redirect, only allowing for a single redirect @@ -1119,25 +1150,29 @@ func (c *Client) requestWithContext(ctx context.Context, r *Request) (*Response, // Parse the updated location respLoc, err := resp.Location() if err != nil { - return nil, err + return result, fmt.Errorf("failed to follow redirect") } // Ensure a protocol downgrade doesn't happen if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return nil, fmt.Errorf("redirect would cause protocol downgrade") + return result, fmt.Errorf("redirect would cause protocol downgrade") } // Update the request req.URL = respLoc + // Reset the request body if any + if err := r.ResetJSONBody(); err != nil { + return result, err + } + // Retry the request resp, err = c.config.HttpClient.Do(req) if err != nil { - return nil, err + return result, err } } - result = &Response{Response: resp} if err := result.Error(); err != nil { return nil, err } @@ -1226,3 +1261,14 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo } return false, nil } + +// tokenSanityCheck will check for non-printable characters to prevent a call that will fail at the api +func tokenSanityCheck(t string) error { + idx := strings.IndexFunc(t, func(c rune) bool { + return !unicode.IsPrint(c) + }) + if idx != -1 { + return fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") + } + return nil +} diff --git a/api/sys_raft.go b/api/sys_raft.go index 11f83edcb60df..e5d32c59d4c04 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -137,11 +137,11 @@ func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) { func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") r.URL.RawQuery = r.Params.Encode() - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - resp, err := c.requestWithContext(ctx, r, nil) + resp, err := c.c.httpRequestWithContext(context.Background(), r) + if err != nil { + return err + } + defer resp.Body.Close() // Make sure that the last file in the archive, SHA256SUMS.sealed, is present // and non-empty. This is to catch cases where the snapshot failed midstream, @@ -210,14 +210,11 @@ func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { r := c.c.NewRequest(http.MethodPost, path) - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - _, err := c.requestWithContext(ctx, r, snapReader) + resp, err := c.c.httpRequestWithContext(context.Background(), r) if err != nil { return err } - + defer resp.Body.Close() return nil } From 191fdcaea80e9525ddaf9c22b512853eac6dec5b Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Tue, 1 Mar 2022 11:28:26 -0800 Subject: [PATCH 06/13] ctx friendly --- api/client.go | 12 ++++++------ api/sys_raft.go | 30 ++++++++++++++++++++++++------ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/api/client.go b/api/client.go index 49339f0d1e159..6bde7fbaf2e05 100644 --- a/api/client.go +++ b/api/client.go @@ -963,8 +963,8 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon limiter.Wait(ctx) } - // Sanity check the token before potentially erroring from the API - if err := tokenSanityCheck(token); err != nil { + // check the token before potentially erroring from the API + if err := tokenCheck(token); err != nil { return nil, err } @@ -1125,8 +1125,8 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo limiter.Wait(ctx) } - // Sanity check the token before potentially erroring from the API - if err := tokenSanityCheck(token); err != nil { + // check the token before potentially erroring from the API + if err := tokenCheck(token); err != nil { return nil, err } @@ -1262,8 +1262,8 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } -// tokenSanityCheck will check for non-printable characters to prevent a call that will fail at the api -func tokenSanityCheck(t string) error { +// tokenCheck will check for non-printable characters to prevent a call that will fail at the api +func tokenCheck(t string) error { idx := strings.IndexFunc(t, func(c rune) bool { return !unicode.IsPrint(c) }) diff --git a/api/sys_raft.go b/api/sys_raft.go index e5d32c59d4c04..329c5d70eb0f3 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -132,12 +132,21 @@ func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) { return &result, err } -// RaftSnapshot invokes the API that takes the snapshot of the raft cluster and -// writes it to the supplied io.Writer. +// RaftSnapshot is a thin wrapper around RaftSnapshotWithContext func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + return c.RaftSnapshotWithContext(ctx, snapWriter) +} + +// RaftSnapshotWithContext invokes the API that takes the snapshot of the raft cluster and +// writes it to the supplied io.Writer. +func (c *Sys) RaftSnapshotWithContext(ctx context.Context, snapWriter io.Writer) error { r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") r.URL.RawQuery = r.Params.Encode() - resp, err := c.c.httpRequestWithContext(context.Background(), r) + + resp, err := c.c.httpRequestWithContext(ctx, r) if err != nil { return err } @@ -200,17 +209,26 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { return nil } -// RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. +// RaftSnapshotRestore is a thin wrapper around RaftSnapshotRestoreWithContext func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + return c.RaftSnapshotRestoreWithContext(ctx, snapReader, force) +} + +// RaftSnapshotRestoreWithContext reads the snapshot from the io.Reader and installs that +// snapshot, returning the cluster to the state defined by it. +func (c *Sys) RaftSnapshotRestoreWithContext(ctx context.Context, snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { path = "/v1/sys/storage/raft/snapshot-force" } r := c.c.NewRequest(http.MethodPost, path) + r.Body = snapReader - resp, err := c.c.httpRequestWithContext(context.Background(), r) + resp, err := c.c.httpRequestWithContext(ctx, r) if err != nil { return err } From 8f452d9b664650dca877893d6ba8494035898230 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Tue, 1 Mar 2022 12:13:34 -0800 Subject: [PATCH 07/13] clean up --- api/client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/client.go b/api/client.go index 6bde7fbaf2e05..0d9c882155cc8 100644 --- a/api/client.go +++ b/api/client.go @@ -1150,7 +1150,7 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo // Parse the updated location respLoc, err := resp.Location() if err != nil { - return result, fmt.Errorf("failed to follow redirect") + return result, fmt.Errorf("redirect failed: %s", err) } // Ensure a protocol downgrade doesn't happen @@ -1163,13 +1163,13 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo // Reset the request body if any if err := r.ResetJSONBody(); err != nil { - return result, err + return result, fmt.Errorf("redirect failed: %s", err) } // Retry the request - resp, err = c.config.HttpClient.Do(req) + resp, err = httpClient.Do(req) if err != nil { - return result, err + return result, fmt.Errorf("redirect failed: %s", err) } } From 54f62f7bd1ff43461f4ce9554c890cd3145848ba Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Wed, 2 Mar 2022 10:47:47 -0800 Subject: [PATCH 08/13] PR feedback --- api/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/client.go b/api/client.go index 0d9c882155cc8..f5ed304015ab3 100644 --- a/api/client.go +++ b/api/client.go @@ -964,7 +964,7 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon } // check the token before potentially erroring from the API - if err := tokenCheck(token); err != nil { + if err := validateToken(token); err != nil { return nil, err } @@ -1126,7 +1126,7 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo } // check the token before potentially erroring from the API - if err := tokenCheck(token); err != nil { + if err := validateToken(token); err != nil { return nil, err } From 9ff6c4ee1bae73520509c37651af5e3cfa2d89a8 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Wed, 2 Mar 2022 10:55:59 -0800 Subject: [PATCH 09/13] revert errwrap --- api/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/client.go b/api/client.go index f5ed304015ab3..de5c737e32033 100644 --- a/api/client.go +++ b/api/client.go @@ -1021,7 +1021,7 @@ START: } if err != nil { if strings.Contains(err.Error(), "tls: oversized") { - err = fmt.Errorf("%s\n\n%s", err, TLSErrorString) + err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) } return result, err } From 2e235eae2deba6ea4a2e30861b8bf36c41c2ec7c Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Wed, 2 Mar 2022 13:34:19 -0800 Subject: [PATCH 10/13] errwrap errwhere --- api/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/client.go b/api/client.go index de5c737e32033..aceaf34c4da0b 100644 --- a/api/client.go +++ b/api/client.go @@ -1140,7 +1140,7 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo if err != nil { if strings.Contains(err.Error(), "tls: oversized") { - err = fmt.Errorf("%s\n\n%s", err, TLSErrorString) + err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) } return result, err } From 2af33e07bcc5f2d17cfdd2ed324310327d2395ee Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Mon, 14 Mar 2022 09:39:57 -0700 Subject: [PATCH 11/13] headers --- api/client.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/api/client.go b/api/client.go index aceaf34c4da0b..d7f1cae781228 100644 --- a/api/client.go +++ b/api/client.go @@ -1081,8 +1081,14 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo limiter := c.config.Limiter httpClient := c.config.HttpClient outputCurlString := c.config.OutputCurlString + if c.headers != nil { + for header, vals := range c.headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } c.config.modifyLock.RUnlock() - c.modifyLock.RUnlock() // OutputCurlString logic relies on the request type to be retryable.Request as @@ -1095,14 +1101,6 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo req.URL.Host = r.URL.Host req.Host = r.URL.Host - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - if len(r.ClientToken) != 0 { req.Header.Set(consts.AuthHeaderName, r.ClientToken) } From e3d86c212cf62c5e8dbaa99aafc6d72cb54d907a Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Fri, 1 Apr 2022 14:57:35 -0700 Subject: [PATCH 12/13] clean up --- api/client.go | 1 - api/sys_raft.go | 85 ------------------------------------------------- 2 files changed, 86 deletions(-) diff --git a/api/client.go b/api/client.go index d7f1cae781228..c0f42ad3dffa4 100644 --- a/api/client.go +++ b/api/client.go @@ -42,7 +42,6 @@ const ( EnvVaultToken = "VAULT_TOKEN" EnvVaultMFA = "VAULT_MFA" EnvRateLimit = "VAULT_RATE_LIMIT" - HeaderIndex = "X-Vault-Index" TLSErrorString = "This error usually means that the server is running with TLS disabled\n" + "but the client is configured to use TLS. Please either enable TLS\n" + "on the server or run the client with -address set to an address\n" + diff --git a/api/sys_raft.go b/api/sys_raft.go index 329c5d70eb0f3..52e5d9342d20d 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -6,14 +6,12 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "io/ioutil" "net/http" "sync" "time" - "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/parseutil" "github.com/mitchellh/mapstructure" ) @@ -311,86 +309,3 @@ func (c *Sys) RaftAutopilotConfiguration() (*AutopilotConfig, error) { return &result, err } - -// requestWithContext avoids the use of the go-retryable library and is useful when -// making requests where the req or resp body is likely larger than available memory -func (c *Sys) requestWithContext(ctx context.Context, r *Request, body io.Reader) (*Response, error) { - req, err := http.NewRequest(r.Method, r.URL.RequestURI(), body) - if err != nil { - return nil, err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } - - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } - - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } - - var result *Response - - resp, err := c.c.config.HttpClient.Do(req) - if err != nil { - return nil, err - } - - if resp == nil { - return nil, err - } - - defer resp.Body.Close() - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return nil, err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return nil, fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc - - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return nil, err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return nil, err - } - - return result, nil -} From 1e4cafcb3145cf44029a6295a32e6211fd20b331 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Fri, 1 Apr 2022 15:39:04 -0700 Subject: [PATCH 13/13] clean up --- api/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/client.go b/api/client.go index c0f42ad3dffa4..ca8732cb7de15 100644 --- a/api/client.go +++ b/api/client.go @@ -1259,8 +1259,8 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } -// tokenCheck will check for non-printable characters to prevent a call that will fail at the api -func tokenCheck(t string) error { +// validateToken will check for non-printable characters to prevent a call that will fail at the api +func validateToken(t string) error { idx := strings.IndexFunc(t, func(c rune) bool { return !unicode.IsPrint(c) })