diff --git a/builtin/credential/aws/backend.go b/builtin/credential/aws/backend.go index acc68dc04e6d6..a4bd75830a0fd 100644 --- a/builtin/credential/aws/backend.go +++ b/builtin/credential/aws/backend.go @@ -3,6 +3,7 @@ package awsauth import ( "context" "fmt" + "net/textproto" "strings" "sync" "time" @@ -17,6 +18,8 @@ import ( cache "github.com/patrickmn/go-cache" ) +var defaultAllowedSTSRequestHeaders = []string{"Authorization", "Content-Length", "Content-Type", "User-Agent", "X-Amz-Date", textproto.CanonicalMIMEHeaderKey(iamServerIdHeader)} + func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { b, err := Backend(conf) if err != nil { diff --git a/builtin/credential/aws/path_config_client.go b/builtin/credential/aws/path_config_client.go index 228488e302208..0fdc61dfac564 100644 --- a/builtin/credential/aws/path_config_client.go +++ b/builtin/credential/aws/path_config_client.go @@ -2,6 +2,10 @@ package awsauth import ( "context" + "errors" + "github.com/hashicorp/vault/sdk/helper/strutil" + "net/http" + "net/textproto" "github.com/aws/aws-sdk-go/aws" "github.com/hashicorp/vault/sdk/framework" @@ -53,6 +57,11 @@ func (b *backend) pathConfigClient() *framework.Path { Default: "", Description: "Value to require in the X-Vault-AWS-IAM-Server-ID request header", }, + "allowed_sts_header_values": { + Type: framework.TypeStringSlice, + Default: nil, + Description: "List of headers that are allowed to be in AWS STS request headers", + }, "max_retries": { Type: framework.TypeInt, Default: aws.UseServiceDefaultRetries, @@ -257,6 +266,24 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical configEntry.IAMServerIdHeaderValue = data.Get("iam_server_id_header_value").(string) } + aHeadersValStr, ok := data.GetOk("allowed_sts_header_values") + if ok { + aHeadersValSl := aHeadersValStr.([]string) + for i, v := range aHeadersValSl { + aHeadersValSl[i] = textproto.CanonicalMIMEHeaderKey(v) + } + if !strutil.EquivalentSlices(configEntry.AllowedSTSHeaderValues, aHeadersValSl) { + // NOT setting changedCreds here, since this isn't really cached + configEntry.AllowedSTSHeaderValues = aHeadersValSl + changedOtherConfig = true + } + } else if req.Operation == logical.CreateOperation { + ah, ok := data.GetOk("allowed_sts_header_values") + if ok { + configEntry.AllowedSTSHeaderValues = ah.([]string) + } + } + maxRetriesInt, ok := data.GetOk("max_retries") if ok { configEntry.MaxRetries = maxRetriesInt.(int) @@ -293,14 +320,28 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical // Struct to hold 'aws_access_key' and 'aws_secret_key' that are required to // interact with the AWS EC2 API. type clientConfig struct { - AccessKey string `json:"access_key"` - SecretKey string `json:"secret_key"` - Endpoint string `json:"endpoint"` - IAMEndpoint string `json:"iam_endpoint"` - STSEndpoint string `json:"sts_endpoint"` - STSRegion string `json:"sts_region"` - IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` - MaxRetries int `json:"max_retries"` + AccessKey string `json:"access_key"` + SecretKey string `json:"secret_key"` + Endpoint string `json:"endpoint"` + IAMEndpoint string `json:"iam_endpoint"` + STSEndpoint string `json:"sts_endpoint"` + STSRegion string `json:"sts_region"` + IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` + AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"` + MaxRetries int `json:"max_retries"` +} + +func (c *clientConfig) validateAllowedSTSHeaderValues(headers http.Header) error { + allowList := c.AllowedSTSHeaderValues + if c.AllowedSTSHeaderValues == nil { + allowList = defaultAllowedSTSRequestHeaders + } + for k := range headers { + if !strutil.StrListContains(allowList, textproto.CanonicalMIMEHeaderKey(k)) { + return errors.New("invalid request header: " + k) + } + } + return nil } const pathConfigClientHelpSyn = ` diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index 5bbb74c832e10..df99c4a35260a 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "encoding/pem" "encoding/xml" + "errors" "fmt" "io/ioutil" "net/http" @@ -43,6 +44,11 @@ const ( retryWaitMax = 30 * time.Second ) +var ( + errRequestBodyNotValid = errors.New("iam request body is invalid") + errInvalidGetCallerIdentityResponse = errors.New("body of GetCallerIdentity is invalid") +) + func (b *backend) pathLogin() *framework.Path { return &framework.Path{ Pattern: "login$", @@ -1179,7 +1185,10 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, if err != nil { return logical.ErrorResponse("error parsing iam_request_url"), nil } - + if parsedUrl.RawQuery != "" { + // Should be no query parameters + return logical.ErrorResponse(logical.ErrInvalidRequest.Error()), nil + } // TODO: There are two potentially valid cases we're not yet supporting that would // necessitate this check being changed. First, if we support GET requests. // Second if we support presigned POST requests @@ -1192,6 +1201,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, return logical.ErrorResponse("failed to base64 decode iam_request_body"), nil } body := string(bodyRaw) + if err = validateLoginIamRequestBody(body); err != nil { + return logical.ErrorResponse(err.Error()), nil + } headers := data.Get("iam_request_headers").(http.Header) if len(headers) == 0 { @@ -1213,6 +1225,9 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, return logical.ErrorResponse(fmt.Sprintf("error validating %s header: %v", iamServerIdHeader, err)), nil } } + if err = config.validateAllowedSTSHeaderValues(headers); err != nil { + return logical.ErrorResponse(err.Error()), nil + } if config.STSEndpoint != "" { endpoint = config.STSEndpoint } @@ -1394,6 +1409,29 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, }, nil } +// Validate that the iam_request_body passed is valid for the STS request +func validateLoginIamRequestBody(body string) error { + qs, err := url.ParseQuery(body) + if err != nil { + return err + } + for k, v := range qs { + switch k { + case "Action": + if len(v) != 1 || v[0] != "GetCallerIdentity" { + return errRequestBodyNotValid + } + case "Version": + // Will assume for now that future versions don't change + // the semantics + default: + // Not expecting any other values + return errRequestBodyNotValid + } + } + return nil +} + // These two methods (hasValuesFor*) return two bools // The first is a hasAll, that is, does the request have all the values // necessary for this auth method @@ -1559,8 +1597,12 @@ func ensureHeaderIsSigned(signedHeaders, headerToSign string) error { } func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, error) { - decoder := xml.NewDecoder(strings.NewReader(response)) result := GetCallerIdentityResponse{} + response = strings.TrimSpace(response) + if !strings.HasPrefix(response, "", + }, + ExpectErr: errors.New("invalid request header: X-Mallory-Header"), + }, { Name: "JSON-complete", Header: `{ @@ -543,7 +556,8 @@ func setupIAMTestServer() *httptest.Server { 7f4fc40c-853a-11e6-8848-8d035d01eb87 -` + +` auth := r.Header.Get("Authorization") parts := strings.Split(auth, ",") @@ -566,6 +580,7 @@ func setupIAMTestServer() *httptest.Server { if matchingCount != len(expectedAuthParts) { responseString = "missing auth parts" } + w.Header().Add("Content-Type", "text/xml") fmt.Fprintln(w, responseString) })) } diff --git a/website/pages/api-docs/auth/aws/index.mdx b/website/pages/api-docs/auth/aws/index.mdx index d5cbcf74b62a7..3cd500cf5e87e 100644 --- a/website/pages/api-docs/auth/aws/index.mdx +++ b/website/pages/api-docs/auth/aws/index.mdx @@ -66,7 +66,11 @@ capabilities, the credentials are fetched automatically. signed headers validated by AWS. This is to protect against different types of replay attacks, for example a signed request sent to a dev server being resent to a production server. Consider setting this to the Vault server's DNS name. - +- `allowed_sts_header_values` `([]string: nil)` The list of allowed request headers + when providing the iam_request_headers for an IAM based login call. If not + provided (recommended), defaults to the set of headers AWS STS expects for a + GetCallerIdentity call. + ### Sample Payload ```json