diff --git a/.changelog/aaba2642a8f64293a3ad7dd5bc0e9ef7.json b/.changelog/aaba2642a8f64293a3ad7dd5bc0e9ef7.json new file mode 100644 index 00000000000..6d4ded44d4e --- /dev/null +++ b/.changelog/aaba2642a8f64293a3ad7dd5bc0e9ef7.json @@ -0,0 +1,8 @@ +{ + "id": "aaba2642-a8f6-4293-a3ad-7dd5bc0e9ef7", + "type": "feature", + "description": "Updated `BuildAuthToken` to validate the provided endpoint contains a port.", + "modules": [ + "feature/rds/auth" + ] +} \ No newline at end of file diff --git a/feature/rds/auth/connect.go b/feature/rds/auth/connect.go index 5e07c626ef1..9a1406e7ed3 100644 --- a/feature/rds/auth/connect.go +++ b/feature/rds/auth/connect.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strconv" "strings" "time" @@ -44,6 +45,11 @@ type BuildAuthTokenOptions struct{} // See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html // for more information on using IAM database authentication with RDS. func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) { + _, port := validateURL(endpoint) + if port == "" { + return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid") + } + o := BuildAuthTokenOptions{} for _, fn := range optFns { @@ -94,3 +100,26 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds return url, nil } + +func validateURL(hostPort string) (host, port string) { + colon := strings.LastIndexByte(hostPort, ':') + if colon != -1 { + host, port = hostPort[:colon], hostPort[colon+1:] + } + if !validatePort(port) { + port = "" + return + } + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} + +func validatePort(port string) bool { + if _, err := strconv.Atoi(port); err == nil { + return true + } + return false +} diff --git a/feature/rds/auth/connect_test.go b/feature/rds/auth/connect_test.go index 9c99a103732..7ccb2332f86 100644 --- a/feature/rds/auth/connect_test.go +++ b/feature/rds/auth/connect_test.go @@ -3,6 +3,7 @@ package auth_test import ( "context" "regexp" + "strings" "testing" "github.com/aws/aws-sdk-go-v2/aws" @@ -15,27 +16,51 @@ func TestBuildAuthToken(t *testing.T) { region string user string expectedRegex string + expectedError string }{ { - "https://prod-instance.us-east-1.rds.amazonaws.com:3306", - "us-west-2", - "mysqlUser", - `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`, + endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306", + region: "us-west-2", + user: "mysqlUser", + expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`, }, { - "prod-instance.us-east-1.rds.amazonaws.com:3306", - "us-west-2", - "mysqlUser", - `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`, + endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306", + region: "us-west-2", + user: "mysqlUser", + expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`, + }, + { + endpoint: "prod-instance.us-east-1.rds.amazonaws.com", + region: "us-west-2", + user: "mysqlUser", + expectedError: "port", + }, + { + endpoint: "prod-instance.us-east-1.rds.amazonaws.com:kakasdkasd", + region: "us-west-2", + user: "mysqlUser", + expectedError: "port", }, } for _, c := range cases { creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"} url, err := auth.BuildAuthToken(context.Background(), c.endpoint, c.region, c.user, creds) - if err != nil { - t.Errorf("expect no error, got %v", err) + if len(c.expectedError) > 0 { + if err != nil { + if !strings.Contains(err.Error(), c.expectedError) { + t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err) + } else { + continue + } + } else { + t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err) + } + } else if err != nil { + t.Fatalf("expect no err, got: %v", err) } + if re, a := regexp.MustCompile(c.expectedRegex), url; !re.MatchString(a) { t.Errorf("expect %s to match %s", re, a) }