Skip to content

Commit

Permalink
move to client
Browse files Browse the repository at this point in the history
  • Loading branch information
VinnyHC committed Feb 25, 2022
1 parent 43da631 commit 780782e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 93 deletions.
77 changes: 77 additions & 0 deletions api/client.go
Expand Up @@ -1249,6 +1249,83 @@ START:
return result, nil
}

// requestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is
// useful when making requests where the req or resp body is likely larger than available memory
func (c *Client) requestWithContext(ctx context.Context, r *Request) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body)
if err != nil {
return nil, err
}

req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme
req.URL.Host = r.URL.Host
req.Host = r.URL.Host

if r.Headers != nil {
for header, vals := range r.Headers {
for _, val := range vals {
req.Header.Add(header, val)
}
}
}

if len(r.ClientToken) != 0 {
req.Header.Set(consts.AuthHeaderName, r.ClientToken)
}

if len(r.WrapTTL) != 0 {
req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
}

if len(r.MFAHeaderVals) != 0 {
for _, mfaHeaderVal := range r.MFAHeaderVals {
req.Header.Add("X-Vault-MFA", mfaHeaderVal)
}
}

if r.PolicyOverride {
req.Header.Set("X-Vault-Policy-Override", "true")
}

var result *Response

resp, err := c.config.HttpClient.Do(req)
if err != nil {
return nil, err
}

// Check for a redirect, only allowing for a single redirect
if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 {
// Parse the updated location
respLoc, err := resp.Location()
if err != nil {
return nil, err
}

// Ensure a protocol downgrade doesn't happen
if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
return nil, fmt.Errorf("redirect would cause protocol downgrade")
}

// Update the request
req.URL = respLoc

// Retry the request
resp, err = c.config.HttpClient.Do(req)
if err != nil {
return nil, err
}
}

result = &Response{Response: resp}
if err := result.Error(); err != nil {
return nil, err
}

return result, nil
}

type (
RequestCallback func(*Request)
ResponseCallback func(*Response)
Expand Down
101 changes: 8 additions & 93 deletions api/sys_raft.go
Expand Up @@ -6,15 +6,13 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"time"

"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/mitchellh/mapstructure"
)

Expand Down Expand Up @@ -138,10 +136,11 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
r.URL.RawQuery = r.Params.Encode()

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()

resp, err := c.requestWithContext(ctx, r, nil)
resp, err := c.c.requestWithContext(context.Background(), r)
if err != nil {
return err
}
defer resp.Body.Close()

// Make sure that the last file in the archive, SHA256SUMS.sealed, is present
// and non-empty. This is to catch cases where the snapshot failed midstream,
Expand Down Expand Up @@ -209,14 +208,13 @@ func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error {
}

r := c.c.NewRequest(http.MethodPost, path)
r.Body = snapReader

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()

_, err := c.requestWithContext(ctx, r, snapReader)
resp, err := c.c.requestWithContext(context.Background(), r)
if err != nil {
return err
}
defer resp.Body.Close()

return nil
}
Expand Down Expand Up @@ -315,86 +313,3 @@ func (c *Sys) PutRaftAutopilotConfiguration(opts *AutopilotConfig) error {

return nil
}

// requestWithContext avoids the use of the go-retryable library and is useful when
// making requests where the req or resp body is likely larger than available memory
func (c *Sys) requestWithContext(ctx context.Context, r *Request, body io.Reader) (*Response, error) {
req, err := http.NewRequest(r.Method, r.URL.RequestURI(), body)
if err != nil {
return nil, err
}

req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme
req.URL.Host = r.URL.Host
req.Host = r.URL.Host

if r.Headers != nil {
for header, vals := range r.Headers {
for _, val := range vals {
req.Header.Add(header, val)
}
}
}

if len(r.ClientToken) != 0 {
req.Header.Set(consts.AuthHeaderName, r.ClientToken)
}

if len(r.WrapTTL) != 0 {
req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
}

if len(r.MFAHeaderVals) != 0 {
for _, mfaHeaderVal := range r.MFAHeaderVals {
req.Header.Add("X-Vault-MFA", mfaHeaderVal)
}
}

if r.PolicyOverride {
req.Header.Set("X-Vault-Policy-Override", "true")
}

var result *Response

resp, err := c.c.config.HttpClient.Do(req)
if err != nil {
return nil, err
}

if resp == nil {
return nil, err
}

defer resp.Body.Close()

// Check for a redirect, only allowing for a single redirect
if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 {
// Parse the updated location
respLoc, err := resp.Location()
if err != nil {
return nil, err
}

// Ensure a protocol downgrade doesn't happen
if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
return nil, fmt.Errorf("redirect would cause protocol downgrade")
}

// Update the request
req.URL = respLoc

// Retry the request
resp, err = c.c.config.HttpClient.Do(req)
if err != nil {
return nil, err
}
}

result = &Response{Response: resp}
if err := result.Error(); err != nil {
return nil, err
}

return result, nil
}

0 comments on commit 780782e

Please sign in to comment.