Skip to content

Commit

Permalink
Allow non-hardcoded ports for multinode cluster
Browse files Browse the repository at this point in the history
This required the join and init args to be computed on the fly, since we
can no longer infer them until after a node has started.
  • Loading branch information
rafiss committed Mar 16, 2023
1 parent d51b7cd commit 2c9d026
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 32 deletions.
1 change: 1 addition & 0 deletions testserver/tenant.go
Expand Up @@ -254,6 +254,7 @@ func (ts *testServerImpl) NewTenantServer(proxy bool) (TestServer, error) {
tenantURL := ts.pgURL[0].orig
tenantURL.Host = sqlAddr
tenant.pgURL = make([]pgURLChan, 1)
tenant.pgURL[0].started = make(chan struct{})
tenant.pgURL[0].set = make(chan struct{})

tenant.setPGURL(&tenantURL)
Expand Down
44 changes: 20 additions & 24 deletions testserver/testserver.go
Expand Up @@ -123,6 +123,9 @@ type TestServer interface {
}

type pgURLChan struct {
// started will be closed after the start command is executed.
started chan struct{}
// set will be closed once the URL is available after startup.
set chan struct{}
u *url.URL
// The original URL is preserved here if we are using a custom password.
Expand Down Expand Up @@ -419,19 +422,14 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) {
serverArgs.cockroachBinary = customBinaryEnv
}

// For backwards compatibility, in the 3 node case where no args are
// specified, default to ports 26257, 26258, 26259.
if serverArgs.numNodes == 3 && len(serverArgs.listenAddrPorts) == 0 {
serverArgs.listenAddrPorts = []int{26257, 26258, 26259}
} else if serverArgs.numNodes != 1 && len(serverArgs.listenAddrPorts) != serverArgs.numNodes {
if len(serverArgs.listenAddrPorts) == 0 {
serverArgs.listenAddrPorts = make([]int, serverArgs.numNodes)
}
if serverArgs.numNodes != 1 && len(serverArgs.listenAddrPorts) != serverArgs.numNodes {
panic(fmt.Sprintf("need to specify a port for each node using AddListenAddrPortOpt, got %d nodes, need %d ports",
serverArgs.numNodes, len(serverArgs.listenAddrPorts)))
}

if len(serverArgs.listenAddrPorts) == 0 {
serverArgs.listenAddrPorts = []int{0}
}

var err error
if serverArgs.cockroachBinary != "" {
log.Printf("Using custom cockroach binary: %s", serverArgs.cockroachBinary)
Expand Down Expand Up @@ -522,13 +520,6 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) {
}

nodes := make([]nodeInfo, serverArgs.numNodes)
var initArgs []string
joinAddrs := make([]string, 3)
hostPort := serverArgs.listenAddrPorts[0]
for i, port := range serverArgs.listenAddrPorts {
joinAddrs[i] = fmt.Sprintf("localhost:%d", port)
}

if len(serverArgs.httpPorts) == 0 {
serverArgs.httpPorts = make([]int, serverArgs.numNodes)
}
Expand All @@ -551,7 +542,6 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) {
nodes[i].listeningURLFile = filepath.Join(nodeBaseDir, "listen-url")
nodes[i].state = stateNew
if serverArgs.numNodes > 1 {
joinArg := fmt.Sprintf("--join=%s", strings.Join(joinAddrs, ","))
nodes[i].startCmdArgs = []string{
serverArgs.cockroachBinary,
startCmd,
Expand All @@ -568,7 +558,6 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) {
serverArgs.httpPorts[i],
),
"--listening-url-file=" + nodes[i].listeningURLFile,
joinArg,
"--external-io-dir=" + serverArgs.externalIODir,
}
} else {
Expand All @@ -589,11 +578,10 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) {

// We only need initArgs if we're creating a testserver
// with multiple nodes.
initArgs = []string{
initArgs := []string{
serverArgs.cockroachBinary,
"init",
secureOpt,
fmt.Sprintf("--host=localhost:%d", hostPort),
}

states := make([]int, serverArgs.numNodes)
Expand All @@ -611,6 +599,10 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) {
nodes: nodes,
}
ts.pgURL = make([]pgURLChan, serverArgs.numNodes)
for i := range ts.pgURL {
ts.pgURL[i].started = make(chan struct{})
ts.pgURL[i].set = make(chan struct{})
}

if err := ts.Start(); err != nil {
return nil, fmt.Errorf("%s Start failed: %w", testserverMessagePrefix, err)
Expand Down Expand Up @@ -676,9 +668,10 @@ func (ts *testServerImpl) setPGURLForNode(nodeNum int, u *url.URL) {
}

func (ts *testServerImpl) WaitForInitFinishForNode(nodeIdx int) error {
pgURL := ts.PGURLForNode(nodeIdx).String()
for i := 0; i < ts.serverArgs.initTimeoutSeconds*10; i++ {
err := func() error {
db, err := sql.Open("postgres", ts.PGURLForNode(nodeIdx).String())
db, err := sql.Open("postgres", pgURL)
if err != nil {
return err
}
Expand All @@ -695,7 +688,7 @@ func (ts *testServerImpl) WaitForInitFinishForNode(nodeIdx int) error {
if err == nil {
return nil
}
log.Printf("%s: WaitForInitFinishForNode %d: Trying again after error: %v", testserverMessagePrefix, nodeIdx, err)
log.Printf("%s: WaitForInitFinishForNode %d (%s): Trying again after error: %v", testserverMessagePrefix, nodeIdx, pgURL, err)
time.Sleep(time.Millisecond * 100)
}
log.Printf(
Expand Down Expand Up @@ -867,7 +860,10 @@ func (ts *testServerImpl) Stop() {
}

func (ts *testServerImpl) CockroachInit() error {
ts.initCmd = exec.Command(ts.initCmdArgs[0], ts.initCmdArgs[1:]...)
// The port must be computed here, since it may not be known until after
// a node is started (if the listen port is 0).
args := append(ts.initCmdArgs, fmt.Sprintf("--host=localhost:%s", ts.PGURL().Port()))
ts.initCmd = exec.Command(args[0], args[1:]...)
ts.initCmd.Env = []string{
"COCKROACH_MAX_OFFSET=1ns",
"COCKROACH_TRUST_CLIENT_PROVIDED_SQL_REMOTE_ADDR=true",
Expand All @@ -880,7 +876,7 @@ func (ts *testServerImpl) CockroachInit() error {

err := ts.initCmd.Start()
if ts.initCmd.Process != nil {
log.Printf("process %d started: %s", ts.initCmd.Process.Pid, strings.Join(ts.initCmdArgs, " "))
log.Printf("process %d started: %s", ts.initCmd.Process.Pid, strings.Join(args, " "))
}
if err != nil {
return err
Expand Down
6 changes: 1 addition & 5 deletions testserver/testserver_test.go
Expand Up @@ -630,9 +630,6 @@ func TestUpgradeNode(t *testing.T) {
testserver.CockroachBinaryPathOpt(absPathOldBinary),
testserver.UpgradeCockroachBinaryPathOpt(absPathNewBinary),
testserver.StoreOnDiskOpt(),
testserver.AddListenAddrPortOpt(26257),
testserver.AddListenAddrPortOpt(26258),
testserver.AddListenAddrPortOpt(26259),
)
require.NoError(t, err)
defer ts.Stop()
Expand Down Expand Up @@ -700,14 +697,13 @@ func TestUpgradeNode(t *testing.T) {
}
}

var wg = sync.WaitGroup{}

// testFlockWithDownloadPassing is to test the flock over downloaded CRDB binary with
// two goroutines, the second goroutine waits for the first goroutine to
// finish downloading the CRDB binary into a local file.
func testFlockWithDownloadPassing(
t *testing.T, opts ...testserver.TestServerOpt,
) (*sql.DB, func()) {
var wg = sync.WaitGroup{}

localFile, err := getLocalFile(false)
if err != nil {
Expand Down
49 changes: 46 additions & 3 deletions testserver/testservernode.go
Expand Up @@ -17,6 +17,7 @@ package testserver
import (
"fmt"
"log"
"os"
"os/exec"
"strings"
"syscall"
Expand All @@ -37,6 +38,15 @@ func (ts *testServerImpl) StopNode(nodeNum int) error {
return err
}
}
// Reset the pgURL, since it could change if the node is started later;
// specifically, if the listen port is 0 then the port will change.
ts.pgURL[nodeNum] = pgURLChan{}
ts.pgURL[nodeNum].started = make(chan struct{})
ts.pgURL[nodeNum].set = make(chan struct{})

if err := os.Remove(ts.nodes[nodeNum].listeningURLFile); err != nil {
return err
}

return nil
}
Expand All @@ -47,7 +57,40 @@ func (ts *testServerImpl) StartNode(i int) error {
return fmt.Errorf("node %d already running", i)
}
ts.mu.RUnlock()
ts.nodes[i].startCmd = exec.Command(ts.nodes[i].startCmdArgs[0], ts.nodes[i].startCmdArgs[1:]...)

// We need to compute the join addresses here. since if the listen port is
// 0, then the actual port will not be known until a node is started.
var joinAddrs []string
for otherNodeID := range ts.nodes {
if i == otherNodeID {
continue
}
if ts.serverArgs.listenAddrPorts[otherNodeID] != 0 {
joinAddrs = append(joinAddrs, fmt.Sprintf("localhost:%d", ts.serverArgs.listenAddrPorts[otherNodeID]))
continue
}
select {
case <-ts.pgURL[otherNodeID].started:
// PGURLForNode will block until the URL is ready. If something
// goes wrong, the goroutine waiting on pollListeningURLFile
// will time out.
joinAddrs = append(joinAddrs, fmt.Sprintf("localhost:%s", ts.PGURLForNode(otherNodeID).Port()))
default:
// If the other node hasn't started yet, don't add the join arg.
}
}
joinArg := fmt.Sprintf("--join=%s", strings.Join(joinAddrs, ","))

args := ts.nodes[i].startCmdArgs
if len(ts.nodes) > 1 {
if len(joinAddrs) == 0 {
// The start command always requires a --join arg, so we fake one
// if we don't have any yet.
joinArg = "--join=localhost:0"
}
args = append(args, joinArg)
}
ts.nodes[i].startCmd = exec.Command(args[0], args[1:]...)

currCmd := ts.nodes[i].startCmd
currCmd.Env = []string{
Expand Down Expand Up @@ -85,8 +128,9 @@ func (ts *testServerImpl) StartNode(i int) error {

log.Printf("executing: %s", currCmd)
err := currCmd.Start()
close(ts.pgURL[i].started)
if currCmd.Process != nil {
log.Printf("process %d started: %s", currCmd.Process.Pid, strings.Join(ts.nodes[i].startCmdArgs, " "))
log.Printf("process %d started. env=%s; cmd: %s", currCmd.Process.Pid, currCmd.Env, strings.Join(args, " "))
}
if err != nil {
log.Print(err.Error())
Expand All @@ -104,7 +148,6 @@ func (ts *testServerImpl) StartNode(i int) error {
capturedI := i

if ts.pgURL[capturedI].u == nil {
ts.pgURL[capturedI].set = make(chan struct{})
go func() {
if err := ts.pollListeningURLFile(capturedI); err != nil {
log.Printf("%s failed to poll listening URL file: %v", testserverMessagePrefix, err)
Expand Down

0 comments on commit 2c9d026

Please sign in to comment.