diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5af2e646..047a8703 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest, macos-latest, windows-latest ] - go: [ '1.16', '1.17' ] + go: [ '1.17' ] runs-on: ${{ matrix.os }} diff --git a/account.go b/account.go index 80be2ed0..f7cef7c4 100644 --- a/account.go +++ b/account.go @@ -37,8 +37,8 @@ import ( // getAccount either loads or creates a new account, depending on if // an account can be found in storage for the given CA + email combo. -func (am *ACMEManager) getAccount(ca, email string) (acme.Account, error) { - acct, err := am.loadAccount(ca, email) +func (am *ACMEManager) getAccount(ctx context.Context, ca, email string) (acme.Account, error) { + acct, err := am.loadAccount(ctx, ca, email) if errors.Is(err, fs.ErrNotExist) { return am.newAccount(email) } @@ -46,12 +46,12 @@ func (am *ACMEManager) getAccount(ca, email string) (acme.Account, error) { } // loadAccount loads an account from storage, but does not create a new one. -func (am *ACMEManager) loadAccount(ca, email string) (acme.Account, error) { - regBytes, err := am.config.Storage.Load(am.storageKeyUserReg(ca, email)) +func (am *ACMEManager) loadAccount(ctx context.Context, ca, email string) (acme.Account, error) { + regBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserReg(ca, email)) if err != nil { return acme.Account{}, err } - keyBytes, err := am.config.Storage.Load(am.storageKeyUserPrivateKey(ca, email)) + keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(ca, email)) if err != nil { return acme.Account{}, err } @@ -99,18 +99,18 @@ func (am *ACMEManager) GetAccount(ctx context.Context, privateKeyPEM []byte) (ac // If it does not exist, an error of type fs.ErrNotExist is returned. This is not very efficient // for lots of accounts. func (am *ACMEManager) loadAccountByKey(ctx context.Context, privateKeyPEM []byte) (acme.Account, error) { - accountList, err := am.config.Storage.List(am.storageKeyUsersPrefix(am.CA), false) + accountList, err := am.config.Storage.List(ctx, am.storageKeyUsersPrefix(am.CA), false) if err != nil { return acme.Account{}, err } for _, accountFolderKey := range accountList { email := path.Base(accountFolderKey) - keyBytes, err := am.config.Storage.Load(am.storageKeyUserPrivateKey(am.CA, email)) + keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email)) if err != nil { return acme.Account{}, err } if bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) { - return am.loadAccount(am.CA, email) + return am.loadAccount(ctx, am.CA, email) } } return acme.Account{}, fs.ErrNotExist @@ -137,7 +137,7 @@ func (am *ACMEManager) lookUpAccount(ctx context.Context, privateKeyPEM []byte) } // save the account details to storage - err = am.saveAccount(client.Directory, account) + err = am.saveAccount(ctx, client.Directory, account) if err != nil { return account, fmt.Errorf("could not save account to storage: %v", err) } @@ -147,7 +147,7 @@ func (am *ACMEManager) lookUpAccount(ctx context.Context, privateKeyPEM []byte) // saveAccount persists an ACME account's info and private key to storage. // It does NOT register the account via ACME or prompt the user. -func (am *ACMEManager) saveAccount(ca string, account acme.Account) error { +func (am *ACMEManager) saveAccount(ctx context.Context, ca string, account acme.Account) error { regBytes, err := json.MarshalIndent(account, "", "\t") if err != nil { return err @@ -168,7 +168,7 @@ func (am *ACMEManager) saveAccount(ca string, account acme.Account) error { value: keyBytes, }, } - return storeTx(am.config.Storage, all) + return storeTx(ctx, am.config.Storage, all) } // getEmail does everything it can to obtain an email address @@ -178,7 +178,7 @@ func (am *ACMEManager) saveAccount(ca string, account acme.Account) error { // the consequences of an empty email.) This function MAY prompt // the user for input. If allowPrompts is false, the user // will NOT be prompted and an empty email may be returned. -func (am *ACMEManager) getEmail(allowPrompts bool) error { +func (am *ACMEManager) getEmail(ctx context.Context, allowPrompts bool) error { leEmail := am.Email // First try package default email, or a discovered email address @@ -194,7 +194,7 @@ func (am *ACMEManager) getEmail(allowPrompts bool) error { // Then try to get most recent user email from storage var gotRecentEmail bool if leEmail == "" { - leEmail, gotRecentEmail = am.mostRecentAccountEmail(am.CA) + leEmail, gotRecentEmail = am.mostRecentAccountEmail(ctx, am.CA) } if !gotRecentEmail && leEmail == "" && allowPrompts { // Looks like there is no email address readily available, @@ -331,8 +331,8 @@ func (*ACMEManager) emailUsername(email string) string { // in storage. Since this is part of a complex sequence to get a user // account, errors here are discarded to simplify code flow in // the caller, and errors are not important here anyway. -func (am *ACMEManager) mostRecentAccountEmail(caURL string) (string, bool) { - accountList, err := am.config.Storage.List(am.storageKeyUsersPrefix(caURL), false) +func (am *ACMEManager) mostRecentAccountEmail(ctx context.Context, caURL string) (string, bool) { + accountList, err := am.config.Storage.List(ctx, am.storageKeyUsersPrefix(caURL), false) if err != nil || len(accountList) == 0 { return "", false } @@ -342,7 +342,7 @@ func (am *ACMEManager) mostRecentAccountEmail(caURL string) (string, bool) { stats := make(map[string]KeyInfo) for i := 0; i < len(accountList); i++ { u := accountList[i] - keyInfo, err := am.config.Storage.Stat(u) + keyInfo, err := am.config.Storage.Stat(ctx, u) if err != nil { continue } @@ -370,7 +370,7 @@ func (am *ACMEManager) mostRecentAccountEmail(caURL string) (string, bool) { return "", false } - account, err := am.getAccount(caURL, path.Base(accountList[0])) + account, err := am.getAccount(ctx, caURL, path.Base(accountList[0])) if err != nil { return "", false } diff --git a/account_test.go b/account_test.go index 2689da09..e4ea11e4 100644 --- a/account_test.go +++ b/account_test.go @@ -16,6 +16,7 @@ package certmagic import ( "bytes" + "context" "os" "path/filepath" "reflect" @@ -50,6 +51,8 @@ func TestNewAccount(t *testing.T) { } func TestSaveAccount(t *testing.T) { + ctx := context.Background() + am := &ACMEManager{CA: dummyCA} testConfig := &Config{ Issuers: []Issuer{am}, @@ -72,17 +75,19 @@ func TestSaveAccount(t *testing.T) { t.Fatalf("Error creating account: %v", err) } - err = am.saveAccount(am.CA, account) + err = am.saveAccount(ctx, am.CA, account) if err != nil { t.Fatalf("Error saving account: %v", err) } - _, err = am.getAccount(am.CA, email) + _, err = am.getAccount(ctx, am.CA, email) if err != nil { t.Errorf("Cannot access account data, error: %v", err) } } func TestGetAccountDoesNotAlreadyExist(t *testing.T) { + ctx := context.Background() + am := &ACMEManager{CA: dummyCA} testConfig := &Config{ Issuers: []Issuer{am}, @@ -91,7 +96,7 @@ func TestGetAccountDoesNotAlreadyExist(t *testing.T) { } am.config = testConfig - account, err := am.getAccount(am.CA, "account_does_not_exist@foobar.com") + account, err := am.getAccount(ctx, am.CA, "account_does_not_exist@foobar.com") if err != nil { t.Fatalf("Error getting account: %v", err) } @@ -102,6 +107,8 @@ func TestGetAccountDoesNotAlreadyExist(t *testing.T) { } func TestGetAccountAlreadyExists(t *testing.T) { + ctx := context.Background() + am := &ACMEManager{CA: dummyCA} testConfig := &Config{ Issuers: []Issuer{am}, @@ -125,13 +132,13 @@ func TestGetAccountAlreadyExists(t *testing.T) { if err != nil { t.Fatalf("Error creating account: %v", err) } - err = am.saveAccount(am.CA, account) + err = am.saveAccount(ctx, am.CA, account) if err != nil { t.Fatalf("Error saving account: %v", err) } // Expect to load account from disk - loadedAccount, err := am.getAccount(am.CA, email) + loadedAccount, err := am.getAccount(ctx, am.CA, email) if err != nil { t.Fatalf("Error getting account: %v", err) } @@ -148,6 +155,8 @@ func TestGetAccountAlreadyExists(t *testing.T) { } func TestGetEmailFromPackageDefault(t *testing.T) { + ctx := context.Background() + DefaultACME.Email = "tEsT2@foo.com" defer func() { DefaultACME.Email = "" @@ -162,7 +171,7 @@ func TestGetEmailFromPackageDefault(t *testing.T) { } am.config = testConfig - err := am.getEmail(true) + err := am.getEmail(ctx, true) if err != nil { t.Fatalf("getEmail error: %v", err) } @@ -173,6 +182,8 @@ func TestGetEmailFromPackageDefault(t *testing.T) { } func TestGetEmailFromUserInput(t *testing.T) { + ctx := context.Background() + am := &ACMEManager{CA: dummyCA} testConfig := &Config{ Issuers: []Issuer{am}, @@ -192,7 +203,7 @@ func TestGetEmailFromUserInput(t *testing.T) { email := "test3@foo.com" stdin = bytes.NewBufferString(email + "\n") - err := am.getEmail(true) + err := am.getEmail(ctx, true) if err != nil { t.Fatalf("getEmail error: %v", err) } @@ -205,6 +216,8 @@ func TestGetEmailFromUserInput(t *testing.T) { } func TestGetEmailFromRecent(t *testing.T) { + ctx := context.Background() + am := &ACMEManager{CA: dummyCA} testConfig := &Config{ Issuers: []Issuer{am}, @@ -233,7 +246,7 @@ func TestGetEmailFromRecent(t *testing.T) { if err != nil { t.Fatalf("Error creating user %d: %v", i, err) } - err = am.saveAccount(am.CA, account) + err = am.saveAccount(ctx, am.CA, account) if err != nil { t.Fatalf("Error saving user %d: %v", i, err) } @@ -250,7 +263,7 @@ func TestGetEmailFromRecent(t *testing.T) { t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) } } - err := am.getEmail(true) + err := am.getEmail(ctx, true) if err != nil { t.Fatalf("getEmail error: %v", err) } diff --git a/acmeclient.go b/acmeclient.go index a22dc19a..5104193f 100644 --- a/acmeclient.go +++ b/acmeclient.go @@ -61,7 +61,7 @@ func (am *ACMEManager) newACMEClientWithAccount(ctx context.Context, useTestCA, if am.AccountKeyPEM != "" { account, err = am.GetAccount(ctx, []byte(am.AccountKeyPEM)) } else { - account, err = am.getAccount(client.Directory, am.Email) + account, err = am.getAccount(ctx, client.Directory, am.Email) } if err != nil { return nil, fmt.Errorf("getting ACME account: %v", err) @@ -116,7 +116,7 @@ func (am *ACMEManager) newACMEClientWithAccount(ctx context.Context, useTestCA, } // persist the account to storage - err = am.saveAccount(client.Directory, account) + err = am.saveAccount(ctx, client.Directory, account) if err != nil { return nil, fmt.Errorf("could not save account %v: %v", account.Contact, err) } diff --git a/acmemanager.go b/acmemanager.go index 82b6cc12..d5f15ec7 100644 --- a/acmemanager.go +++ b/acmemanager.go @@ -217,7 +217,7 @@ func (*ACMEManager) issuerKey(ca string) string { // renewing a certificate with ACME, and returns whether this // batch is eligible for certificates if using Let's Encrypt. // It also ensures that an email address is available. -func (am *ACMEManager) PreCheck(_ context.Context, names []string, interactive bool) error { +func (am *ACMEManager) PreCheck(ctx context.Context, names []string, interactive bool) error { publicCA := strings.Contains(am.CA, "api.letsencrypt.org") || strings.Contains(am.CA, "acme.zerossl.com") if publicCA { for _, name := range names { @@ -226,7 +226,7 @@ func (am *ACMEManager) PreCheck(_ context.Context, names []string, interactive b } } } - return am.getEmail(interactive) + return am.getEmail(ctx, interactive) } // Issue implements the Issuer interface. It obtains a certificate for the given csr using diff --git a/certificates.go b/certificates.go index d60a46dd..47bcd71d 100644 --- a/certificates.go +++ b/certificates.go @@ -15,6 +15,7 @@ package certmagic import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -115,8 +116,8 @@ func (cert Certificate) HasTag(tag string) bool { // This is a lower-level method; normally you'll call Manage() instead. // // This method is safe for concurrent use. -func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) { - cert, err := cfg.loadManagedCertificate(domain) +func (cfg *Config) CacheManagedCertificate(ctx context.Context, domain string) (Certificate, error) { + cert, err := cfg.loadManagedCertificate(ctx, domain) if err != nil { return cert, err } @@ -128,12 +129,12 @@ func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) { // loadManagedCertificate loads the managed certificate for domain from any // of the configured issuers' storage locations, but it does not add it to // the cache. It just loads from storage and returns it. -func (cfg *Config) loadManagedCertificate(domain string) (Certificate, error) { - certRes, err := cfg.loadCertResourceAnyIssuer(domain) +func (cfg *Config) loadManagedCertificate(ctx context.Context, domain string) (Certificate, error) { + certRes, err := cfg.loadCertResourceAnyIssuer(ctx, domain) if err != nil { return Certificate{}, err } - cert, err := cfg.makeCertificateWithOCSP(certRes.CertificatePEM, certRes.PrivateKeyPEM) + cert, err := cfg.makeCertificateWithOCSP(ctx, certRes.CertificatePEM, certRes.PrivateKeyPEM) if err != nil { return cert, err } @@ -147,8 +148,8 @@ func (cfg *Config) loadManagedCertificate(domain string) (Certificate, error) { // the in-memory cache. // // This method is safe for concurrent use. -func (cfg *Config) CacheUnmanagedCertificatePEMFile(certFile, keyFile string, tags []string) error { - cert, err := cfg.makeCertificateFromDiskWithOCSP(cfg.Storage, certFile, keyFile) +func (cfg *Config) CacheUnmanagedCertificatePEMFile(ctx context.Context, certFile, keyFile string, tags []string) error { + cert, err := cfg.makeCertificateFromDiskWithOCSP(ctx, cfg.Storage, certFile, keyFile) if err != nil { return err } @@ -162,13 +163,13 @@ func (cfg *Config) CacheUnmanagedCertificatePEMFile(certFile, keyFile string, ta // It staples OCSP if possible. // // This method is safe for concurrent use. -func (cfg *Config) CacheUnmanagedTLSCertificate(tlsCert tls.Certificate, tags []string) error { +func (cfg *Config) CacheUnmanagedTLSCertificate(ctx context.Context, tlsCert tls.Certificate, tags []string) error { var cert Certificate err := fillCertFromLeaf(&cert, tlsCert) if err != nil { return err } - err = stapleOCSP(cfg.OCSP, cfg.Storage, &cert, nil) + err = stapleOCSP(ctx, cfg.OCSP, cfg.Storage, &cert, nil) if err != nil && cfg.Logger != nil { cfg.Logger.Warn("stapling OCSP", zap.Error(err)) } @@ -182,8 +183,8 @@ func (cfg *Config) CacheUnmanagedTLSCertificate(tlsCert tls.Certificate, tags [] // of the certificate and key, then caches it in memory. // // This method is safe for concurrent use. -func (cfg *Config) CacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte, tags []string) error { - cert, err := cfg.makeCertificateWithOCSP(certBytes, keyBytes) +func (cfg *Config) CacheUnmanagedCertificatePEMBytes(ctx context.Context, certBytes, keyBytes []byte, tags []string) error { + cert, err := cfg.makeCertificateWithOCSP(ctx, certBytes, keyBytes) if err != nil { return err } @@ -197,7 +198,7 @@ func (cfg *Config) CacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte, // certificate and key files. It fills out all the fields in // the certificate except for the Managed and OnDemand flags. // (It is up to the caller to set those.) It staples OCSP. -func (cfg Config) makeCertificateFromDiskWithOCSP(storage Storage, certFile, keyFile string) (Certificate, error) { +func (cfg Config) makeCertificateFromDiskWithOCSP(ctx context.Context, storage Storage, certFile, keyFile string) (Certificate, error) { certPEMBlock, err := os.ReadFile(certFile) if err != nil { return Certificate{}, err @@ -206,17 +207,17 @@ func (cfg Config) makeCertificateFromDiskWithOCSP(storage Storage, certFile, key if err != nil { return Certificate{}, err } - return cfg.makeCertificateWithOCSP(certPEMBlock, keyPEMBlock) + return cfg.makeCertificateWithOCSP(ctx, certPEMBlock, keyPEMBlock) } // makeCertificateWithOCSP is the same as makeCertificate except that it also // staples OCSP to the certificate. -func (cfg Config) makeCertificateWithOCSP(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { +func (cfg Config) makeCertificateWithOCSP(ctx context.Context, certPEMBlock, keyPEMBlock []byte) (Certificate, error) { cert, err := makeCertificate(certPEMBlock, keyPEMBlock) if err != nil { return cert, err } - err = stapleOCSP(cfg.OCSP, cfg.Storage, &cert, certPEMBlock) + err = stapleOCSP(ctx, cfg.OCSP, cfg.Storage, &cert, certPEMBlock) if err != nil && cfg.Logger != nil { cfg.Logger.Warn("stapling OCSP", zap.Error(err), zap.Strings("identifiers", cert.Names)) } @@ -310,8 +311,8 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error { // means that another instance renewed the certificate in the // meantime, and it would be a good idea to simply load the cert // into our cache rather than repeating the renewal process again. -func (cfg *Config) managedCertInStorageExpiresSoon(cert Certificate) (bool, error) { - certRes, err := cfg.loadCertResourceAnyIssuer(cert.Names[0]) +func (cfg *Config) managedCertInStorageExpiresSoon(ctx context.Context, cert Certificate) (bool, error) { + certRes, err := cfg.loadCertResourceAnyIssuer(ctx, cert.Names[0]) if err != nil { return false, err } @@ -324,11 +325,11 @@ func (cfg *Config) managedCertInStorageExpiresSoon(cert Certificate) (bool, erro // with the new one, so that all configurations that used the old cert now point // to the new cert. It assumes that the new certificate for oldCert.Names[0] is // already in storage. It returns the newly-loaded certificate if successful. -func (cfg *Config) reloadManagedCertificate(oldCert Certificate) (Certificate, error) { +func (cfg *Config) reloadManagedCertificate(ctx context.Context, oldCert Certificate) (Certificate, error) { if cfg.Logger != nil { cfg.Logger.Info("reloading managed certificate", zap.Strings("identifiers", oldCert.Names)) } - newCert, err := cfg.loadManagedCertificate(oldCert.Names[0]) + newCert, err := cfg.loadManagedCertificate(ctx, oldCert.Names[0]) if err != nil { return Certificate{}, fmt.Errorf("loading managed certificate for %v from storage: %v", oldCert.Names, err) } diff --git a/certmagic.go b/certmagic.go index 7944a666..b909f235 100644 --- a/certmagic.go +++ b/certmagic.go @@ -51,7 +51,7 @@ import ( // HTTPS serves mux for all domainNames using the HTTP // and HTTPS ports, redirecting all HTTP requests to HTTPS. -// It uses the Default config. +// It uses the Default config and a background context. // // This high-level convenience function is opinionated and // applies sane defaults for production use, including @@ -66,6 +66,8 @@ import ( // Calling this function signifies your acceptance to // the CA's Subscriber Agreement and/or Terms of Service. func HTTPS(domainNames []string, mux http.Handler) error { + ctx := context.Background() + if mux == nil { mux = http.DefaultServeMux } @@ -73,7 +75,7 @@ func HTTPS(domainNames []string, mux http.Handler) error { DefaultACME.Agreed = true cfg := NewDefault() - err := cfg.ManageSync(context.Background(), domainNames) + err := cfg.ManageSync(ctx, domainNames) if err != nil { return err } @@ -124,6 +126,7 @@ func HTTPS(domainNames []string, mux http.Handler) error { ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, IdleTimeout: 5 * time.Second, + BaseContext: func(listener net.Listener) context.Context { return ctx }, } if len(cfg.Issuers) > 0 { if am, ok := cfg.Issuers[0].(*ACMEManager); ok { @@ -136,6 +139,7 @@ func HTTPS(domainNames []string, mux http.Handler) error { WriteTimeout: 2 * time.Minute, IdleTimeout: 5 * time.Minute, Handler: mux, + BaseContext: func(listener net.Listener) context.Context { return ctx }, } log.Printf("%v Serving HTTP->HTTPS on %s and %s", diff --git a/config.go b/config.go index 88830782..84779f79 100644 --- a/config.go +++ b/config.go @@ -289,7 +289,7 @@ func (cfg *Config) ClientCredentials(ctx context.Context, identifiers []string) } var chains []tls.Certificate for _, id := range identifiers { - certRes, err := cfg.loadCertResourceAnyIssuer(id) + certRes, err := cfg.loadCertResourceAnyIssuer(ctx, id) if err != nil { return chains, err } @@ -328,7 +328,7 @@ func (cfg *Config) manageAll(ctx context.Context, domainNames []string, async bo func (cfg *Config) manageOne(ctx context.Context, domainName string, async bool) error { // first try loading existing certificate from storage - cert, err := cfg.CacheManagedCertificate(domainName) + cert, err := cfg.CacheManagedCertificate(ctx, domainName) if err != nil { if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("%s: caching certificate: %v", domainName, err) @@ -344,7 +344,7 @@ func (cfg *Config) manageOne(ctx context.Context, domainName string, async bool) if err != nil { return fmt.Errorf("%s: obtaining certificate: %w", domainName, err) } - cert, err = cfg.CacheManagedCertificate(domainName) + cert, err = cfg.CacheManagedCertificate(ctx, domainName) if err != nil { return fmt.Errorf("%s: caching certificate after obtaining it: %v", domainName, err) } @@ -390,7 +390,7 @@ func (cfg *Config) manageOne(ctx context.Context, domainName string, async bool) return fmt.Errorf("%s: renewing certificate: %w", domainName, err) } // successful renewal, so update in-memory cache - _, err = cfg.reloadManagedCertificate(cert) + _, err = cfg.reloadManagedCertificate(ctx, cert) if err != nil { return fmt.Errorf("%s: reloading renewed certificate into memory: %v", domainName, err) } @@ -449,13 +449,13 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool } // if storage has all resources for this certificate, obtain is a no-op - if cfg.storageHasCertResourcesAnyIssuer(name) { + if cfg.storageHasCertResourcesAnyIssuer(ctx, name) { return nil } // ensure storage is writeable and readable // TODO: this is not necessary every time; should only perform check once every so often for each storage, which may require some global state... - err := cfg.checkStorage() + err := cfg.checkStorage(ctx) if err != nil { return fmt.Errorf("failed storage check: %v - storage is probably misconfigured", err) } @@ -476,7 +476,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool if log != nil { log.Info("releasing lock", zap.String("identifier", name)) } - if err := releaseLock(cfg.Storage, lockKey); err != nil { + if err := releaseLock(ctx, cfg.Storage, lockKey); err != nil { if log != nil { log.Error("unable to unlock", zap.String("identifier", name), @@ -491,7 +491,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool f := func(ctx context.Context) error { // check if obtain is still needed -- might have been obtained during lock - if cfg.storageHasCertResourcesAnyIssuer(name) { + if cfg.storageHasCertResourcesAnyIssuer(ctx, name) { if log != nil { log.Info("certificate already exists in storage", zap.String("identifier", name)) } @@ -500,7 +500,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool // if storage has a private key already, use it; otherwise, // we'll generate our own - privKey, privKeyPEM, issuers, err := cfg.reusePrivateKey(name) + privKey, privKeyPEM, issuers, err := cfg.reusePrivateKey(ctx, name) if err != nil { return err } @@ -568,7 +568,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool PrivateKeyPEM: privKeyPEM, IssuerData: issuedCert.Metadata, } - err = cfg.saveCertResource(issuerUsed, certRes) + err = cfg.saveCertResource(ctx, issuerUsed, certRes) if err != nil { return fmt.Errorf("[%s] Obtain: saving assets: %v", name, err) } @@ -596,7 +596,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool // as well as the reordered list of issuers to use instead of cfg.Issuers (because if a key // is found, that issuer should be tried first, so it is moved to the front in a copy of // cfg.Issuers). -func (cfg *Config) reusePrivateKey(domain string) (privKey crypto.PrivateKey, privKeyPEM []byte, issuers []Issuer, err error) { +func (cfg *Config) reusePrivateKey(ctx context.Context, domain string) (privKey crypto.PrivateKey, privKeyPEM []byte, issuers []Issuer, err error) { // make a copy of cfg.Issuers so that if we have to reorder elements, we don't // inadvertently mutate the configured issuers (see append calls below) issuers = make([]Issuer, len(cfg.Issuers)) @@ -605,7 +605,7 @@ func (cfg *Config) reusePrivateKey(domain string) (privKey crypto.PrivateKey, pr for i, issuer := range issuers { // see if this issuer location in storage has a private key for the domain privateKeyStorageKey := StorageKeys.SitePrivateKey(issuer.IssuerKey(), domain) - privKeyPEM, err = cfg.Storage.Load(privateKeyStorageKey) + privKeyPEM, err = cfg.Storage.Load(ctx, privateKeyStorageKey) if errors.Is(err, fs.ErrNotExist) { err = nil // obviously, it's OK to not have a private key; so don't prevent obtaining a cert continue @@ -632,9 +632,9 @@ func (cfg *Config) reusePrivateKey(domain string) (privKey crypto.PrivateKey, pr // storageHasCertResourcesAnyIssuer returns true if storage has all the // certificate resources in storage from any configured issuer. It checks // all configured issuers in order. -func (cfg *Config) storageHasCertResourcesAnyIssuer(name string) bool { +func (cfg *Config) storageHasCertResourcesAnyIssuer(ctx context.Context, name string) bool { for _, iss := range cfg.Issuers { - if cfg.storageHasCertResources(iss, name) { + if cfg.storageHasCertResources(ctx, iss, name) { return true } } @@ -666,7 +666,7 @@ func (cfg *Config) renewCert(ctx context.Context, name string, force, interactiv // ensure storage is writeable and readable // TODO: this is not necessary every time; should only perform check once every so often for each storage, which may require some global state... - err := cfg.checkStorage() + err := cfg.checkStorage(ctx) if err != nil { return fmt.Errorf("failed storage check: %v - storage is probably misconfigured", err) } @@ -687,7 +687,7 @@ func (cfg *Config) renewCert(ctx context.Context, name string, force, interactiv if log != nil { log.Info("releasing lock", zap.String("identifier", name)) } - if err := releaseLock(cfg.Storage, lockKey); err != nil { + if err := releaseLock(ctx, cfg.Storage, lockKey); err != nil { if log != nil { log.Error("unable to unlock", zap.String("identifier", name), @@ -702,7 +702,7 @@ func (cfg *Config) renewCert(ctx context.Context, name string, force, interactiv f := func(ctx context.Context) error { // prepare for renewal (load PEM cert, key, and meta) - certRes, err := cfg.loadCertResourceAnyIssuer(name) + certRes, err := cfg.loadCertResourceAnyIssuer(ctx, name) if err != nil { return err } @@ -784,7 +784,7 @@ func (cfg *Config) renewCert(ctx context.Context, name string, force, interactiv PrivateKeyPEM: certRes.PrivateKeyPEM, IssuerData: issuedCert.Metadata, } - err = cfg.saveCertResource(issuerUsed, newCertRes) + err = cfg.saveCertResource(ctx, issuerUsed, newCertRes) if err != nil { return fmt.Errorf("[%s] Renew: saving assets: %v", name, err) } @@ -854,12 +854,12 @@ func (cfg *Config) RevokeCert(ctx context.Context, domain string, reason int, in return fmt.Errorf("issuer %d (%s) is not a Revoker", i, issuerKey) } - certRes, err := cfg.loadCertResource(issuer, domain) + certRes, err := cfg.loadCertResource(ctx, issuer, domain) if err != nil { return err } - if !cfg.Storage.Exists(StorageKeys.SitePrivateKey(issuerKey, domain)) { + if !cfg.Storage.Exists(ctx, StorageKeys.SitePrivateKey(issuerKey, domain)) { return fmt.Errorf("private key not found for %s", certRes.SANs) } @@ -870,7 +870,7 @@ func (cfg *Config) RevokeCert(ctx context.Context, domain string, reason int, in cfg.emit("cert_revoked", domain) - err = cfg.deleteSiteAssets(issuerKey, domain) + err = cfg.deleteSiteAssets(ctx, issuerKey, domain) if err != nil { return fmt.Errorf("certificate revoked, but unable to fully clean up assets from issuer %s: %v", issuerKey, err) } @@ -923,7 +923,7 @@ func (cfg *Config) TLSConfig() *tls.Config { // indicates whether challenge info was loaded from external storage. If true, the // challenge is being solved in a distributed fashion; if false, from internal memory. // If no matching challenge information can be found, an error is returned. -func (cfg *Config) getChallengeInfo(identifier string) (Challenge, bool, error) { +func (cfg *Config) getChallengeInfo(ctx context.Context, identifier string) (Challenge, bool, error) { // first, check if our process initiated this challenge; if so, just return it chalData, ok := GetACMEChallenge(identifier) if ok { @@ -943,7 +943,7 @@ func (cfg *Config) getChallengeInfo(identifier string) (Challenge, bool, error) } tokenKey = ds.challengeTokensKey(identifier) var err error - chalInfoBytes, err = cfg.Storage.Load(tokenKey) + chalInfoBytes, err = cfg.Storage.Load(ctx, tokenKey) if err == nil { break } @@ -968,19 +968,19 @@ func (cfg *Config) getChallengeInfo(identifier string) (Challenge, bool, error) // to a random key, and then loading those bytes and // comparing the loaded value. If this fails, the provided // cfg.Storage mechanism should not be used. -func (cfg *Config) checkStorage() error { +func (cfg *Config) checkStorage(ctx context.Context) error { key := fmt.Sprintf("rw_test_%d", weakrand.Int()) contents := make([]byte, 1024*10) // size sufficient for one or two ACME resources _, err := weakrand.Read(contents) if err != nil { return err } - err = cfg.Storage.Store(key, contents) + err = cfg.Storage.Store(ctx, key, contents) if err != nil { return err } defer func() { - deleteErr := cfg.Storage.Delete(key) + deleteErr := cfg.Storage.Delete(ctx, key) if deleteErr != nil { if cfg.Logger != nil { cfg.Logger.Error("deleting test key from storage", @@ -993,7 +993,7 @@ func (cfg *Config) checkStorage() error { err = deleteErr } }() - loaded, err := cfg.Storage.Load(key) + loaded, err := cfg.Storage.Load(ctx, key) if err != nil { return err } @@ -1007,33 +1007,33 @@ func (cfg *Config) checkStorage() error { // associated with cfg's certificate cache has all the // resources related to the certificate for domain: the // certificate, the private key, and the metadata. -func (cfg *Config) storageHasCertResources(issuer Issuer, domain string) bool { +func (cfg *Config) storageHasCertResources(ctx context.Context, issuer Issuer, domain string) bool { issuerKey := issuer.IssuerKey() certKey := StorageKeys.SiteCert(issuerKey, domain) keyKey := StorageKeys.SitePrivateKey(issuerKey, domain) metaKey := StorageKeys.SiteMeta(issuerKey, domain) - return cfg.Storage.Exists(certKey) && - cfg.Storage.Exists(keyKey) && - cfg.Storage.Exists(metaKey) + return cfg.Storage.Exists(ctx, certKey) && + cfg.Storage.Exists(ctx, keyKey) && + cfg.Storage.Exists(ctx, metaKey) } // deleteSiteAssets deletes the folder in storage containing the // certificate, private key, and metadata file for domain from the // issuer with the given issuer key. -func (cfg *Config) deleteSiteAssets(issuerKey, domain string) error { - err := cfg.Storage.Delete(StorageKeys.SiteCert(issuerKey, domain)) +func (cfg *Config) deleteSiteAssets(ctx context.Context, issuerKey, domain string) error { + err := cfg.Storage.Delete(ctx, StorageKeys.SiteCert(issuerKey, domain)) if err != nil { return fmt.Errorf("deleting certificate file: %v", err) } - err = cfg.Storage.Delete(StorageKeys.SitePrivateKey(issuerKey, domain)) + err = cfg.Storage.Delete(ctx, StorageKeys.SitePrivateKey(issuerKey, domain)) if err != nil { return fmt.Errorf("deleting private key: %v", err) } - err = cfg.Storage.Delete(StorageKeys.SiteMeta(issuerKey, domain)) + err = cfg.Storage.Delete(ctx, StorageKeys.SiteMeta(issuerKey, domain)) if err != nil { return fmt.Errorf("deleting metadata file: %v", err) } - err = cfg.Storage.Delete(StorageKeys.CertsSitePrefix(issuerKey, domain)) + err = cfg.Storage.Delete(ctx, StorageKeys.CertsSitePrefix(issuerKey, domain)) if err != nil { return fmt.Errorf("deleting site asset folder: %v", err) } diff --git a/config_test.go b/config_test.go index 1bb30e36..3f177018 100644 --- a/config_test.go +++ b/config_test.go @@ -15,6 +15,7 @@ package certmagic import ( + "context" "os" "reflect" "testing" @@ -23,6 +24,8 @@ import ( ) func TestSaveCertResource(t *testing.T) { + ctx := context.Background() + am := &ACMEManager{CA: "https://example.com/acme/directory"} testConfig := &Config{ Issuers: []Issuer{am}, @@ -53,7 +56,7 @@ func TestSaveCertResource(t *testing.T) { issuerKey: am.IssuerKey(), } - err := testConfig.saveCertResource(am, cert) + err := testConfig.saveCertResource(ctx, am, cert) if err != nil { t.Fatalf("Expected no error, got: %v", err) } @@ -64,7 +67,7 @@ func TestSaveCertResource(t *testing.T) { "url": "https://example.com/cert", } - siteData, err := testConfig.loadCertResource(am, domain) + siteData, err := testConfig.loadCertResource(ctx, am, domain) if err != nil { t.Fatalf("Expected no error reading site, got: %v", err) } diff --git a/crypto.go b/crypto.go index 62f5f17f..475155f9 100644 --- a/crypto.go +++ b/crypto.go @@ -15,6 +15,7 @@ package certmagic import ( + "context" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -142,7 +143,7 @@ func fastHash(input []byte) string { // saveCertResource saves the certificate resource to disk. This // includes the certificate file itself, the private key, and the // metadata file. -func (cfg *Config) saveCertResource(issuer Issuer, cert CertificateResource) error { +func (cfg *Config) saveCertResource(ctx context.Context, issuer Issuer, cert CertificateResource) error { metaBytes, err := json.MarshalIndent(cert, "", "\t") if err != nil { return fmt.Errorf("encoding certificate metadata: %v", err) @@ -166,19 +167,19 @@ func (cfg *Config) saveCertResource(issuer Issuer, cert CertificateResource) err }, } - return storeTx(cfg.Storage, all) + return storeTx(ctx, cfg.Storage, all) } // loadCertResourceAnyIssuer loads and returns the certificate resource from any // of the configured issuers. If multiple are found (e.g. if there are 3 issuers // configured, and all 3 have a resource matching certNamesKey), then the newest // (latest NotBefore date) resource will be chosen. -func (cfg *Config) loadCertResourceAnyIssuer(certNamesKey string) (CertificateResource, error) { +func (cfg *Config) loadCertResourceAnyIssuer(ctx context.Context, certNamesKey string) (CertificateResource, error) { // we can save some extra decoding steps if there's only one issuer, since // we don't need to compare potentially multiple available resources to // select the best one, when there's only one choice anyway if len(cfg.Issuers) == 1 { - return cfg.loadCertResource(cfg.Issuers[0], certNamesKey) + return cfg.loadCertResource(ctx, cfg.Issuers[0], certNamesKey) } type decodedCertResource struct { @@ -192,7 +193,7 @@ func (cfg *Config) loadCertResourceAnyIssuer(certNamesKey string) (CertificateRe // load and decode all certificate resources found with the // configured issuers so we can sort by newest for _, issuer := range cfg.Issuers { - certRes, err := cfg.loadCertResource(issuer, certNamesKey) + certRes, err := cfg.loadCertResource(ctx, issuer, certNamesKey) if err != nil { if errors.Is(err, fs.ErrNotExist) { // not a problem, but we need to remember the error @@ -238,7 +239,7 @@ func (cfg *Config) loadCertResourceAnyIssuer(certNamesKey string) (CertificateRe } // loadCertResource loads a certificate resource from the given issuer's storage location. -func (cfg *Config) loadCertResource(issuer Issuer, certNamesKey string) (CertificateResource, error) { +func (cfg *Config) loadCertResource(ctx context.Context, issuer Issuer, certNamesKey string) (CertificateResource, error) { certRes := CertificateResource{issuerKey: issuer.IssuerKey()} normalizedName, err := idna.ToASCII(certNamesKey) @@ -246,17 +247,17 @@ func (cfg *Config) loadCertResource(issuer Issuer, certNamesKey string) (Certifi return CertificateResource{}, fmt.Errorf("converting '%s' to ASCII: %v", certNamesKey, err) } - keyBytes, err := cfg.Storage.Load(StorageKeys.SitePrivateKey(certRes.issuerKey, normalizedName)) + keyBytes, err := cfg.Storage.Load(ctx, StorageKeys.SitePrivateKey(certRes.issuerKey, normalizedName)) if err != nil { return CertificateResource{}, err } certRes.PrivateKeyPEM = keyBytes - certBytes, err := cfg.Storage.Load(StorageKeys.SiteCert(certRes.issuerKey, normalizedName)) + certBytes, err := cfg.Storage.Load(ctx, StorageKeys.SiteCert(certRes.issuerKey, normalizedName)) if err != nil { return CertificateResource{}, err } certRes.CertificatePEM = certBytes - metaBytes, err := cfg.Storage.Load(StorageKeys.SiteMeta(certRes.issuerKey, normalizedName)) + metaBytes, err := cfg.Storage.Load(ctx, StorageKeys.SiteMeta(certRes.issuerKey, normalizedName)) if err != nil { return CertificateResource{}, err } diff --git a/filestorage.go b/filestorage.go index dc49d44f..ede25c07 100644 --- a/filestorage.go +++ b/filestorage.go @@ -70,13 +70,13 @@ type FileStorage struct { } // Exists returns true if key exists in s. -func (s *FileStorage) Exists(key string) bool { +func (s *FileStorage) Exists(_ context.Context, key string) bool { _, err := os.Stat(s.Filename(key)) return !errors.Is(err, fs.ErrNotExist) } // Store saves value at key. -func (s *FileStorage) Store(key string, value []byte) error { +func (s *FileStorage) Store(_ context.Context, key string, value []byte) error { filename := s.Filename(key) err := os.MkdirAll(filepath.Dir(filename), 0700) if err != nil { @@ -86,17 +86,17 @@ func (s *FileStorage) Store(key string, value []byte) error { } // Load retrieves the value at key. -func (s *FileStorage) Load(key string) ([]byte, error) { +func (s *FileStorage) Load(_ context.Context, key string) ([]byte, error) { return os.ReadFile(s.Filename(key)) } // Delete deletes the value at key. -func (s *FileStorage) Delete(key string) error { +func (s *FileStorage) Delete(_ context.Context, key string) error { return os.Remove(s.Filename(key)) } // List returns all keys that match prefix. -func (s *FileStorage) List(prefix string, recursive bool) ([]string, error) { +func (s *FileStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { var keys []string walkPrefix := s.Filename(prefix) @@ -110,6 +110,9 @@ func (s *FileStorage) List(prefix string, recursive bool) ([]string, error) { if fpath == walkPrefix { return nil } + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } suffix, err := filepath.Rel(walkPrefix, fpath) if err != nil { @@ -127,7 +130,7 @@ func (s *FileStorage) List(prefix string, recursive bool) ([]string, error) { } // Stat returns information about key. -func (s *FileStorage) Stat(key string) (KeyInfo, error) { +func (s *FileStorage) Stat(_ context.Context, key string) (KeyInfo, error) { fi, err := os.Stat(s.Filename(key)) if err != nil { return KeyInfo{}, err @@ -212,7 +215,7 @@ func (s *FileStorage) Lock(ctx context.Context, key string) error { } // Unlock releases the lock for name. -func (s *FileStorage) Unlock(key string) error { +func (s *FileStorage) Unlock(_ context.Context, key string) error { return os.Remove(s.lockFilename(key)) } diff --git a/handshake.go b/handshake.go index f53c0368..5a37533e 100644 --- a/handshake.go +++ b/handshake.go @@ -294,12 +294,12 @@ func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNece // it might be a good idea to check with the DecisionFunc or allowlist first before even loading the certificate // from storage, since if we can't renew it, why should we even try serving it (it will just get evicted after // we get a return value of false anyway)? - loadedCert, err := cfg.CacheManagedCertificate(name) + loadedCert, err := cfg.CacheManagedCertificate(ctx, name) if errors.Is(err, fs.ErrNotExist) { // If no exact match, try a wildcard variant, which is something we can still use labels := strings.Split(name, ".") labels[0] = "*" - loadedCert, err = cfg.CacheManagedCertificate(strings.Join(labels, ".")) + loadedCert, err = cfg.CacheManagedCertificate(ctx, strings.Join(labels, ".")) } if err == nil { if log != nil { @@ -497,7 +497,7 @@ func (cfg *Config) handshakeMaintenance(ctx context.Context, hello *tls.ClientHe zap.Time("next_update", cert.ocsp.NextUpdate)) } - err := stapleOCSP(cfg.OCSP, cfg.Storage, &cert, nil) + err := stapleOCSP(ctx, cfg.OCSP, cfg.Storage, &cert, nil) if err != nil { // An error with OCSP stapling is not the end of the world, and in fact, is // quite common considering not all certs have issuer URLs that support it. @@ -654,7 +654,7 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien // even though the recursive nature of the dynamic cert loading // would just call this function anyway, we do it here to // make the replacement as atomic as possible. - newCert, err = cfg.CacheManagedCertificate(name) + newCert, err = cfg.CacheManagedCertificate(ctx, name) if err != nil { if log != nil { log.Error("loading renewed certificate", zap.String("server_name", name), zap.Error(err)) @@ -752,7 +752,7 @@ func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.Cli // solving). True is returned if the challenge is being solved distributed (there // is no semantic difference with distributed solving; it is mainly for logging). func (cfg *Config) getTLSALPNChallengeCert(clientHello *tls.ClientHelloInfo) (*tls.Certificate, bool, error) { - chalData, distributed, err := cfg.getChallengeInfo(clientHello.ServerName) + chalData, distributed, err := cfg.getChallengeInfo(clientHello.Context(), clientHello.ServerName) if err != nil { return nil, distributed, err } diff --git a/httphandler.go b/httphandler.go index d17cfaab..60b33cd3 100644 --- a/httphandler.go +++ b/httphandler.go @@ -71,7 +71,7 @@ func (am *ACMEManager) distributedHTTPChallengeSolver(w http.ResponseWriter, r * return false } host := hostOnly(r.Host) - chalInfo, distributed, err := am.config.getChallengeInfo(host) + chalInfo, distributed, err := am.config.getChallengeInfo(r.Context(), host) if err != nil { if am.Logger != nil { am.Logger.Error("looking up info for HTTP challenge", diff --git a/maintain.go b/maintain.go index 0802dc8f..dbfde011 100644 --- a/maintain.go +++ b/maintain.go @@ -154,7 +154,7 @@ func (certCache *Cache) RenewManagedCertificates(ctx context.Context) error { // instance that didn't coordinate with this one; if so, just load it (this // might happen if another instance already renewed it - kinda sloppy but checking disk // first is a simple way to possibly drastically reduce rate limit problems) - storedCertExpiring, err := cfg.managedCertInStorageExpiresSoon(cert) + storedCertExpiring, err := cfg.managedCertInStorageExpiresSoon(ctx, cert) if err != nil { // hmm, weird, but not a big deal, maybe it was deleted or something if log != nil { @@ -191,7 +191,7 @@ func (certCache *Cache) RenewManagedCertificates(ctx context.Context) error { cfg := configs[oldCert.Names[0]] // crucially, this happens OUTSIDE a lock on the certCache - _, err := cfg.reloadManagedCertificate(oldCert) + _, err := cfg.reloadManagedCertificate(ctx, oldCert) if err != nil { if log != nil { log.Error("loading renewed certificate", @@ -264,7 +264,7 @@ func (certCache *Cache) queueRenewalTask(ctx context.Context, oldCert Certificat // successful renewal, so update in-memory cache by loading // renewed certificate so it will be used with handshakes - _, err = cfg.reloadManagedCertificate(oldCert) + _, err = cfg.reloadManagedCertificate(ctx, oldCert) if err != nil { return ErrNoRetry{fmt.Errorf("%v %v", oldCert.Names, err)} } @@ -355,7 +355,7 @@ func (certCache *Cache) updateOCSPStaples(ctx context.Context) { continue } - err := stapleOCSP(qe.cfg.OCSP, qe.cfg.Storage, &cert, nil) + err := stapleOCSP(ctx, qe.cfg.OCSP, qe.cfg.Storage, &cert, nil) if err != nil { if cert.ocsp != nil { // if there was no staple before, that's fine; otherwise we should log the error @@ -439,7 +439,7 @@ func CleanStorage(ctx context.Context, storage Storage, opts CleanStorageOptions } func deleteOldOCSPStaples(ctx context.Context, storage Storage) error { - ocspKeys, err := storage.List(prefixOCSP, false) + ocspKeys, err := storage.List(ctx, prefixOCSP, false) if err != nil { // maybe just hasn't been created yet; no big deal return nil @@ -451,7 +451,7 @@ func deleteOldOCSPStaples(ctx context.Context, storage Storage) error { return ctx.Err() default: } - ocspBytes, err := storage.Load(key) + ocspBytes, err := storage.Load(ctx, key) if err != nil { log.Printf("[ERROR] While deleting old OCSP staples, unable to load staple file: %v", err) continue @@ -459,7 +459,7 @@ func deleteOldOCSPStaples(ctx context.Context, storage Storage) error { resp, err := ocsp.ParseResponse(ocspBytes, nil) if err != nil { // contents are invalid; delete it - err = storage.Delete(key) + err = storage.Delete(ctx, key) if err != nil { log.Printf("[ERROR] Purging corrupt staple file %s: %v", key, err) } @@ -467,7 +467,7 @@ func deleteOldOCSPStaples(ctx context.Context, storage Storage) error { } if time.Now().After(resp.NextUpdate) { // response has expired; delete it - err = storage.Delete(key) + err = storage.Delete(ctx, key) if err != nil { log.Printf("[ERROR] Purging expired staple file %s: %v", key, err) } @@ -477,14 +477,14 @@ func deleteOldOCSPStaples(ctx context.Context, storage Storage) error { } func deleteExpiredCerts(ctx context.Context, storage Storage, gracePeriod time.Duration) error { - issuerKeys, err := storage.List(prefixCerts, false) + issuerKeys, err := storage.List(ctx, prefixCerts, false) if err != nil { // maybe just hasn't been created yet; no big deal return nil } for _, issuerKey := range issuerKeys { - siteKeys, err := storage.List(issuerKey, false) + siteKeys, err := storage.List(ctx, issuerKey, false) if err != nil { log.Printf("[ERROR] Listing contents of %s: %v", issuerKey, err) continue @@ -498,7 +498,7 @@ func deleteExpiredCerts(ctx context.Context, storage Storage, gracePeriod time.D default: } - siteAssets, err := storage.List(siteKey, false) + siteAssets, err := storage.List(ctx, siteKey, false) if err != nil { log.Printf("[ERROR] Listing contents of %s: %v", siteKey, err) continue @@ -509,7 +509,7 @@ func deleteExpiredCerts(ctx context.Context, storage Storage, gracePeriod time.D continue } - certFile, err := storage.Load(assetKey) + certFile, err := storage.Load(ctx, assetKey) if err != nil { return fmt.Errorf("loading certificate file %s: %v", assetKey, err) } @@ -531,7 +531,7 @@ func deleteExpiredCerts(ctx context.Context, storage Storage, gracePeriod time.D baseName + ".json", } { log.Printf("[INFO] Deleting %s because resource expired", relatedAsset) - err := storage.Delete(relatedAsset) + err := storage.Delete(ctx, relatedAsset) if err != nil { log.Printf("[ERROR] Cleaning up asset related to expired certificate for %s: %s: %v", baseName, relatedAsset, err) @@ -541,13 +541,13 @@ func deleteExpiredCerts(ctx context.Context, storage Storage, gracePeriod time.D } // update listing; if folder is empty, delete it - siteAssets, err = storage.List(siteKey, false) + siteAssets, err = storage.List(ctx, siteKey, false) if err != nil { continue } if len(siteAssets) == 0 { log.Printf("[INFO] Deleting %s because key is empty", siteKey) - err := storage.Delete(siteKey) + err := storage.Delete(ctx, siteKey) if err != nil { return fmt.Errorf("deleting empty site folder %s: %v", siteKey, err) } @@ -583,7 +583,7 @@ func (cfg *Config) forceRenew(ctx context.Context, logger *zap.Logger, cert Cert // new key, so we'll have to do an obtain instead var obtainInsteadOfRenew bool if cert.ocsp != nil && cert.ocsp.RevocationReason == acme.ReasonKeyCompromise { - err := cfg.moveCompromisedPrivateKey(cert, logger) + err := cfg.moveCompromisedPrivateKey(ctx, cert, logger) if err != nil && logger != nil { logger.Error("could not remove compromised private key from use", zap.Strings("identifiers", cert.Names), @@ -617,28 +617,28 @@ func (cfg *Config) forceRenew(ctx context.Context, logger *zap.Logger, cert Cert return cert, fmt.Errorf("unable to forcefully get new certificate for %v: %w", cert.Names, err) } - return cfg.reloadManagedCertificate(cert) + return cfg.reloadManagedCertificate(ctx, cert) } // moveCompromisedPrivateKey moves the private key for cert to a ".compromised" file // by copying the data to the new file, then deleting the old one. -func (cfg *Config) moveCompromisedPrivateKey(cert Certificate, logger *zap.Logger) error { +func (cfg *Config) moveCompromisedPrivateKey(ctx context.Context, cert Certificate, logger *zap.Logger) error { privKeyStorageKey := StorageKeys.SitePrivateKey(cert.issuerKey, cert.Names[0]) - privKeyPEM, err := cfg.Storage.Load(privKeyStorageKey) + privKeyPEM, err := cfg.Storage.Load(ctx, privKeyStorageKey) if err != nil { return err } compromisedPrivKeyStorageKey := privKeyStorageKey + ".compromised" - err = cfg.Storage.Store(compromisedPrivKeyStorageKey, privKeyPEM) + err = cfg.Storage.Store(ctx, compromisedPrivKeyStorageKey, privKeyPEM) if err != nil { // better safe than sorry: as a last resort, try deleting the key so it won't be reused - cfg.Storage.Delete(privKeyStorageKey) + cfg.Storage.Delete(ctx, privKeyStorageKey) return err } - err = cfg.Storage.Delete(privKeyStorageKey) + err = cfg.Storage.Delete(ctx, privKeyStorageKey) if err != nil { return err } diff --git a/ocsp.go b/ocsp.go index 3104c08f..5b2bb406 100644 --- a/ocsp.go +++ b/ocsp.go @@ -16,6 +16,7 @@ package certmagic import ( "bytes" + "context" "crypto/x509" "encoding/pem" "fmt" @@ -38,7 +39,7 @@ import ( // // Errors here are not necessarily fatal, it could just be that the // certificate doesn't have an issuer URL. -func stapleOCSP(ocspConfig OCSPConfig, storage Storage, cert *Certificate, pemBundle []byte) error { +func stapleOCSP(ctx context.Context, ocspConfig OCSPConfig, storage Storage, cert *Certificate, pemBundle []byte) error { if ocspConfig.DisableStapling { return nil } @@ -60,7 +61,7 @@ func stapleOCSP(ocspConfig OCSPConfig, storage Storage, cert *Certificate, pemBu // First try to load OCSP staple from storage and see if // we can still use it. ocspStapleKey := StorageKeys.OCSPStaple(cert, pemBundle) - cachedOCSP, err := storage.Load(ocspStapleKey) + cachedOCSP, err := storage.Load(ctx, ocspStapleKey) if err == nil { resp, err := ocsp.ParseResponse(cachedOCSP, nil) if err == nil { @@ -76,7 +77,7 @@ func stapleOCSP(ocspConfig OCSPConfig, storage Storage, cert *Certificate, pemBu // because we loaded it by name, whereas the maintenance routine // just iterates the list of files, even if somehow a non-staple // file gets in the folder. in this case we are sure it is corrupt.) - err := storage.Delete(ocspStapleKey) + err := storage.Delete(ctx, ocspStapleKey) if err != nil { log.Printf("[WARNING] Unable to delete invalid OCSP staple file: %v", err) } @@ -115,7 +116,7 @@ func stapleOCSP(ocspConfig OCSPConfig, storage Storage, cert *Certificate, pemBu if ocspResp.Status == ocsp.Good { cert.Certificate.OCSPStaple = ocspBytes if gotNewOCSP { - err := storage.Store(ocspStapleKey, ocspBytes) + err := storage.Store(ctx, ocspStapleKey, ocspBytes) if err != nil { return fmt.Errorf("unable to write OCSP staple file for %v: %v", cert.Names, err) } diff --git a/solvers.go b/solvers.go index 287d2be6..5521c120 100644 --- a/solvers.go +++ b/solvers.go @@ -72,13 +72,13 @@ func (s *httpSolver) Present(ctx context.Context, _ acme.Challenge) error { // successfully bound socket, so save listener and start key auth HTTP server si.listener = ln - go s.serve(si) + go s.serve(ctx, si) return nil } // serve is an HTTP server that serves only HTTP challenge responses. -func (s *httpSolver) serve(si *solverInfo) { +func (s *httpSolver) serve(ctx context.Context, si *solverInfo) { defer func() { if err := recover(); err != nil { buf := make([]byte, stackTraceBufferSize) @@ -87,7 +87,10 @@ func (s *httpSolver) serve(si *solverInfo) { } }() defer close(si.done) - httpServer := &http.Server{Handler: s.acmeManager.HTTPChallengeHandler(http.NewServeMux())} + httpServer := &http.Server{ + Handler: s.acmeManager.HTTPChallengeHandler(http.NewServeMux()), + BaseContext: func(listener net.Listener) context.Context { return ctx }, + } httpServer.SetKeepAlivesEnabled(false) err := httpServer.Serve(si.listener) if err != nil && atomic.LoadInt32(&s.closed) != 1 { @@ -484,7 +487,7 @@ func (dhs distributedSolver) Present(ctx context.Context, chal acme.Challenge) e return err } - err = dhs.storage.Store(dhs.challengeTokensKey(challengeKey(chal)), infoBytes) + err = dhs.storage.Store(ctx, dhs.challengeTokensKey(challengeKey(chal)), infoBytes) if err != nil { return err } @@ -507,7 +510,7 @@ func (dhs distributedSolver) Wait(ctx context.Context, challenge acme.Challenge) // CleanUp invokes the underlying solver's CleanUp method // and also cleans up any assets saved to storage. func (dhs distributedSolver) CleanUp(ctx context.Context, chal acme.Challenge) error { - err := dhs.storage.Delete(dhs.challengeTokensKey(challengeKey(chal))) + err := dhs.storage.Delete(ctx, dhs.challengeTokensKey(challengeKey(chal))) if err != nil { return err } diff --git a/storage.go b/storage.go index 354c21e9..28b68740 100644 --- a/storage.go +++ b/storage.go @@ -37,36 +37,37 @@ import ( // The Load, Delete, List, and Stat methods should return // fs.ErrNotExist if the key does not exist. // -// Implementations of Storage must be safe for concurrent use. +// Implementations of Storage must be safe for concurrent use +// and honor context cancellations. type Storage interface { // Locker provides atomic synchronization // operations, making Storage safe to share. Locker // Store puts value at key. - Store(key string, value []byte) error + Store(ctx context.Context, key string, value []byte) error // Load retrieves the value at key. - Load(key string) ([]byte, error) + Load(ctx context.Context, key string) ([]byte, error) // Delete deletes key. An error should be // returned only if the key still exists // when the method returns. - Delete(key string) error + Delete(ctx context.Context, key string) error // Exists returns true if the key exists // and there was no error checking. - Exists(key string) bool + Exists(ctx context.Context, key string) bool // List returns all keys that match prefix. // If recursive is true, non-terminal keys // will be enumerated (i.e. "directories" // should be walked); otherwise, only keys // prefixed exactly by prefix will be listed. - List(prefix string, recursive bool) ([]string, error) + List(ctx context.Context, prefix string, recursive bool) ([]string, error) // Stat returns information about key. - Stat(key string) (KeyInfo, error) + Stat(ctx context.Context, key string) (KeyInfo, error) } // Locker facilitates synchronization of certificate tasks across @@ -98,7 +99,7 @@ type Locker interface { // called after a successful call to Lock, and only after the // critical section is finished, even if it errored or timed // out. Unlock cleans up any resources allocated during Lock. - Unlock(key string) error + Unlock(ctx context.Context, key string) error } // KeyInfo holds information about a key in storage. @@ -116,12 +117,12 @@ type KeyInfo struct { } // storeTx stores all the values or none at all. -func storeTx(s Storage, all []keyValue) error { +func storeTx(ctx context.Context, s Storage, all []keyValue) error { for i, kv := range all { - err := s.Store(kv.key, kv.value) + err := s.Store(ctx, kv.key, kv.value) if err != nil { for j := i - 1; j >= 0; j-- { - s.Delete(all[j].key) + s.Delete(ctx, all[j].key) } return err } @@ -215,11 +216,11 @@ func (keys KeyBuilder) Safe(str string) string { // the locks are synchronizing, this should be // called only immediately before process exit. // Errors are only reported if a logger is given. -func CleanUpOwnLocks(logger *zap.Logger) { +func CleanUpOwnLocks(ctx context.Context, logger *zap.Logger) { locksMu.Lock() defer locksMu.Unlock() for lockKey, storage := range locks { - err := storage.Unlock(lockKey) + err := storage.Unlock(ctx, lockKey) if err == nil { delete(locks, lockKey) } else if logger != nil { @@ -242,8 +243,8 @@ func acquireLock(ctx context.Context, storage Storage, lockKey string) error { return err } -func releaseLock(storage Storage, lockKey string) error { - err := storage.Unlock(lockKey) +func releaseLock(ctx context.Context, storage Storage, lockKey string) error { + err := storage.Unlock(ctx, lockKey) if err == nil { locksMu.Lock() delete(locks, lockKey)