Skip to content

Commit

Permalink
Merge pull request #124304 from DarrylWong/backport24.1-123961
Browse files Browse the repository at this point in the history
release-24.1: roachprod, roachtest: use same cluster name sanitization
  • Loading branch information
DarrylWong committed May 17, 2024
2 parents 6f836aa + eab42ef commit ef92f86
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 26 deletions.
2 changes: 2 additions & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ ALL_TESTS = [
"//pkg/roachprod/vm/gce:gce_test",
"//pkg/roachprod/vm/local:local_test",
"//pkg/roachprod/vm:vm_test",
"//pkg/roachprod:roachprod_test",
"//pkg/rpc/nodedialer:nodedialer_test",
"//pkg/rpc:rpc_test",
"//pkg/scheduledjobs/schedulebase:schedulebase_test",
Expand Down Expand Up @@ -1584,6 +1585,7 @@ GO_TARGETS = [
"//pkg/roachprod/vm:vm",
"//pkg/roachprod/vm:vm_test",
"//pkg/roachprod:roachprod",
"//pkg/roachprod:roachprod_test",
"//pkg/rpc/nodedialer:nodedialer",
"//pkg/rpc/nodedialer:nodedialer_test",
"//pkg/rpc/rpcpb:rpcpb",
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/roachtest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ go_test(
"//pkg/cmd/roachtest/spec",
"//pkg/cmd/roachtest/test",
"//pkg/internal/team",
"//pkg/roachprod",
"//pkg/roachprod/errors",
"//pkg/roachprod/logger",
"//pkg/roachprod/vm",
Expand Down
18 changes: 9 additions & 9 deletions pkg/cmd/roachtest/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,15 +532,8 @@ func (r *clusterRegistry) destroyAllClusters(ctx context.Context, l *logger.Logg
}
}

func makeGCEClusterName(name string) string {
name = strings.ToLower(name)
name = regexp.MustCompile(`[^-a-z0-9]+`).ReplaceAllString(name, "-")
name = regexp.MustCompile(`-+`).ReplaceAllString(name, "-")
return name
}

func makeClusterName(name string) string {
return makeGCEClusterName(name)
return vm.DNSSafeName(name)
}

// MachineTypeToCPUs returns a CPU count for GCE, AWS, and Azure machine types.
Expand Down Expand Up @@ -846,6 +839,10 @@ func (f *clusterFactory) clusterMock(cfg clusterConfig) *clusterImpl {
}
}

// create is a hook for tests to inject their own cluster create implementation.
// i.e. unit tests that don't want to actually access a provider.
var create = roachprod.Create

// newCluster creates a new roachprod cluster.
//
// setStatus is called with status messages indicating the stage of cluster
Expand Down Expand Up @@ -956,7 +953,7 @@ func (f *clusterFactory) newCluster(

l.PrintfCtx(ctx, "Attempting cluster creation (attempt #%d/%d)", i, maxAttempts)
createVMOpts.ClusterName = c.name
err = roachprod.Create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer)
err = create(ctx, l, cfg.username, cfg.spec.NodeCount, createVMOpts, providerOptsContainer)
if err == nil {
if err := f.r.registerCluster(c); err != nil {
return nil, nil, err
Expand All @@ -972,6 +969,9 @@ func (f *clusterFactory) newCluster(
// or a destroy from the previous iteration failed.
return nil, nil, err
}
if errors.HasType(err, (*roachprod.MalformedClusterNameError)(nil)) {
return nil, nil, err
}

l.PrintfCtx(ctx, "cluster creation failed, cleaning up in case it was partially created: %s", err)
c.Destroy(ctx, closeLogger, l)
Expand Down
47 changes: 47 additions & 0 deletions pkg/cmd/roachtest/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestflags"
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec"
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
"github.com/cockroachdb/cockroach/pkg/roachprod"
"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/roachprod/vm"
"github.com/cockroachdb/cockroach/pkg/testutils"
Expand Down Expand Up @@ -433,3 +434,49 @@ func TestExitCode(t *testing.T) {
err := runExitCodeTest(t, errors.New("boom"))
require.True(t, errors.Is(err, errTestsFailed))
}

func TestNewCluster(t *testing.T) {
ctx := context.Background()
factory := &clusterFactory{sem: make(chan struct{}, 1)}
cfg := clusterConfig{spec: spec.MakeClusterSpec(1)}
setStatus := func(string) {}

defer func() {
create = roachprod.Create
}()

var createCallsCounter int

testCases := []struct {
name string
createMock func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error)
expectedCreateCalls int
}{
{
"Malformed Cluster Name Error",
func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) {
createCallsCounter++
return &roachprod.MalformedClusterNameError{}
},
1, /* expectedCreateCalls */
},
{
"Cluster Already Exists Error",
func(ctx context.Context, l *logger.Logger, username string, numNodes int, createVMOpts vm.CreateOpts, providerOptsContainer vm.ProviderOptionsContainer) (retErr error) {
createCallsCounter++
return &roachprod.ClusterAlreadyExistsError{}
},
1, /* expectedCreateCalls */
},
}

for _, c := range testCases {
t.Run(c.name, func(t *testing.T) {
createCallsCounter = 0
create = c.createMock
_, _, err := factory.newCluster(ctx, cfg, setStatus, true)
require.Error(t, err)
require.Equal(t, c.expectedCreateCalls, createCallsCounter)
})
}
}
13 changes: 12 additions & 1 deletion pkg/roachprod/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "roachprod",
Expand Down Expand Up @@ -35,3 +35,14 @@ go_library(
"@org_golang_x_sys//unix",
],
)

