diff --git a/vault/expiration.go b/vault/expiration.go index 97183e525031f..9d088f1480279 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -112,6 +112,8 @@ type ExpirationManager struct { leaseCount int pendingLock sync.RWMutex + // A sync.Lock for every active leaseID + lockPerLease sync.Map // Track expired leases that have been determined to be irrevocable (without // manual intervention). We retain a subset of the lease info in memory irrevocable sync.Map @@ -392,6 +394,8 @@ func NewExpirationManager(c *Core, view *BarrierView, e ExpireLeaseStrategy, log leaseCount: 0, tidyLock: new(int32), + lockPerLease: sync.Map{}, + uniquePolicies: make(map[string][]string), emptyUniquePolicies: time.NewTicker(7 * 24 * time.Hour), @@ -642,7 +646,11 @@ func (m *ExpirationManager) Tidy(ctx context.Context) error { if revokeLease { // Force the revocation and skip going through the token store // again + + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() err = m.revokeCommon(ctx, leaseID, true, true) + leaseLock.Unlock() if err != nil { tidyErrors = multierror.Append(tidyErrors, fmt.Errorf("failed to revoke an invalid lease with ID %q: %w", leaseID, err)) return @@ -905,6 +913,14 @@ func (m *ExpirationManager) Revoke(ctx context.Context, leaseID string) error { // it triggers a return of a 202. func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) error { defer metrics.MeasureSince([]string{"expire", "lazy-revoke"}, time.Now()) + return m.lazyRevokeInternal(ctx, leaseID) +} + +// Mark a lease as expiring immediately +func (m *ExpirationManager) lazyRevokeInternal(ctx context.Context, leaseID string) error { + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() // Load the entry le, err := m.loadEntry(ctx, leaseID) @@ -918,16 +934,10 @@ func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) erro } le.ExpireTime = time.Now() - { - m.pendingLock.Lock() - if err := m.persistEntry(ctx, le); err != nil { - m.pendingLock.Unlock() - return err - } - - m.updatePendingInternal(le) - m.pendingLock.Unlock() + if err := m.persistEntry(ctx, le); err != nil { + return err } + m.updatePending(le) return nil } @@ -937,6 +947,16 @@ func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) erro func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, force, skipToken bool) error { defer metrics.MeasureSince([]string{"expire", "revoke-common"}, time.Now()) + if !skipToken { + // Acquire lock for this lease + // If skipToken is true, then we're either being (1) called via RevokeByToken, so + // probably the lock is already held, and if we re-acquire we get deadlock, or + // (2) called by tidy, in which case the lock is held by the tidy thread. + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() + } + // Load the entry le, err := m.loadEntry(ctx, leaseID) if err != nil { @@ -966,6 +986,9 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo return err } + // Lease has been removed, also remove the in-memory lock. + m.deleteLockForLease(leaseID) + // Delete the secondary index, but only if it's a leased secret (not auth) if le.Secret != nil { var indexToken string @@ -1029,7 +1052,8 @@ func (m *ExpirationManager) RevokePrefix(ctx context.Context, prefix string, syn // RevokeByToken is used to revoke all the secrets issued with a given token. // This is done by using the secondary index. It also removes the lease entry // for the token itself. As a result it should *ONLY* ever be called from the -// token store's revokeSalted function. +// token store's revokeInternal function. +// (NB: it's called by token tidy as well.) func (m *ExpirationManager) RevokeByToken(ctx context.Context, te *logical.TokenEntry) error { defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now()) tokenNS, err := NamespaceByID(ctx, te.NamespaceID, m.core) @@ -1047,31 +1071,12 @@ func (m *ExpirationManager) RevokeByToken(ctx context.Context, te *logical.Token return fmt.Errorf("failed to scan for leases: %w", err) } - // Revoke all the keys + // Revoke all the keys by marking them expired for _, leaseID := range existing { - // Load the entry - le, err := m.loadEntry(ctx, leaseID) + err := m.lazyRevokeInternal(ctx, leaseID) if err != nil { return err } - - // If there's a lease, set expiration to now, persist, and call - // updatePending to hand off revocation to the expiration manager's pending - // timer map - if le != nil { - le.ExpireTime = time.Now() - - { - m.pendingLock.Lock() - if err := m.persistEntry(ctx, le); err != nil { - m.pendingLock.Unlock() - return err - } - - m.updatePendingInternal(le) - m.pendingLock.Unlock() - } - } } // te.Path should never be empty, but we check just in case @@ -1137,6 +1142,7 @@ func (m *ExpirationManager) revokePrefixCommon(ctx context.Context, prefix strin // Revoke all the keys for idx, suffix := range existing { leaseID := prefix + suffix + // No need to acquire per-lease lock here, one of these two will do it. switch { case sync: if err := m.revokeCommon(ctx, leaseID, force, false); err != nil { @@ -1157,6 +1163,11 @@ func (m *ExpirationManager) revokePrefixCommon(ctx context.Context, prefix strin func (m *ExpirationManager) Renew(ctx context.Context, leaseID string, increment time.Duration) (*logical.Response, error) { defer metrics.MeasureSince([]string{"expire", "renew"}, time.Now()) + // Acquire lock for this lease + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() + // Load the entry le, err := m.loadEntry(ctx, leaseID) if err != nil { @@ -1249,18 +1260,13 @@ func (m *ExpirationManager) Renew(ctx context.Context, leaseID string, increment } } - { - m.pendingLock.Lock() - if err := m.persistEntry(ctx, le); err != nil { - m.pendingLock.Unlock() - return nil, err - } - - // Update the expiration time - m.updatePendingInternal(le) - m.pendingLock.Unlock() + if err := m.persistEntry(ctx, le); err != nil { + return nil, err } + // Update the expiration time + m.updatePending(le) + // Return the response return resp, nil } @@ -1299,6 +1305,11 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request leaseID = fmt.Sprintf("%s.%s", leaseID, ns.ID) } + // Acquire lock for this lease + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() + // Load the entry le, err := m.loadEntry(ctx, leaseID) if err != nil { @@ -1368,17 +1379,10 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request le.ExpireTime = resp.Auth.ExpirationTime() le.LastRenewalTime = time.Now() - { - m.pendingLock.Lock() - if err := m.persistEntry(ctx, le); err != nil { - m.pendingLock.Unlock() - return nil, err - } - - // Update the expiration time - m.updatePendingInternal(le) - m.pendingLock.Unlock() + if err := m.persistEntry(ctx, le); err != nil { + return nil, err } + m.updatePending(le) retResp.Auth = resp.Auth return retResp, nil @@ -1467,6 +1471,8 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, if err := m.removeIndexByToken(ctx, le, indexToken); err != nil { retErr = multierror.Append(retErr, fmt.Errorf("an additional error was encountered removing lease indexes associated with the newly-generated secret: %w", err)) } + + m.deleteLockForLease(leaseID) } }() @@ -1485,6 +1491,14 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, } } + // Acquire the lock here so persistEntry and updatePending are atomic, + // although it is *very unlikely* that anybody could grab the lease ID + // before this function returns. (They could find it in an index, or + // find it in a list.) + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() + // Encode the entry if err := m.persistEntry(ctx, le); err != nil { return "", err @@ -1570,6 +1584,10 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE Version: 1, } + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() + // Encode the entry if err := m.persistEntry(ctx, &le); err != nil { return err @@ -1716,6 +1734,30 @@ func (m *ExpirationManager) uniquePoliciesGc() { } } +// Placing a lock in pendingMap means that we need to work very hard on reload +// to only create one lock. Instead, we'll create locks on-demand in an atomic fashion. +// +// Acquiring a lock from a leaseEntry is a bad idea because it could change +// between loading and acquiring the lock. So we only provide an ID-based map, and the +// locking discipline should be: +// 1. Lock lease +// 2. Load, or attempt to load, leaseEntry +// 3. Modify leaseEntry and pendingMap (atomic wrt operations on this lease) +// 4. Unlock lease +// +// The lock must be removed from the map when the lease is deleted, or is +// found to not exist in storage. loadEntry does this whenever it returns +// nil, but we should also do it in revokeCommon(). +func (m *ExpirationManager) lockForLeaseID(id string) *sync.Mutex { + mutex := &sync.Mutex{} + lock, _ := m.lockPerLease.LoadOrStore(id, mutex) + return lock.(*sync.Mutex) +} + +func (m *ExpirationManager) deleteLockForLease(id string) { + m.lockPerLease.Delete(id) +} + // updatePending is used to update a pending invocation for a lease func (m *ExpirationManager) updatePending(le *leaseEntry) { m.pendingLock.Lock() @@ -1908,7 +1950,16 @@ func (m *ExpirationManager) loadEntry(ctx context.Context, leaseID string) (*lea } else { ctx = namespace.ContextWithNamespace(ctx, namespace.RootNamespace) } - return m.loadEntryInternal(ctx, leaseID, restoreMode, true) + + // If a lease entry is nil, proactively delete the lease lock, in case we + // created one erroneously. + // If there was an error, we don't know whether the lease entry exists or not. + leaseEntry, err := m.loadEntryInternal(ctx, leaseID, restoreMode, true) + if err == nil && leaseEntry == nil { + m.deleteLockForLease(leaseID) + } + return leaseEntry, err + } // loadEntryInternal is used when you need to load an entry but also need to @@ -2133,6 +2184,15 @@ func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(ctx context.Cont // If there's no associated leaseEntry for the token, we create one if le == nil { + + // Acquire the lock here so persistEntry and updatePending are atomic, + // although it is *very unlikely* that anybody could grab the lease ID + // before this function returns. (They could find it in an index, or + // find it in a list.) + leaseLock := m.lockForLeaseID(leaseID) + leaseLock.Lock() + defer leaseLock.Unlock() + auth := &logical.Auth{ ClientToken: te.ID, LeaseOptions: logical.LeaseOptions{ @@ -2159,6 +2219,7 @@ func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(ctx context.Cont // Encode the entry if err := m.persistEntry(ctx, le); err != nil { + m.deleteLockForLease(leaseID) return "", err } } diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 2c94d249d4204..add6696791e70 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -612,6 +612,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, Path: "prod/aws/" + pathUUID, ClientToken: "root", } + req.SetTokenEntry(&logical.TokenEntry{ID: "root", NamespaceID: "root"}) resp := &logical.Response{ Secret: &logical.Secret{ LeaseOptions: logical.LeaseOptions{ @@ -623,7 +624,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, "secret_key": "abcd", }, } - _, err = exp.Register(context.Background(), req, resp) + _, err = exp.Register(namespace.RootContext(nil), req, resp) if err != nil { b.Fatalf("err: %v", err) } @@ -646,6 +647,52 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, b.StopTimer() } +func BenchmarkExpiration_Create_Leases(b *testing.B) { + logger := logging.NewVaultLogger(log.Trace) + inm, err := inmem.NewInmem(nil, logger) + if err != nil { + b.Fatal(err) + } + + c, _, _ := TestCoreUnsealedBackend(b, inm) + exp := c.expiration + noop := &NoopBackend{} + view := NewBarrierView(c.barrier, "logical/") + meUUID, err := uuid.GenerateUUID() + if err != nil { + b.Fatal(err) + } + err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor", namespace: namespace.RootNamespace}, view) + if err != nil { + b.Fatal(err) + } + req := &logical.Request{ + Operation: logical.ReadOperation, + ClientToken: "root", + } + req.SetTokenEntry(&logical.TokenEntry{ID: "root", NamespaceID: "root"}) + resp := &logical.Response{ + Secret: &logical.Secret{ + LeaseOptions: logical.LeaseOptions{ + TTL: 400 * time.Second, + }, + }, + Data: map[string]interface{}{ + "access_key": "xyz", + "secret_key": "abcd", + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req.Path = fmt.Sprintf("prod/aws/%d", i) + _, err = exp.Register(namespace.RootContext(nil), req, resp) + if err != nil { + b.Fatalf("err: %v", err) + } + } +} + func TestExpiration_Restore(t *testing.T) { c, _, _ := TestCoreUnsealed(t) exp := c.expiration