diff --git a/testserver/tenant.go b/testserver/tenant.go index 7add8b8..e22bbf5 100644 --- a/testserver/tenant.go +++ b/testserver/tenant.go @@ -254,6 +254,7 @@ func (ts *testServerImpl) NewTenantServer(proxy bool) (TestServer, error) { tenantURL := ts.pgURL[0].orig tenantURL.Host = sqlAddr tenant.pgURL = make([]pgURLChan, 1) + tenant.pgURL[0].started = make(chan struct{}) tenant.pgURL[0].set = make(chan struct{}) tenant.setPGURL(&tenantURL) diff --git a/testserver/testserver.go b/testserver/testserver.go index 7e4ac25..3a06175 100644 --- a/testserver/testserver.go +++ b/testserver/testserver.go @@ -123,6 +123,9 @@ type TestServer interface { } type pgURLChan struct { + // started will be closed after the start command is executed. + started chan struct{} + // set will be closed once the URL is available after startup. set chan struct{} u *url.URL // The original URL is preserved here if we are using a custom password. @@ -419,19 +422,14 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) { serverArgs.cockroachBinary = customBinaryEnv } - // For backwards compatibility, in the 3 node case where no args are - // specified, default to ports 26257, 26258, 26259. - if serverArgs.numNodes == 3 && len(serverArgs.listenAddrPorts) == 0 { - serverArgs.listenAddrPorts = []int{26257, 26258, 26259} - } else if serverArgs.numNodes != 1 && len(serverArgs.listenAddrPorts) != serverArgs.numNodes { + if len(serverArgs.listenAddrPorts) == 0 { + serverArgs.listenAddrPorts = make([]int, serverArgs.numNodes) + } + if serverArgs.numNodes != 1 && len(serverArgs.listenAddrPorts) != serverArgs.numNodes { panic(fmt.Sprintf("need to specify a port for each node using AddListenAddrPortOpt, got %d nodes, need %d ports", serverArgs.numNodes, len(serverArgs.listenAddrPorts))) } - if len(serverArgs.listenAddrPorts) == 0 { - serverArgs.listenAddrPorts = []int{0} - } - var err error if serverArgs.cockroachBinary != "" { log.Printf("Using custom cockroach binary: %s", serverArgs.cockroachBinary) @@ -522,13 +520,6 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) { } nodes := make([]nodeInfo, serverArgs.numNodes) - var initArgs []string - joinAddrs := make([]string, 3) - hostPort := serverArgs.listenAddrPorts[0] - for i, port := range serverArgs.listenAddrPorts { - joinAddrs[i] = fmt.Sprintf("localhost:%d", port) - } - if len(serverArgs.httpPorts) == 0 { serverArgs.httpPorts = make([]int, serverArgs.numNodes) } @@ -551,7 +542,6 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) { nodes[i].listeningURLFile = filepath.Join(nodeBaseDir, "listen-url") nodes[i].state = stateNew if serverArgs.numNodes > 1 { - joinArg := fmt.Sprintf("--join=%s", strings.Join(joinAddrs, ",")) nodes[i].startCmdArgs = []string{ serverArgs.cockroachBinary, startCmd, @@ -568,7 +558,6 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) { serverArgs.httpPorts[i], ), "--listening-url-file=" + nodes[i].listeningURLFile, - joinArg, "--external-io-dir=" + serverArgs.externalIODir, } } else { @@ -589,11 +578,10 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) { // We only need initArgs if we're creating a testserver // with multiple nodes. - initArgs = []string{ + initArgs := []string{ serverArgs.cockroachBinary, "init", secureOpt, - fmt.Sprintf("--host=localhost:%d", hostPort), } states := make([]int, serverArgs.numNodes) @@ -611,6 +599,10 @@ func NewTestServer(opts ...TestServerOpt) (TestServer, error) { nodes: nodes, } ts.pgURL = make([]pgURLChan, serverArgs.numNodes) + for i := range ts.pgURL { + ts.pgURL[i].started = make(chan struct{}) + ts.pgURL[i].set = make(chan struct{}) + } if err := ts.Start(); err != nil { return nil, fmt.Errorf("%s Start failed: %w", testserverMessagePrefix, err) @@ -676,9 +668,10 @@ func (ts *testServerImpl) setPGURLForNode(nodeNum int, u *url.URL) { } func (ts *testServerImpl) WaitForInitFinishForNode(nodeIdx int) error { + pgURL := ts.PGURLForNode(nodeIdx).String() for i := 0; i < ts.serverArgs.initTimeoutSeconds*10; i++ { err := func() error { - db, err := sql.Open("postgres", ts.PGURLForNode(nodeIdx).String()) + db, err := sql.Open("postgres", pgURL) if err != nil { return err } @@ -695,7 +688,7 @@ func (ts *testServerImpl) WaitForInitFinishForNode(nodeIdx int) error { if err == nil { return nil } - log.Printf("%s: WaitForInitFinishForNode %d: Trying again after error: %v", testserverMessagePrefix, nodeIdx, err) + log.Printf("%s: WaitForInitFinishForNode %d (%s): Trying again after error: %v", testserverMessagePrefix, nodeIdx, pgURL, err) time.Sleep(time.Millisecond * 100) } log.Printf( @@ -867,7 +860,10 @@ func (ts *testServerImpl) Stop() { } func (ts *testServerImpl) CockroachInit() error { - ts.initCmd = exec.Command(ts.initCmdArgs[0], ts.initCmdArgs[1:]...) + // The port must be computed here, since it may not be known until after + // a node is started (if the listen port is 0). + args := append(ts.initCmdArgs, fmt.Sprintf("--host=localhost:%s", ts.PGURL().Port())) + ts.initCmd = exec.Command(args[0], args[1:]...) ts.initCmd.Env = []string{ "COCKROACH_MAX_OFFSET=1ns", "COCKROACH_TRUST_CLIENT_PROVIDED_SQL_REMOTE_ADDR=true", @@ -880,7 +876,7 @@ func (ts *testServerImpl) CockroachInit() error { err := ts.initCmd.Start() if ts.initCmd.Process != nil { - log.Printf("process %d started: %s", ts.initCmd.Process.Pid, strings.Join(ts.initCmdArgs, " ")) + log.Printf("process %d started: %s", ts.initCmd.Process.Pid, strings.Join(args, " ")) } if err != nil { return err diff --git a/testserver/testserver_test.go b/testserver/testserver_test.go index b09f90c..6b5b04b 100644 --- a/testserver/testserver_test.go +++ b/testserver/testserver_test.go @@ -630,9 +630,6 @@ func TestUpgradeNode(t *testing.T) { testserver.CockroachBinaryPathOpt(absPathOldBinary), testserver.UpgradeCockroachBinaryPathOpt(absPathNewBinary), testserver.StoreOnDiskOpt(), - testserver.AddListenAddrPortOpt(26257), - testserver.AddListenAddrPortOpt(26258), - testserver.AddListenAddrPortOpt(26259), ) require.NoError(t, err) defer ts.Stop() @@ -700,14 +697,13 @@ func TestUpgradeNode(t *testing.T) { } } -var wg = sync.WaitGroup{} - // testFlockWithDownloadPassing is to test the flock over downloaded CRDB binary with // two goroutines, the second goroutine waits for the first goroutine to // finish downloading the CRDB binary into a local file. func testFlockWithDownloadPassing( t *testing.T, opts ...testserver.TestServerOpt, ) (*sql.DB, func()) { + var wg = sync.WaitGroup{} localFile, err := getLocalFile(false) if err != nil { diff --git a/testserver/testservernode.go b/testserver/testservernode.go index ba6fdeb..23668bb 100644 --- a/testserver/testservernode.go +++ b/testserver/testservernode.go @@ -17,6 +17,7 @@ package testserver import ( "fmt" "log" + "os" "os/exec" "strings" "syscall" @@ -37,6 +38,15 @@ func (ts *testServerImpl) StopNode(nodeNum int) error { return err } } + // Reset the pgURL, since it could change if the node is started later; + // specifically, if the listen port is 0 then the port will change. + ts.pgURL[nodeNum] = pgURLChan{} + ts.pgURL[nodeNum].started = make(chan struct{}) + ts.pgURL[nodeNum].set = make(chan struct{}) + + if err := os.Remove(ts.nodes[nodeNum].listeningURLFile); err != nil { + return err + } return nil } @@ -47,7 +57,40 @@ func (ts *testServerImpl) StartNode(i int) error { return fmt.Errorf("node %d already running", i) } ts.mu.RUnlock() - ts.nodes[i].startCmd = exec.Command(ts.nodes[i].startCmdArgs[0], ts.nodes[i].startCmdArgs[1:]...) + + // We need to compute the join addresses here. since if the listen port is + // 0, then the actual port will not be known until a node is started. + var joinAddrs []string + for otherNodeID := range ts.nodes { + if i == otherNodeID { + continue + } + if ts.serverArgs.listenAddrPorts[otherNodeID] != 0 { + joinAddrs = append(joinAddrs, fmt.Sprintf("localhost:%d", ts.serverArgs.listenAddrPorts[otherNodeID])) + continue + } + select { + case <-ts.pgURL[otherNodeID].started: + // PGURLForNode will block until the URL is ready. If something + // goes wrong, the goroutine waiting on pollListeningURLFile + // will time out. + joinAddrs = append(joinAddrs, fmt.Sprintf("localhost:%s", ts.PGURLForNode(otherNodeID).Port())) + default: + // If the other node hasn't started yet, don't add the join arg. + } + } + joinArg := fmt.Sprintf("--join=%s", strings.Join(joinAddrs, ",")) + + args := ts.nodes[i].startCmdArgs + if len(ts.nodes) > 1 { + if len(joinAddrs) == 0 { + // The start command always requires a --join arg, so we fake one + // if we don't have any yet. + joinArg = "--join=localhost:0" + } + args = append(args, joinArg) + } + ts.nodes[i].startCmd = exec.Command(args[0], args[1:]...) currCmd := ts.nodes[i].startCmd currCmd.Env = []string{ @@ -85,8 +128,9 @@ func (ts *testServerImpl) StartNode(i int) error { log.Printf("executing: %s", currCmd) err := currCmd.Start() + close(ts.pgURL[i].started) if currCmd.Process != nil { - log.Printf("process %d started: %s", currCmd.Process.Pid, strings.Join(ts.nodes[i].startCmdArgs, " ")) + log.Printf("process %d started. env=%s; cmd: %s", currCmd.Process.Pid, currCmd.Env, strings.Join(args, " ")) } if err != nil { log.Print(err.Error()) @@ -104,7 +148,6 @@ func (ts *testServerImpl) StartNode(i int) error { capturedI := i if ts.pgURL[capturedI].u == nil { - ts.pgURL[capturedI].set = make(chan struct{}) go func() { if err := ts.pollListeningURLFile(capturedI); err != nil { log.Printf("%s failed to poll listening URL file: %v", testserverMessagePrefix, err)