go_test(
name = "roachprod_test",
srcs = ["roachprod_test.go"],
embed = [":roachprod"],
deps = [
"//pkg/roachprod/logger",
"//pkg/roachprod/vm",
"@com_github_stretchr_testify//assert",
],
)
40 changes: 29 additions & 11 deletions pkg/roachprod/roachprod.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ import (
"golang.org/x/sys/unix"
)

// MalformedClusterNameError is returned when the cluster name passed to Create is invalid.
type MalformedClusterNameError struct {
name string
reason string
suggestions []string
}

func (e *MalformedClusterNameError) Error() string {
return fmt.Sprintf("Malformed cluster name %s, %s. Did you mean one of %s", e.name, e.reason, e.suggestions)
}

// findActiveAccounts is a hook for tests to inject their own FindActiveAccounts
// implementation. i.e. unit tests that don't want to actually access a provider.
var findActiveAccounts = vm.FindActiveAccounts

// verifyClusterName ensures that the given name conforms to
// our naming pattern of "<username>-<clustername>". The
// username must match one of the vm.Provider account names
Expand All @@ -65,12 +80,9 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {
return fmt.Errorf("cluster name cannot be blank")
}

alphaNum, err := regexp.Compile(`^[a-zA-Z0-9\-]+$`)
if err != nil {
return err
}
if !alphaNum.MatchString(clusterName) {
return errors.Errorf("cluster name must match %s", alphaNum.String())
sanitizedName := vm.DNSSafeName(clusterName)
if sanitizedName != clusterName {
return &MalformedClusterNameError{name: clusterName, reason: "invalid characters", suggestions: []string{sanitizedName}}
}

if config.IsLocalClusterName(clusterName) {
Expand All @@ -80,17 +92,21 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {
// Use the vm.Provider account names, or --username.
var accounts []string
if len(username) > 0 {
accounts = []string{username}
cleanAccount := vm.DNSSafeName(username)
if cleanAccount != username {
l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, username)
}
accounts = []string{cleanAccount}
} else {
seenAccounts := map[string]bool{}
active, err := vm.FindActiveAccounts(l)
active, err := findActiveAccounts(l)
if err != nil {
return err
}
for _, account := range active {
if !seenAccounts[account] {
seenAccounts[account] = true
cleanAccount := vm.DNSSafeAccount(account)
cleanAccount := vm.DNSSafeName(account)
if cleanAccount != account {
l.Printf("WARN: using `%s' as username instead of `%s'", cleanAccount, account)
}
Expand All @@ -108,26 +124,28 @@ func verifyClusterName(l *logger.Logger, clusterName, username string) error {

// Try to pick out a reasonable cluster name from the input.
var suffix string
var reason string
if i := strings.Index(clusterName, "-"); i != -1 {
// The user specified a username prefix, but it didn't match an active
// account name. For example, assuming the account is "peter", `roachprod
// create joe-perf` should be specified as `roachprod create joe-perf -u
// joe`.
suffix = clusterName[i+1:]
reason = "username prefix does not match an active account name"
} else {
// The user didn't specify a username prefix. For example, assuming the
// account is "peter", `roachprod create perf` should be specified as
// `roachprod create peter-perf`.
suffix = clusterName
reason = "cluster name should start with a username prefix: <username>-<clustername>"
}

// Suggest acceptable cluster names.
var suggestions []string
for _, account := range accounts {
suggestions = append(suggestions, fmt.Sprintf("%s-%s", account, suffix))
}
return fmt.Errorf("malformed cluster name %s, did you mean one of %s",
clusterName, suggestions)
return &MalformedClusterNameError{name: clusterName, reason: reason, suggestions: suggestions}
}

func sortedClusters() []string {
Expand Down
82 changes: 82 additions & 0 deletions pkg/roachprod/roachprod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package roachprod

import (
"io"
"testing"

"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/roachprod/vm"
"github.com/stretchr/testify/assert"
)

func nilLogger() *logger.Logger {
lcfg := logger.Config{
Stdout: io.Discard,
Stderr: io.Discard,
}
l, err := lcfg.NewLogger("" /* path */)
if err != nil {
panic(err)
}
return l
}

func TestVerifyClusterName(t *testing.T) {
findActiveAccounts = func(l *logger.Logger) (map[string]string, error) {
return map[string]string{"1": "user1", "2": "user2", "3": "USER4"}, nil
}
defer func() {
findActiveAccounts = vm.FindActiveAccounts
}()
cases := []struct {
description, clusterName, username string
errorExpected bool
}{
{
"username found", "user1-clustername", "", false,
},
{
"username not found", "user3-clustername", "", true,
},
{
"specified username", "user3-clustername", "user3", false,
},
{
"specified username that doesn't match", "user1-clustername", "fakeuser", true,
},
{
"clustername not sanitized", "UserName-clustername", "", true,
},
{
"no username", "clustername", "", true,
},
{
"no clustername", "user1", "", true,
},
{
"unsanitized found username", "user4-clustername", "", false,
},
{
"unsanitized specified username", "user3-clustername", "USER3", false,
},
}
for _, c := range cases {
t.Run(c.description, func(t *testing.T) {
if c.errorExpected {
assert.Error(t, verifyClusterName(nilLogger(), c.clusterName, c.username))
} else {
assert.NoError(t, verifyClusterName(nilLogger(), c.clusterName, c.username))
}
})
}
}
16 changes: 13 additions & 3 deletions pkg/roachprod/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,21 +697,31 @@ func ExpandZonesFlag(zoneFlag []string) (zones []string, err error) {
return zones, nil
}

// DNSSafeAccount takes a string and returns a cleaned version of the string that can be used in DNS entries.
// DNSSafeName takes a string and returns a cleaned version of the string that can be used in DNS entries.
// Unsafe characters are dropped. No length check is performed.
func DNSSafeAccount(account string) string {
func DNSSafeName(name string) string {
safe := func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= 'A' && r <= 'Z':
return unicode.ToLower(r)
case r >= '0' && r <= '9':
return r
case r == '-':
return r
default:
// Negative value tells strings.Map to drop the rune.
return -1
}
}
return strings.Map(safe, account)
name = strings.Map(safe, name)

// DNS entries cannot start or end with hyphens.
name = strings.Trim(name, "-")

// Consecutive hyphens are allowed in DNS entries, but disallow it for readability.
return regexp.MustCompile(`-+`).ReplaceAllString(name, "-")
}

// SanitizeLabel returns a version of the string that can be used as a label.
Expand Down
10 changes: 8 additions & 2 deletions pkg/roachprod/vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,18 @@ func TestDNSSafeAccount(t *testing.T) {
"dot and underscore", "u.ser_n.a_me", "username",
},
{
"Unicode and other characters", "~/❦u.ser_ऄn.a_meλ", "username",
"leading and trailing hyphens", "--username-clustername-&", "username-clustername",
},
{
"consecutive hyphens", "username---clustername", "username-clustername",
},
{
"Unicode and other characters", "~/❦--u.ser_ऄn.a_meλ", "username",
},
}
for _, c := range cases {
t.Run(c.description, func(t *testing.T) {
assert.EqualValues(t, DNSSafeAccount(c.input), c.expected)
assert.EqualValues(t, c.expected, DNSSafeName(c.input))
})
}
}
Expand Down

0 comments on commit ef92f86

Please sign in to comment.