Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BuildAuthToken to validate endpoint contains a port #1837

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/aaba2642a8f64293a3ad7dd5bc0e9ef7.json
@@ -0,0 +1,8 @@
{
"id": "aaba2642-a8f6-4293-a3ad-7dd5bc0e9ef7",
"type": "feature",
"description": "Updated `BuildAuthToken` to validate the provided endpoint contains a port.",
"modules": [
"feature/rds/auth"
]
}
29 changes: 29 additions & 0 deletions feature/rds/auth/connect.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -44,6 +45,11 @@ type BuildAuthTokenOptions struct{}
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
_, port := validateURL(endpoint)
if port == "" {
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
}

o := BuildAuthTokenOptions{}

for _, fn := range optFns {
Expand Down Expand Up @@ -94,3 +100,26 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds

return url, nil
}

func validateURL(hostPort string) (host, port string) {
colon := strings.LastIndexByte(hostPort, ':')
if colon != -1 {
host, port = hostPort[:colon], hostPort[colon+1:]
}
if !validatePort(port) {
port = ""
return
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}

return
}

func validatePort(port string) bool {
if _, err := strconv.Atoi(port); err == nil {
return true
}
return false
}
45 changes: 35 additions & 10 deletions feature/rds/auth/connect_test.go
Expand Up @@ -3,6 +3,7 @@ package auth_test
import (
"context"
"regexp"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -15,27 +16,51 @@ func TestBuildAuthToken(t *testing.T) {
region string
user string
expectedRegex string
expectedError string
}{
{
"https://prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
"prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:kakasdkasd",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
}

for _, c := range cases {
creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"}
url, err := auth.BuildAuthToken(context.Background(), c.endpoint, c.region, c.user, creds)
if err != nil {
t.Errorf("expect no error, got %v", err)
if len(c.expectedError) > 0 {
if err != nil {
if !strings.Contains(err.Error(), c.expectedError) {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
} else {
continue
}
} else {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
}
} else if err != nil {
t.Fatalf("expect no err, got: %v", err)
}

if re, a := regexp.MustCompile(c.expectedRegex), url; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
Expand Down