diff --git a/testserver/binaries.go b/testserver/binaries.go index 1e93873..beaf690 100644 --- a/testserver/binaries.go +++ b/testserver/binaries.go @@ -57,12 +57,12 @@ const updatesUrl = "https://register.cockroachdb.com/api/updates" var muslRE = regexp.MustCompile(`(?i)\bmusl\b`) -// GetDownloadResponse return the http response of a CRDB download. -// It creates the url for downloading a CRDB binary for current runtime OS, -// makes a request to this url, and return the response. -// nonStable should only be used if desiredVersion is not specified. If nonStable -// is used, the latest cockroach binary will be used. -func GetDownloadResponse(desiredVersion string, nonStable bool) (*http.Response, string, error) { +// GetDownloadURL returns the URL of a CRDB download. It creates the URL for +// downloading a CRDB binary for current runtime OS. If desiredVersion is +// specified, it will return the URL of the specified version. Otherwise, it +// will return the URL of the latest stable cockroach binary. If nonStable is +// true, the latest cockroach binary will be used. +func GetDownloadURL(desiredVersion string, nonStable bool) (string, string, error) { goos := runtime.GOOS if goos == "linux" { goos += func() string { @@ -100,21 +100,31 @@ func GetDownloadResponse(desiredVersion string, nonStable bool) (*http.Response, // For the latest stable CRDB, we use the url provided in the CRDB release page. dbUrl, desiredVersion, err = getLatestStableVersionInfo() if err != nil { - return nil, "", err + return dbUrl, "", err } } - log.Printf("GET %s", dbUrl) - response, err := http.Get(dbUrl) + return dbUrl, desiredVersion, nil +} + +// DownloadFromURL starts a download of the cockroach binary from the given URL. +func DownloadFromURL(downloadURL string) (*http.Response, error) { + log.Printf("GET %s", downloadURL) + response, err := http.Get(downloadURL) if err != nil { - return nil, "", err + return nil, err } if response.StatusCode != 200 { - return nil, "", fmt.Errorf("error downloading %s: %d (%s)", dbUrl, - response.StatusCode, response.Status) + return nil, fmt.Errorf( + "error downloading %s: %d (%s)", + downloadURL, + response.StatusCode, + response.Status, + ) } - return response, desiredVersion, nil + + return response, nil } // DownloadBinary saves the latest version of CRDB into a local binary file, @@ -124,19 +134,30 @@ func GetDownloadResponse(desiredVersion string, nonStable bool) (*http.Response, // To download the latest STABLE version of CRDB, set `nonStable` to false. // To download the bleeding edge version of CRDB, set `nonStable` to true. func DownloadBinary(tc *TestConfig, desiredVersion string, nonStable bool) (string, error) { - response, desiredVersion, err := GetDownloadResponse(desiredVersion, nonStable) + dbUrl, desiredVersion, err := GetDownloadURL(desiredVersion, nonStable) if err != nil { return "", err } - defer func() { _ = response.Body.Close() }() + filename, err := GetDownloadFilename(desiredVersion) + if err != nil { + return "", err + } + localFile := filepath.Join(os.TempDir(), filename) + + // Short circuit if the file already exists and is in the finished state. + info, err := os.Stat(localFile) + if err == nil && info.Mode().Perm() == finishedFileMode { + return localFile, nil + } - filename, err := GetDownloadFilename(response, nonStable, desiredVersion) + response, err := DownloadFromURL(dbUrl) if err != nil { return "", err } - localFile := filepath.Join(os.TempDir(), filename) + defer func() { _ = response.Body.Close() }() + for { info, err := os.Stat(localFile) if os.IsNotExist(err) { @@ -229,7 +250,7 @@ func DownloadBinary(tc *TestConfig, desiredVersion string, nonStable bool) (stri // GetDownloadFilename returns the local filename of the downloaded CRDB binary file. func GetDownloadFilename( - response *http.Response, nonStableDB bool, desiredVersion string, + desiredVersion string, ) (string, error) { filename := fmt.Sprintf("cockroach-%s", desiredVersion) if runtime.GOOS == "windows" { diff --git a/testserver/testserver_test.go b/testserver/testserver_test.go index 01e5e77..9bfa018 100644 --- a/testserver/testserver_test.go +++ b/testserver/testserver_test.go @@ -762,11 +762,11 @@ func testFlockWithDownloadKilled(t *testing.T) (*sql.DB, func()) { // getLocalFile returns the to-be-downloaded CRDB binary's local path. func getLocalFile(nonStable bool) (string, error) { - response, latestStableVersion, err := testserver.GetDownloadResponse("", nonStable) + _, latestStableVersion, err := testserver.GetDownloadURL("", nonStable) if err != nil { return "", err } - filename, err := testserver.GetDownloadFilename(response, nonStable, latestStableVersion) + filename, err := testserver.GetDownloadFilename(latestStableVersion) if err != nil { return "", err }