Skip to content

Commit

Permalink
Acquire a per-lock lease to make renew and revoke atomic wrt each oth…
Browse files Browse the repository at this point in the history
…er. (#11122)

* Acquire a per-lock lease to make renew and revoke atomic wrt each other.
This means we don't have to hold pendingLock during I/O.

* Attempted fix for deadlock in token revocation.

* Comment fix.

* Fix error checking in loadEntry.

* Add benchmark

* Add a few additional locking locations

* Improve benchmark slightly

* Update vault/expiration.go

Co-authored-by: swayne275 <swayne275@gmail.com>

* Update vault/expiration.go

Co-authored-by: swayne275 <swayne275@gmail.com>

* Add a lease lock into tidy

Co-authored-by: Scott Miller <smiller@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
Co-authored-by: swayne275 <swayne275@gmail.com>
  • Loading branch information
5 people committed Jun 10, 2021
1 parent f8353a8 commit 5323b68
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 53 deletions.
165 changes: 113 additions & 52 deletions vault/expiration.go
Expand Up @@ -112,6 +112,8 @@ type ExpirationManager struct {
leaseCount int
pendingLock sync.RWMutex

// A sync.Lock for every active leaseID
lockPerLease sync.Map
// Track expired leases that have been determined to be irrevocable (without
// manual intervention). We retain a subset of the lease info in memory
irrevocable sync.Map
Expand Down Expand Up @@ -392,6 +394,8 @@ func NewExpirationManager(c *Core, view *BarrierView, e ExpireLeaseStrategy, log
leaseCount: 0,
tidyLock: new(int32),

lockPerLease: sync.Map{},

uniquePolicies: make(map[string][]string),
emptyUniquePolicies: time.NewTicker(7 * 24 * time.Hour),

Expand Down Expand Up @@ -642,7 +646,11 @@ func (m *ExpirationManager) Tidy(ctx context.Context) error {
if revokeLease {
// Force the revocation and skip going through the token store
// again

leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
err = m.revokeCommon(ctx, leaseID, true, true)
leaseLock.Unlock()
if err != nil {
tidyErrors = multierror.Append(tidyErrors, fmt.Errorf("failed to revoke an invalid lease with ID %q: %w", leaseID, err))
return
Expand Down Expand Up @@ -905,6 +913,14 @@ func (m *ExpirationManager) Revoke(ctx context.Context, leaseID string) error {
// it triggers a return of a 202.
func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) error {
defer metrics.MeasureSince([]string{"expire", "lazy-revoke"}, time.Now())
return m.lazyRevokeInternal(ctx, leaseID)
}

// Mark a lease as expiring immediately
func (m *ExpirationManager) lazyRevokeInternal(ctx context.Context, leaseID string) error {
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
Expand All @@ -918,16 +934,10 @@ func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) erro
}

le.ExpireTime = time.Now()
{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return err
}

m.updatePendingInternal(le)
m.pendingLock.Unlock()
if err := m.persistEntry(ctx, le); err != nil {
return err
}
m.updatePending(le)

return nil
}
Expand All @@ -937,6 +947,16 @@ func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) erro
func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, force, skipToken bool) error {
defer metrics.MeasureSince([]string{"expire", "revoke-common"}, time.Now())

if !skipToken {
// Acquire lock for this lease
// If skipToken is true, then we're either being (1) called via RevokeByToken, so
// probably the lock is already held, and if we re-acquire we get deadlock, or
// (2) called by tidy, in which case the lock is held by the tidy thread.
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()
}

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
if err != nil {
Expand Down Expand Up @@ -966,6 +986,9 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo
return err
}

// Lease has been removed, also remove the in-memory lock.
m.deleteLockForLease(leaseID)

// Delete the secondary index, but only if it's a leased secret (not auth)
if le.Secret != nil {
var indexToken string
Expand Down Expand Up @@ -1029,7 +1052,8 @@ func (m *ExpirationManager) RevokePrefix(ctx context.Context, prefix string, syn
// RevokeByToken is used to revoke all the secrets issued with a given token.
// This is done by using the secondary index. It also removes the lease entry
// for the token itself. As a result it should *ONLY* ever be called from the
// token store's revokeSalted function.
// token store's revokeInternal function.
// (NB: it's called by token tidy as well.)
func (m *ExpirationManager) RevokeByToken(ctx context.Context, te *logical.TokenEntry) error {
defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now())
tokenNS, err := NamespaceByID(ctx, te.NamespaceID, m.core)
Expand All @@ -1047,31 +1071,12 @@ func (m *ExpirationManager) RevokeByToken(ctx context.Context, te *logical.Token
return fmt.Errorf("failed to scan for leases: %w", err)
}

// Revoke all the keys
// Revoke all the keys by marking them expired
for _, leaseID := range existing {
// Load the entry
le, err := m.loadEntry(ctx, leaseID)
err := m.lazyRevokeInternal(ctx, leaseID)
if err != nil {
return err
}

// If there's a lease, set expiration to now, persist, and call
// updatePending to hand off revocation to the expiration manager's pending
// timer map
if le != nil {
le.ExpireTime = time.Now()

{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return err
}

m.updatePendingInternal(le)
m.pendingLock.Unlock()
}
}
}

// te.Path should never be empty, but we check just in case
Expand Down Expand Up @@ -1137,6 +1142,7 @@ func (m *ExpirationManager) revokePrefixCommon(ctx context.Context, prefix strin
// Revoke all the keys
for idx, suffix := range existing {
leaseID := prefix + suffix
// No need to acquire per-lease lock here, one of these two will do it.
switch {
case sync:
if err := m.revokeCommon(ctx, leaseID, force, false); err != nil {
Expand All @@ -1157,6 +1163,11 @@ func (m *ExpirationManager) revokePrefixCommon(ctx context.Context, prefix strin
func (m *ExpirationManager) Renew(ctx context.Context, leaseID string, increment time.Duration) (*logical.Response, error) {
defer metrics.MeasureSince([]string{"expire", "renew"}, time.Now())

// Acquire lock for this lease
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
if err != nil {
Expand Down Expand Up @@ -1249,18 +1260,13 @@ func (m *ExpirationManager) Renew(ctx context.Context, leaseID string, increment
}
}

{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return nil, err
}

// Update the expiration time
m.updatePendingInternal(le)
m.pendingLock.Unlock()
if err := m.persistEntry(ctx, le); err != nil {
return nil, err
}

// Update the expiration time
m.updatePending(le)

// Return the response
return resp, nil
}
Expand Down Expand Up @@ -1299,6 +1305,11 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request
leaseID = fmt.Sprintf("%s.%s", leaseID, ns.ID)
}

// Acquire lock for this lease
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
if err != nil {
Expand Down Expand Up @@ -1368,17 +1379,10 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request
le.ExpireTime = resp.Auth.ExpirationTime()
le.LastRenewalTime = time.Now()

{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return nil, err
}

// Update the expiration time
m.updatePendingInternal(le)
m.pendingLock.Unlock()
if err := m.persistEntry(ctx, le); err != nil {
return nil, err
}
m.updatePending(le)

retResp.Auth = resp.Auth
return retResp, nil
Expand Down Expand Up @@ -1467,6 +1471,8 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request,
if err := m.removeIndexByToken(ctx, le, indexToken); err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("an additional error was encountered removing lease indexes associated with the newly-generated secret: %w", err))
}

m.deleteLockForLease(leaseID)
}
}()

Expand All @@ -1485,6 +1491,14 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request,
}
}

// Acquire the lock here so persistEntry and updatePending are atomic,
// although it is *very unlikely* that anybody could grab the lease ID
// before this function returns. (They could find it in an index, or
// find it in a list.)
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Encode the entry
if err := m.persistEntry(ctx, le); err != nil {
return "", err
Expand Down Expand Up @@ -1570,6 +1584,10 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE
Version: 1,
}

leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Encode the entry
if err := m.persistEntry(ctx, &le); err != nil {
return err
Expand Down Expand Up @@ -1716,6 +1734,30 @@ func (m *ExpirationManager) uniquePoliciesGc() {
}
}

// Placing a lock in pendingMap means that we need to work very hard on reload
// to only create one lock. Instead, we'll create locks on-demand in an atomic fashion.
//
// Acquiring a lock from a leaseEntry is a bad idea because it could change
// between loading and acquiring the lock. So we only provide an ID-based map, and the
// locking discipline should be:
// 1. Lock lease
// 2. Load, or attempt to load, leaseEntry
// 3. Modify leaseEntry and pendingMap (atomic wrt operations on this lease)
// 4. Unlock lease
//
// The lock must be removed from the map when the lease is deleted, or is
// found to not exist in storage. loadEntry does this whenever it returns
// nil, but we should also do it in revokeCommon().
func (m *ExpirationManager) lockForLeaseID(id string) *sync.Mutex {
mutex := &sync.Mutex{}
lock, _ := m.lockPerLease.LoadOrStore(id, mutex)
return lock.(*sync.Mutex)
}

func (m *ExpirationManager) deleteLockForLease(id string) {
m.lockPerLease.Delete(id)
}

// updatePending is used to update a pending invocation for a lease
func (m *ExpirationManager) updatePending(le *leaseEntry) {
m.pendingLock.Lock()
Expand Down Expand Up @@ -1908,7 +1950,16 @@ func (m *ExpirationManager) loadEntry(ctx context.Context, leaseID string) (*lea
} else {
ctx = namespace.ContextWithNamespace(ctx, namespace.RootNamespace)
}
return m.loadEntryInternal(ctx, leaseID, restoreMode, true)

// If a lease entry is nil, proactively delete the lease lock, in case we
// created one erroneously.
// If there was an error, we don't know whether the lease entry exists or not.
leaseEntry, err := m.loadEntryInternal(ctx, leaseID, restoreMode, true)
if err == nil && leaseEntry == nil {
m.deleteLockForLease(leaseID)
}
return leaseEntry, err

}

// loadEntryInternal is used when you need to load an entry but also need to
Expand Down Expand Up @@ -2133,6 +2184,15 @@ func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(ctx context.Cont

// If there's no associated leaseEntry for the token, we create one
if le == nil {

// Acquire the lock here so persistEntry and updatePending are atomic,
// although it is *very unlikely* that anybody could grab the lease ID
// before this function returns. (They could find it in an index, or
// find it in a list.)
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

auth := &logical.Auth{
ClientToken: te.ID,
LeaseOptions: logical.LeaseOptions{
Expand All @@ -2159,6 +2219,7 @@ func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(ctx context.Cont

// Encode the entry
if err := m.persistEntry(ctx, le); err != nil {
m.deleteLockForLease(leaseID)
return "", err
}
}
Expand Down
49 changes: 48 additions & 1 deletion vault/expiration_test.go
Expand Up @@ -612,6 +612,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
Path: "prod/aws/" + pathUUID,
ClientToken: "root",
}
req.SetTokenEntry(&logical.TokenEntry{ID: "root", NamespaceID: "root"})
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
Expand All @@ -623,7 +624,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
"secret_key": "abcd",
},
}
_, err = exp.Register(context.Background(), req, resp)
_, err = exp.Register(namespace.RootContext(nil), req, resp)
if err != nil {
b.Fatalf("err: %v", err)
}
Expand All @@ -646,6 +647,52 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
b.StopTimer()
}

func BenchmarkExpiration_Create_Leases(b *testing.B) {
logger := logging.NewVaultLogger(log.Trace)
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
b.Fatal(err)
}

c, _, _ := TestCoreUnsealedBackend(b, inm)
exp := c.expiration
noop := &NoopBackend{}
view := NewBarrierView(c.barrier, "logical/")
meUUID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor", namespace: namespace.RootNamespace}, view)
if err != nil {
b.Fatal(err)
}
req := &logical.Request{
Operation: logical.ReadOperation,
ClientToken: "root",
}
req.SetTokenEntry(&logical.TokenEntry{ID: "root", NamespaceID: "root"})
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: 400 * time.Second,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
req.Path = fmt.Sprintf("prod/aws/%d", i)
_, err = exp.Register(namespace.RootContext(nil), req, resp)
if err != nil {
b.Fatalf("err: %v", err)
}
}
}

func TestExpiration_Restore(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
exp := c.expiration
Expand Down

0 comments on commit 5323b68

Please sign in to comment.