From fc08616ebbce914bbe6bb7bb694e6b58901e4b49 Mon Sep 17 00:00:00 2001 From: VinnyHC Date: Thu, 24 Feb 2022 15:47:57 -0800 Subject: [PATCH] use net/http client when body could be too big for retryablehttp client --- api/client.go | 1 - api/sys_raft.go | 81 +++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/api/client.go b/api/client.go index 6a804091a2194..741dceae0f986 100644 --- a/api/client.go +++ b/api/client.go @@ -20,7 +20,6 @@ import ( "unicode" "github.com/hashicorp/errwrap" - "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/go-rootcerts" "github.com/hashicorp/go-secure-stdlib/parseutil" diff --git a/api/sys_raft.go b/api/sys_raft.go index cbf3a2020038d..685717401dd32 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -272,24 +272,93 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { } // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. +// snapshot, returning the cluster to the state defined by it. This avoids the use of +// RawRequestWithContext which copies the body (leading to possible OOMs) for retrying func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { path = "/v1/sys/storage/raft/snapshot-force" } - r := c.c.NewRequest("POST", path) - r.Body = snapReader + r := c.c.NewRequest(http.MethodPost, path) + r.URL.RawQuery = r.Params.Encode() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - resp, err := c.c.RawRequestWithContext(ctx, r) + req, err := http.NewRequest(http.MethodPost, r.URL.RequestURI(), snapReader) if err != nil { return err } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + resp, err := c.c.config.HttpClient.Do(req) defer resp.Body.Close() + if err != nil { + return err + } + + if resp == nil { + return nil + } + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.c.config.HttpClient.Do(req) + if err != nil { + return err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return err + } + return nil }