Skip to content

Commit

Permalink
Propagate context in the Storage interface methods (#155)
Browse files Browse the repository at this point in the history
* Add context propagation to the Storage interface

Signed-off-by: Dave Henderson <dhenderson@gmail.com>

* Bump to Go 1.17

* Minor cleanup

* filestorage: Honor context cancellation in List()

Co-authored-by: Matthew Holt <mholt@users.noreply.github.com>
  • Loading branch information
hairyhenderson and mholt committed Mar 7, 2022
1 parent 2d11419 commit 9a56fcd
Show file tree
Hide file tree
Showing 17 changed files with 188 additions and 158 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest, macos-latest, windows-latest ]
go: [ '1.16', '1.17' ]
go: [ '1.17' ]

runs-on: ${{ matrix.os }}

Expand Down
34 changes: 17 additions & 17 deletions account.go
Expand Up @@ -37,21 +37,21 @@ import (

// getAccount either loads or creates a new account, depending on if
// an account can be found in storage for the given CA + email combo.
func (am *ACMEManager) getAccount(ca, email string) (acme.Account, error) {
acct, err := am.loadAccount(ca, email)
func (am *ACMEManager) getAccount(ctx context.Context, ca, email string) (acme.Account, error) {
acct, err := am.loadAccount(ctx, ca, email)
if errors.Is(err, fs.ErrNotExist) {
return am.newAccount(email)
}
return acct, err
}

// loadAccount loads an account from storage, but does not create a new one.
func (am *ACMEManager) loadAccount(ca, email string) (acme.Account, error) {
regBytes, err := am.config.Storage.Load(am.storageKeyUserReg(ca, email))
func (am *ACMEManager) loadAccount(ctx context.Context, ca, email string) (acme.Account, error) {
regBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserReg(ca, email))
if err != nil {
return acme.Account{}, err
}
keyBytes, err := am.config.Storage.Load(am.storageKeyUserPrivateKey(ca, email))
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(ca, email))
if err != nil {
return acme.Account{}, err
}
Expand Down Expand Up @@ -99,18 +99,18 @@ func (am *ACMEManager) GetAccount(ctx context.Context, privateKeyPEM []byte) (ac
// If it does not exist, an error of type fs.ErrNotExist is returned. This is not very efficient
// for lots of accounts.
func (am *ACMEManager) loadAccountByKey(ctx context.Context, privateKeyPEM []byte) (acme.Account, error) {
accountList, err := am.config.Storage.List(am.storageKeyUsersPrefix(am.CA), false)
accountList, err := am.config.Storage.List(ctx, am.storageKeyUsersPrefix(am.CA), false)
if err != nil {
return acme.Account{}, err
}
for _, accountFolderKey := range accountList {
email := path.Base(accountFolderKey)
keyBytes, err := am.config.Storage.Load(am.storageKeyUserPrivateKey(am.CA, email))
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
if err != nil {
return acme.Account{}, err
}
if bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
return am.loadAccount(am.CA, email)
return am.loadAccount(ctx, am.CA, email)
}
}
return acme.Account{}, fs.ErrNotExist
Expand All @@ -137,7 +137,7 @@ func (am *ACMEManager) lookUpAccount(ctx context.Context, privateKeyPEM []byte)
}

// save the account details to storage
err = am.saveAccount(client.Directory, account)
err = am.saveAccount(ctx, client.Directory, account)
if err != nil {
return account, fmt.Errorf("could not save account to storage: %v", err)
}
Expand All @@ -147,7 +147,7 @@ func (am *ACMEManager) lookUpAccount(ctx context.Context, privateKeyPEM []byte)

// saveAccount persists an ACME account's info and private key to storage.
// It does NOT register the account via ACME or prompt the user.
func (am *ACMEManager) saveAccount(ca string, account acme.Account) error {
func (am *ACMEManager) saveAccount(ctx context.Context, ca string, account acme.Account) error {
regBytes, err := json.MarshalIndent(account, "", "\t")
if err != nil {
return err
Expand All @@ -168,7 +168,7 @@ func (am *ACMEManager) saveAccount(ca string, account acme.Account) error {
value: keyBytes,
},
}
return storeTx(am.config.Storage, all)
return storeTx(ctx, am.config.Storage, all)
}

// getEmail does everything it can to obtain an email address
Expand All @@ -178,7 +178,7 @@ func (am *ACMEManager) saveAccount(ca string, account acme.Account) error {
// the consequences of an empty email.) This function MAY prompt
// the user for input. If allowPrompts is false, the user
// will NOT be prompted and an empty email may be returned.
func (am *ACMEManager) getEmail(allowPrompts bool) error {
func (am *ACMEManager) getEmail(ctx context.Context, allowPrompts bool) error {
leEmail := am.Email

// First try package default email, or a discovered email address
Expand All @@ -194,7 +194,7 @@ func (am *ACMEManager) getEmail(allowPrompts bool) error {
// Then try to get most recent user email from storage
var gotRecentEmail bool
if leEmail == "" {
leEmail, gotRecentEmail = am.mostRecentAccountEmail(am.CA)
leEmail, gotRecentEmail = am.mostRecentAccountEmail(ctx, am.CA)
}
if !gotRecentEmail && leEmail == "" && allowPrompts {
// Looks like there is no email address readily available,
Expand Down Expand Up @@ -331,8 +331,8 @@ func (*ACMEManager) emailUsername(email string) string {
// in storage. Since this is part of a complex sequence to get a user
// account, errors here are discarded to simplify code flow in
// the caller, and errors are not important here anyway.
func (am *ACMEManager) mostRecentAccountEmail(caURL string) (string, bool) {
accountList, err := am.config.Storage.List(am.storageKeyUsersPrefix(caURL), false)
func (am *ACMEManager) mostRecentAccountEmail(ctx context.Context, caURL string) (string, bool) {
accountList, err := am.config.Storage.List(ctx, am.storageKeyUsersPrefix(caURL), false)
if err != nil || len(accountList) == 0 {
return "", false
}
Expand All @@ -342,7 +342,7 @@ func (am *ACMEManager) mostRecentAccountEmail(caURL string) (string, bool) {
stats := make(map[string]KeyInfo)
for i := 0; i < len(accountList); i++ {
u := accountList[i]
keyInfo, err := am.config.Storage.Stat(u)
keyInfo, err := am.config.Storage.Stat(ctx, u)
if err != nil {
continue
}
Expand Down Expand Up @@ -370,7 +370,7 @@ func (am *ACMEManager) mostRecentAccountEmail(caURL string) (string, bool) {
return "", false
}

account, err := am.getAccount(caURL, path.Base(accountList[0]))
account, err := am.getAccount(ctx, caURL, path.Base(accountList[0]))
if err != nil {
return "", false
}
Expand Down
31 changes: 22 additions & 9 deletions account_test.go
Expand Up @@ -16,6 +16,7 @@ package certmagic

import (
"bytes"
"context"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -50,6 +51,8 @@ func TestNewAccount(t *testing.T) {
}

func TestSaveAccount(t *testing.T) {
ctx := context.Background()

am := &ACMEManager{CA: dummyCA}
testConfig := &Config{
Issuers: []Issuer{am},
Expand All @@ -72,17 +75,19 @@ func TestSaveAccount(t *testing.T) {
t.Fatalf("Error creating account: %v", err)
}

err = am.saveAccount(am.CA, account)
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving account: %v", err)
}
_, err = am.getAccount(am.CA, email)
_, err = am.getAccount(ctx, am.CA, email)
if err != nil {
t.Errorf("Cannot access account data, error: %v", err)
}
}

func TestGetAccountDoesNotAlreadyExist(t *testing.T) {
ctx := context.Background()

am := &ACMEManager{CA: dummyCA}
testConfig := &Config{
Issuers: []Issuer{am},
Expand All @@ -91,7 +96,7 @@ func TestGetAccountDoesNotAlreadyExist(t *testing.T) {
}
am.config = testConfig

account, err := am.getAccount(am.CA, "account_does_not_exist@foobar.com")
account, err := am.getAccount(ctx, am.CA, "account_does_not_exist@foobar.com")
if err != nil {
t.Fatalf("Error getting account: %v", err)
}
Expand All @@ -102,6 +107,8 @@ func TestGetAccountDoesNotAlreadyExist(t *testing.T) {
}

func TestGetAccountAlreadyExists(t *testing.T) {
ctx := context.Background()

am := &ACMEManager{CA: dummyCA}
testConfig := &Config{
Issuers: []Issuer{am},
Expand All @@ -125,13 +132,13 @@ func TestGetAccountAlreadyExists(t *testing.T) {
if err != nil {
t.Fatalf("Error creating account: %v", err)
}
err = am.saveAccount(am.CA, account)
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving account: %v", err)
}

// Expect to load account from disk
loadedAccount, err := am.getAccount(am.CA, email)
loadedAccount, err := am.getAccount(ctx, am.CA, email)
if err != nil {
t.Fatalf("Error getting account: %v", err)
}
Expand All @@ -148,6 +155,8 @@ func TestGetAccountAlreadyExists(t *testing.T) {
}

func TestGetEmailFromPackageDefault(t *testing.T) {
ctx := context.Background()

DefaultACME.Email = "tEsT2@foo.com"
defer func() {
DefaultACME.Email = ""
Expand All @@ -162,7 +171,7 @@ func TestGetEmailFromPackageDefault(t *testing.T) {
}
am.config = testConfig

err := am.getEmail(true)
err := am.getEmail(ctx, true)
if err != nil {
t.Fatalf("getEmail error: %v", err)
}
Expand All @@ -173,6 +182,8 @@ func TestGetEmailFromPackageDefault(t *testing.T) {
}

func TestGetEmailFromUserInput(t *testing.T) {
ctx := context.Background()

am := &ACMEManager{CA: dummyCA}
testConfig := &Config{
Issuers: []Issuer{am},
Expand All @@ -192,7 +203,7 @@ func TestGetEmailFromUserInput(t *testing.T) {

email := "test3@foo.com"
stdin = bytes.NewBufferString(email + "\n")
err := am.getEmail(true)
err := am.getEmail(ctx, true)
if err != nil {
t.Fatalf("getEmail error: %v", err)
}
Expand All @@ -205,6 +216,8 @@ func TestGetEmailFromUserInput(t *testing.T) {
}

func TestGetEmailFromRecent(t *testing.T) {
ctx := context.Background()

am := &ACMEManager{CA: dummyCA}
testConfig := &Config{
Issuers: []Issuer{am},
Expand Down Expand Up @@ -233,7 +246,7 @@ func TestGetEmailFromRecent(t *testing.T) {
if err != nil {
t.Fatalf("Error creating user %d: %v", i, err)
}
err = am.saveAccount(am.CA, account)
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving user %d: %v", i, err)
}
Expand All @@ -250,7 +263,7 @@ func TestGetEmailFromRecent(t *testing.T) {
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
}
}
err := am.getEmail(true)
err := am.getEmail(ctx, true)
if err != nil {
t.Fatalf("getEmail error: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions acmeclient.go
Expand Up @@ -61,7 +61,7 @@ func (am *ACMEManager) newACMEClientWithAccount(ctx context.Context, useTestCA,
if am.AccountKeyPEM != "" {
account, err = am.GetAccount(ctx, []byte(am.AccountKeyPEM))
} else {
account, err = am.getAccount(client.Directory, am.Email)
account, err = am.getAccount(ctx, client.Directory, am.Email)
}
if err != nil {
return nil, fmt.Errorf("getting ACME account: %v", err)
Expand Down Expand Up @@ -116,7 +116,7 @@ func (am *ACMEManager) newACMEClientWithAccount(ctx context.Context, useTestCA,
}

// persist the account to storage
err = am.saveAccount(client.Directory, account)
err = am.saveAccount(ctx, client.Directory, account)
if err != nil {
return nil, fmt.Errorf("could not save account %v: %v", account.Contact, err)
}
Expand Down
4 changes: 2 additions & 2 deletions acmemanager.go
Expand Up @@ -217,7 +217,7 @@ func (*ACMEManager) issuerKey(ca string) string {
// renewing a certificate with ACME, and returns whether this
// batch is eligible for certificates if using Let's Encrypt.
// It also ensures that an email address is available.
func (am *ACMEManager) PreCheck(_ context.Context, names []string, interactive bool) error {
func (am *ACMEManager) PreCheck(ctx context.Context, names []string, interactive bool) error {
publicCA := strings.Contains(am.CA, "api.letsencrypt.org") || strings.Contains(am.CA, "acme.zerossl.com")
if publicCA {
for _, name := range names {
Expand All @@ -226,7 +226,7 @@ func (am *ACMEManager) PreCheck(_ context.Context, names []string, interactive b
}
}
}
return am.getEmail(interactive)
return am.getEmail(ctx, interactive)
}

// Issue implements the Issuer interface. It obtains a certificate for the given csr using
Expand Down

0 comments on commit 9a56fcd

Please sign in to comment.