Skip to content

Commit

Permalink
validated that the right side of the colon has to be an string repres…
Browse files Browse the repository at this point in the history
…entation of an integer
  • Loading branch information
RanVaknin committed Sep 9, 2022
1 parent 5b135f8 commit 394c3f1
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 10 deletions.
29 changes: 29 additions & 0 deletions feature/rds/auth/connect.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -44,6 +45,11 @@ type BuildAuthTokenOptions struct{}
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
_, port := validateUrl(endpoint)
if port == "" {
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
}

o := BuildAuthTokenOptions{}

for _, fn := range optFns {
Expand Down Expand Up @@ -94,3 +100,26 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds

return url, nil
}

func validateUrl(hostPort string) (host, port string) {
colon := strings.LastIndexByte(hostPort, ':')
if colon != -1 {
host, port = hostPort[:colon], hostPort[colon+1:]
}
if !validatePort(port) {
port = ""
return
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}

return
}

func validatePort(port string) bool {
if _, err := strconv.Atoi(port); err == nil {
return true
}
return false
}
45 changes: 35 additions & 10 deletions feature/rds/auth/connect_test.go
Expand Up @@ -3,6 +3,7 @@ package auth_test
import (
"context"
"regexp"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -15,27 +16,51 @@ func TestBuildAuthToken(t *testing.T) {
region string
user string
expectedRegex string
expectedError string
}{
{
"https://prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
"prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:kakasdkasd",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
}

for _, c := range cases {
creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"}
url, err := auth.BuildAuthToken(context.Background(), c.endpoint, c.region, c.user, creds)
if err != nil {
t.Errorf("expect no error, got %v", err)
if len(c.expectedError) > 0 {
if err != nil {
if !strings.Contains(err.Error(), c.expectedError) {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
} else {
continue
}
} else {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
}
} else if err != nil {
t.Fatalf("expect no err, got: %v", err)
}

if re, a := regexp.MustCompile(c.expectedRegex), url; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
Expand Down

0 comments on commit 394c3f1

Please sign in to comment.