From e7c238caea47b14cea517ef237e6808f3a75887c Mon Sep 17 00:00:00 2001 From: Vinny Mannello <94396874+VinnyHC@users.noreply.github.com> Date: Mon, 14 Mar 2022 10:13:33 -0700 Subject: [PATCH] [VAULT-5003] Use net/http client in Sys().RaftSnapshotRestore (#14269) Use net/http client when body could be too big for retryablehttp client --- api/client.go | 154 ++++++++++++++++++++++--- api/sys_raft.go | 106 ++++------------- changelog/14269.txt | 3 + vault/external_tests/raft/raft_test.go | 33 +----- 4 files changed, 168 insertions(+), 128 deletions(-) create mode 100644 changelog/14269.txt diff --git a/api/client.go b/api/client.go index 6a804091a2194..3ce57a1de2e56 100644 --- a/api/client.go +++ b/api/client.go @@ -53,6 +53,14 @@ const ( HeaderIndex = "X-Vault-Index" HeaderForward = "X-Vault-Forward" HeaderInconsistent = "X-Vault-Inconsistent" + TLSErrorString = "This error usually means that the server is running with TLS disabled\n" + + "but the client is configured to use TLS. Please either enable TLS\n" + + "on the server or run the client with -address set to an address\n" + + "that uses the http protocol:\n\n" + + " vault -address http://
\n\n" + + "You can also set the VAULT_ADDR environment variable:\n\n\n" + + " VAULT_ADDR=http://
vault \n\n" + + "where
is replaced by the actual address to the server." ) // Deprecated values @@ -1127,12 +1135,9 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon limiter.Wait(ctx) } - // Sanity check the token before potentially erroring from the API - idx := strings.IndexFunc(token, func(c rune) bool { - return !unicode.IsPrint(c) - }) - if idx != -1 { - return nil, fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") + // check the token before potentially erroring from the API + if err := validateToken(token); err != nil { + return nil, err } redirectCount := 0 @@ -1192,17 +1197,7 @@ START: } if err != nil { if strings.Contains(err.Error(), "tls: oversized") { - err = errwrap.Wrapf( - "{{err}}\n\n"+ - "This error usually means that the server is running with TLS disabled\n"+ - "but the client is configured to use TLS. Please either enable TLS\n"+ - "on the server or run the client with -address set to an address\n"+ - "that uses the http protocol:\n\n"+ - " vault -address http://
\n\n"+ - "You can also set the VAULT_ADDR environment variable:\n\n\n"+ - " VAULT_ADDR=http://
vault \n\n"+ - "where
is replaced by the actual address to the server.", - err) + err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) } return result, err } @@ -1249,6 +1244,120 @@ START: return result, nil } +// httpRequestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is +// useful when making calls where a net/http client is desirable. A single redirect (status code 301, 302, +// or 307) will be followed but all retry and timeout logic is the responsibility of the caller as is +// closing the Response body. +func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body) + if err != nil { + return nil, err + } + + c.modifyLock.RLock() + token := c.token + + c.config.modifyLock.RLock() + limiter := c.config.Limiter + httpClient := c.config.HttpClient + outputCurlString := c.config.OutputCurlString + if c.headers != nil { + for header, vals := range c.headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + c.config.modifyLock.RUnlock() + c.modifyLock.RUnlock() + + // OutputCurlString logic relies on the request type to be retryable.Request as + if outputCurlString { + return nil, fmt.Errorf("output-curl-string is not implemented for this request") + } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + if limiter != nil { + limiter.Wait(ctx) + } + + // check the token before potentially erroring from the API + if err := validateToken(token); err != nil { + return nil, err + } + + var result *Response + + resp, err := httpClient.Do(req) + + if resp != nil { + result = &Response{Response: resp} + } + + if err != nil { + if strings.Contains(err.Error(), "tls: oversized") { + err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) + } + return result, err + } + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return result, fmt.Errorf("redirect failed: %s", err) + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return result, fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Reset the request body if any + if err := r.ResetJSONBody(); err != nil { + return result, fmt.Errorf("redirect failed: %s", err) + } + + // Retry the request + resp, err = httpClient.Do(req) + if err != nil { + return result, fmt.Errorf("redirect failed: %s", err) + } + } + + if err := result.Error(); err != nil { + return nil, err + } + + return result, nil +} + type ( RequestCallback func(*Request) ResponseCallback func(*Response) @@ -1466,3 +1575,14 @@ func (w *replicationStateStore) states() []string { copy(c, w.store) return c } + +// validateToken will check for non-printable characters to prevent a call that will fail at the api +func validateToken(t string) error { + idx := strings.IndexFunc(t, func(c rune) bool { + return !unicode.IsPrint(c) + }) + if idx != -1 { + return fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") + } + return nil +} diff --git a/api/sys_raft.go b/api/sys_raft.go index cbf3a2020038d..7dc10959ac32e 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -14,7 +13,6 @@ import ( "time" "github.com/hashicorp/go-secure-stdlib/parseutil" - "github.com/hashicorp/vault/sdk/helper/consts" "github.com/mitchellh/mapstructure" ) @@ -132,87 +130,25 @@ func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) { return &result, err } -// RaftSnapshot invokes the API that takes the snapshot of the raft cluster and -// writes it to the supplied io.Writer. +// RaftSnapshot is a thin wrapper around RaftSnapshotWithContext func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { - r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") - r.URL.RawQuery = r.Params.Encode() - - req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil) - if err != nil { - return err - } - - req.URL.User = r.URL.User - req.URL.Scheme = r.URL.Scheme - req.URL.Host = r.URL.Host - req.Host = r.URL.Host - - if r.Headers != nil { - for header, vals := range r.Headers { - for _, val := range vals { - req.Header.Add(header, val) - } - } - } - - if len(r.ClientToken) != 0 { - req.Header.Set(consts.AuthHeaderName, r.ClientToken) - } - - if len(r.WrapTTL) != 0 { - req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) - } + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - if len(r.MFAHeaderVals) != 0 { - for _, mfaHeaderVal := range r.MFAHeaderVals { - req.Header.Add("X-Vault-MFA", mfaHeaderVal) - } - } + return c.RaftSnapshotWithContext(ctx, snapWriter) +} - if r.PolicyOverride { - req.Header.Set("X-Vault-Policy-Override", "true") - } +// RaftSnapshotWithContext invokes the API that takes the snapshot of the raft cluster and +// writes it to the supplied io.Writer. +func (c *Sys) RaftSnapshotWithContext(ctx context.Context, snapWriter io.Writer) error { + r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") + r.URL.RawQuery = r.Params.Encode() - // Avoiding the use of RawRequestWithContext which reads the response body - // to determine if the body contains error message. - var result *Response - resp, err := c.c.config.HttpClient.Do(req) + resp, err := c.c.httpRequestWithContext(ctx, r) if err != nil { return err } - - if resp == nil { - return nil - } - - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { - // Parse the updated location - respLoc, err := resp.Location() - if err != nil { - return err - } - - // Ensure a protocol downgrade doesn't happen - if req.URL.Scheme == "https" && respLoc.Scheme != "https" { - return fmt.Errorf("redirect would cause protocol downgrade") - } - - // Update the request - req.URL = respLoc - - // Retry the request - resp, err = c.c.config.HttpClient.Do(req) - if err != nil { - return err - } - } - - result = &Response{Response: resp} - if err := result.Error(); err != nil { - return err - } + defer resp.Body.Close() // Make sure that the last file in the archive, SHA256SUMS.sealed, is present // and non-empty. This is to catch cases where the snapshot failed midstream, @@ -271,20 +207,26 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { return nil } -// RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. +// RaftSnapshotRestore is a thin wrapper around RaftSnapshotRestoreWithContext func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + return c.RaftSnapshotRestoreWithContext(ctx, snapReader, force) +} + +// RaftSnapshotRestoreWithContext reads the snapshot from the io.Reader and installs that +// snapshot, returning the cluster to the state defined by it. +func (c *Sys) RaftSnapshotRestoreWithContext(ctx context.Context, snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { path = "/v1/sys/storage/raft/snapshot-force" } - r := c.c.NewRequest("POST", path) + r := c.c.NewRequest(http.MethodPost, path) r.Body = snapReader - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - resp, err := c.c.RawRequestWithContext(ctx, r) + resp, err := c.c.httpRequestWithContext(ctx, r) if err != nil { return err } diff --git a/changelog/14269.txt b/changelog/14269.txt new file mode 100644 index 0000000000000..529b7c6264299 --- /dev/null +++ b/changelog/14269.txt @@ -0,0 +1,3 @@ +```release-note:bug + api/sys/raft: Update RaftSnapshotRestore to use net/http client allowing bodies larger than allocated memory to be streamed +``` diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index eeba82c7df7d7..947710b3ed1da 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -489,28 +489,10 @@ func TestRaft_SnapshotAPI(t *testing.T) { } } - transport := cleanhttp.DefaultPooledTransport() - transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone() - if err := http2.ConfigureTransport(transport); err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: transport, - } - // Take a snapshot - req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot") - httpReq, err := req.ToHTTP() - if err != nil { - t.Fatal(err) - } - resp, err := client.Do(httpReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - snap, err := ioutil.ReadAll(resp.Body) + buf := new(bytes.Buffer) + err := leaderClient.Sys().RaftSnapshot(buf) + snap, err := io.ReadAll(buf) if err != nil { t.Fatal(err) } @@ -527,15 +509,8 @@ func TestRaft_SnapshotAPI(t *testing.T) { t.Fatal(err) } } - // Restore snapshot - req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot") - req.Body = bytes.NewBuffer(snap) - httpReq, err = req.ToHTTP() - if err != nil { - t.Fatal(err) - } - resp, err = client.Do(httpReq) + err = leaderClient.Sys().RaftSnapshotRestore(bytes.NewReader(snap), false) if err != nil { t.Fatal(err) }