Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VAULT-5003] Use net/http client in Sys().RaftSnapshotRestore #14269

Merged
merged 12 commits into from Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 137 additions & 17 deletions api/client.go
Expand Up @@ -53,6 +53,14 @@ const (
HeaderIndex = "X-Vault-Index"
HeaderForward = "X-Vault-Forward"
HeaderInconsistent = "X-Vault-Inconsistent"
TLSErrorString = "This error usually means that the server is running with TLS disabled\n" +
"but the client is configured to use TLS. Please either enable TLS\n" +
"on the server or run the client with -address set to an address\n" +
"that uses the http protocol:\n\n" +
" vault <command> -address http://<address>\n\n" +
"You can also set the VAULT_ADDR environment variable:\n\n\n" +
" VAULT_ADDR=http://<address> vault <command>\n\n" +
"where <address> is replaced by the actual address to the server."
VinnyHC marked this conversation as resolved.
Show resolved Hide resolved
)

// Deprecated values
Expand Down Expand Up @@ -1127,12 +1135,9 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
limiter.Wait(ctx)
}

// Sanity check the token before potentially erroring from the API
idx := strings.IndexFunc(token, func(c rune) bool {
return !unicode.IsPrint(c)
})
if idx != -1 {
return nil, fmt.Errorf("configured Vault token contains non-printable characters and cannot be used")
// check the token before potentially erroring from the API
if err := validateToken(token); err != nil {
return nil, err
}

redirectCount := 0
Expand Down Expand Up @@ -1192,17 +1197,7 @@ START:
}
if err != nil {
if strings.Contains(err.Error(), "tls: oversized") {
err = errwrap.Wrapf(
"{{err}}\n\n"+
"This error usually means that the server is running with TLS disabled\n"+
"but the client is configured to use TLS. Please either enable TLS\n"+
"on the server or run the client with -address set to an address\n"+
"that uses the http protocol:\n\n"+
" vault <command> -address http://<address>\n\n"+
"You can also set the VAULT_ADDR environment variable:\n\n\n"+
" VAULT_ADDR=http://<address> vault <command>\n\n"+
"where <address> is replaced by the actual address to the server.",
err)
err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err)
}
return result, err
}
Expand Down Expand Up @@ -1249,6 +1244,120 @@ START:
return result, nil
}

// httpRequestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is
// useful when making calls where a net/http client is desirable. A single redirect (status code 301, 302,
// or 307) will be followed but all retry and timeout logic is the responsibility of the caller as is
// closing the Response body.
func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Response, error) {
req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body)
VinnyHC marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}

c.modifyLock.RLock()
token := c.token

c.config.modifyLock.RLock()
limiter := c.config.Limiter
httpClient := c.config.HttpClient
outputCurlString := c.config.OutputCurlString
if c.headers != nil {
for header, vals := range c.headers {
for _, val := range vals {
req.Header.Add(header, val)
}
}
}
c.config.modifyLock.RUnlock()
c.modifyLock.RUnlock()

// OutputCurlString logic relies on the request type to be retryable.Request as
if outputCurlString {
return nil, fmt.Errorf("output-curl-string is not implemented for this request")
}

req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme
req.URL.Host = r.URL.Host
req.Host = r.URL.Host

if len(r.ClientToken) != 0 {
req.Header.Set(consts.AuthHeaderName, r.ClientToken)
}

if len(r.WrapTTL) != 0 {
req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
}

if len(r.MFAHeaderVals) != 0 {
for _, mfaHeaderVal := range r.MFAHeaderVals {
req.Header.Add("X-Vault-MFA", mfaHeaderVal)
}
}

if r.PolicyOverride {
req.Header.Set("X-Vault-Policy-Override", "true")
}

if limiter != nil {
limiter.Wait(ctx)
}

// check the token before potentially erroring from the API
if err := validateToken(token); err != nil {
return nil, err
}

var result *Response

resp, err := httpClient.Do(req)

if resp != nil {
result = &Response{Response: resp}
}

if err != nil {
if strings.Contains(err.Error(), "tls: oversized") {
err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err)
}
return result, err
}

// Check for a redirect, only allowing for a single redirect
if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 {
// Parse the updated location
respLoc, err := resp.Location()
if err != nil {
VinnyHC marked this conversation as resolved.
Show resolved Hide resolved
return result, fmt.Errorf("redirect failed: %s", err)
}

// Ensure a protocol downgrade doesn't happen
if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
return result, fmt.Errorf("redirect would cause protocol downgrade")
}

// Update the request
req.URL = respLoc

// Reset the request body if any
if err := r.ResetJSONBody(); err != nil {
return result, fmt.Errorf("redirect failed: %s", err)
}

// Retry the request
resp, err = httpClient.Do(req)
if err != nil {
return result, fmt.Errorf("redirect failed: %s", err)
}
}

if err := result.Error(); err != nil {
return nil, err
}

return result, nil
}

type (
RequestCallback func(*Request)
ResponseCallback func(*Response)
Expand Down Expand Up @@ -1466,3 +1575,14 @@ func (w *replicationStateStore) states() []string {
copy(c, w.store)
return c
}

// validateToken will check for non-printable characters to prevent a call that will fail at the api
func validateToken(t string) error {
idx := strings.IndexFunc(t, func(c rune) bool {
return !unicode.IsPrint(c)
})
if idx != -1 {
return fmt.Errorf("configured Vault token contains non-printable characters and cannot be used")
}
return nil
}
106 changes: 24 additions & 82 deletions api/sys_raft.go
Expand Up @@ -6,15 +6,13 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"time"

"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/mitchellh/mapstructure"
)

