Skip to content

Commit

Permalink
validated that the right side of the colon has to be an string repres…
Browse files Browse the repository at this point in the history
…entation of an integer
  • Loading branch information
RanVaknin authored and skmcgrail committed Sep 12, 2022
1 parent b011f04 commit 1549137
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 10 deletions.
29 changes: 29 additions & 0 deletions feature/rds/auth/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -44,6 +45,11 @@ type BuildAuthTokenOptions struct{}
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
_, port := validateUrl(endpoint)
if port == "" {
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
}

o := BuildAuthTokenOptions{}

for _, fn := range optFns {
Expand Down Expand Up @@ -94,3 +100,26 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds

return url, nil
}

func validateUrl(hostPort string) (host, port string) {
colon := strings.LastIndexByte(hostPort, ':')
if colon != -1 {
host, port = hostPort[:colon], hostPort[colon+1:]
}
if !validatePort(port) {
port = ""
return
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}

return
}

func validatePort(port string) bool {
if _, err := strconv.Atoi(port); err == nil {
return true
}
return false
}
45 changes: 35 additions & 10 deletions feature/rds/auth/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth_test
import (
"context"
"regexp"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -15,27 +16,51 @@ func TestBuildAuthToken(t *testing.T) {
region string
user string
expectedRegex string
expectedError string
}{
{
"https://prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
"prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:kakasdkasd",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
}

for _, c := range cases {
creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"}
url, err := auth.BuildAuthToken(context.Background(), c.endpoint, c.region, c.user, creds)
if err != nil {
t.Errorf("expect no error, got %v", err)
if len(c.expectedError) > 0 {
if err != nil {
if !strings.Contains(err.Error(), c.expectedError) {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
} else {
continue
}
} else {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
}
} else if err != nil {
t.Fatalf("expect no err, got: %v", err)
}

if re, a := regexp.MustCompile(c.expectedRegex), url; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
Expand Down

0 comments on commit 1549137

Please sign in to comment.