Expand Down Expand Up @@ -132,87 +130,25 @@ func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) {
return &result, err
}

// RaftSnapshot invokes the API that takes the snapshot of the raft cluster and
// writes it to the supplied io.Writer.
// RaftSnapshot is a thin wrapper around RaftSnapshotWithContext
func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
r.URL.RawQuery = r.Params.Encode()

req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil)
if err != nil {
return err
}

req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme
req.URL.Host = r.URL.Host
req.Host = r.URL.Host

if r.Headers != nil {
for header, vals := range r.Headers {
for _, val := range vals {
req.Header.Add(header, val)
}
}
}

if len(r.ClientToken) != 0 {
req.Header.Set(consts.AuthHeaderName, r.ClientToken)
}

if len(r.WrapTTL) != 0 {
req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
}
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()

if len(r.MFAHeaderVals) != 0 {
for _, mfaHeaderVal := range r.MFAHeaderVals {
req.Header.Add("X-Vault-MFA", mfaHeaderVal)
}
}
return c.RaftSnapshotWithContext(ctx, snapWriter)
}

if r.PolicyOverride {
req.Header.Set("X-Vault-Policy-Override", "true")
}
// RaftSnapshotWithContext invokes the API that takes the snapshot of the raft cluster and
// writes it to the supplied io.Writer.
func (c *Sys) RaftSnapshotWithContext(ctx context.Context, snapWriter io.Writer) error {
r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
r.URL.RawQuery = r.Params.Encode()

// Avoiding the use of RawRequestWithContext which reads the response body
// to determine if the body contains error message.
var result *Response
resp, err := c.c.config.HttpClient.Do(req)
resp, err := c.c.httpRequestWithContext(ctx, r)
if err != nil {
return err
}

if resp == nil {
return nil
}

// Check for a redirect, only allowing for a single redirect
if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 {
// Parse the updated location
respLoc, err := resp.Location()
if err != nil {
return err
}

// Ensure a protocol downgrade doesn't happen
if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
return fmt.Errorf("redirect would cause protocol downgrade")
}

// Update the request
req.URL = respLoc

// Retry the request
resp, err = c.c.config.HttpClient.Do(req)
if err != nil {
return err
}
}

result = &Response{Response: resp}
if err := result.Error(); err != nil {
return err
}
defer resp.Body.Close()

// Make sure that the last file in the archive, SHA256SUMS.sealed, is present
// and non-empty. This is to catch cases where the snapshot failed midstream,
Expand Down Expand Up @@ -271,20 +207,26 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
return nil
}

// RaftSnapshotRestore reads the snapshot from the io.Reader and installs that
// snapshot, returning the cluster to the state defined by it.
// RaftSnapshotRestore is a thin wrapper around RaftSnapshotRestoreWithContext
func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

return c.RaftSnapshotRestoreWithContext(ctx, snapReader, force)
}

// RaftSnapshotRestoreWithContext reads the snapshot from the io.Reader and installs that
// snapshot, returning the cluster to the state defined by it.
func (c *Sys) RaftSnapshotRestoreWithContext(ctx context.Context, snapReader io.Reader, force bool) error {
path := "/v1/sys/storage/raft/snapshot"
if force {
path = "/v1/sys/storage/raft/snapshot-force"
}
r := c.c.NewRequest("POST", path)

r := c.c.NewRequest(http.MethodPost, path)
r.Body = snapReader

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
resp, err := c.c.httpRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand Down
3 changes: 3 additions & 0 deletions changelog/14269.txt
@@ -0,0 +1,3 @@
```release-note:bug
api/sys/raft: Update RaftSnapshotRestore to use net/http client allowing bodies larger than allocated memory to be streamed
```
33 changes: 4 additions & 29 deletions vault/external_tests/raft/raft_test.go
Expand Up @@ -489,28 +489,10 @@ func TestRaft_SnapshotAPI(t *testing.T) {
}
}

transport := cleanhttp.DefaultPooledTransport()
transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone()
if err := http2.ConfigureTransport(transport); err != nil {
t.Fatal(err)
}
client := &http.Client{
Transport: transport,
}

// Take a snapshot
req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
httpReq, err := req.ToHTTP()
if err != nil {
t.Fatal(err)
}
resp, err := client.Do(httpReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

snap, err := ioutil.ReadAll(resp.Body)
buf := new(bytes.Buffer)
err := leaderClient.Sys().RaftSnapshot(buf)
snap, err := io.ReadAll(buf)
VinnyHC marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatal(err)
}
Expand All @@ -527,15 +509,8 @@ func TestRaft_SnapshotAPI(t *testing.T) {
t.Fatal(err)
}
}

// Restore snapshot
req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot")
req.Body = bytes.NewBuffer(snap)
httpReq, err = req.ToHTTP()
if err != nil {
t.Fatal(err)
}
resp, err = client.Do(httpReq)
err = leaderClient.Sys().RaftSnapshotRestore(bytes.NewReader(snap), false)
if err != nil {
t.Fatal(err)
}
Expand Down