Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acquire a per-lock lease to make renew and revoke atomic wrt each other. #11122

Merged
merged 13 commits into from Jun 10, 2021
162 changes: 110 additions & 52 deletions vault/expiration.go
Expand Up @@ -112,6 +112,8 @@ type ExpirationManager struct {
leaseCount int
pendingLock sync.RWMutex

// A sync.Lock for every active leaseID
lockPerLease sync.Map
// Track expired leases that have been determined to be irrevocable (without
// manual intervention). We retain a subset of the lease info in memory
irrevocable sync.Map
Expand Down Expand Up @@ -392,6 +394,8 @@ func NewExpirationManager(c *Core, view *BarrierView, e ExpireLeaseStrategy, log
leaseCount: 0,
tidyLock: new(int32),

lockPerLease: sync.Map{},

uniquePolicies: make(map[string][]string),
emptyUniquePolicies: time.NewTicker(7 * 24 * time.Hour),

Expand Down Expand Up @@ -905,6 +909,14 @@ func (m *ExpirationManager) Revoke(ctx context.Context, leaseID string) error {
// it triggers a return of a 202.
func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) error {
defer metrics.MeasureSince([]string{"expire", "lazy-revoke"}, time.Now())
return m.lazyRevokeInternal(ctx, leaseID)
}

// Mark a lease as expiring immediately
func (m *ExpirationManager) lazyRevokeInternal(ctx context.Context, leaseID string) error {
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
Expand All @@ -918,16 +930,10 @@ func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) erro
}

le.ExpireTime = time.Now()
{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return err
}

m.updatePendingInternal(le)
m.pendingLock.Unlock()
if err := m.persistEntry(ctx, le); err != nil {
return err
}
m.updatePending(le)

return nil
}
Expand All @@ -937,6 +943,17 @@ func (m *ExpirationManager) LazyRevoke(ctx context.Context, leaseID string) erro
func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, force, skipToken bool) error {
defer metrics.MeasureSince([]string{"expire", "revoke-common"}, time.Now())

if !skipToken {
// Acquire lease for this lock
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
// If skipToken is true, then we're either being (1) called via RevokeByToken, so
// probably the lock is already held, and if we re-acquire we get deadlock, or
// (2) called by tidy, in which case the lock is not held.
// Is it worth separating those cases out, or is (2) OK to proceed unlocked?
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()
}

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
if err != nil {
Expand Down Expand Up @@ -966,6 +983,9 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo
return err
}

// Lease has been removed, also remove the in-memory lock.
m.deleteLockForLease(leaseID)

// Delete the secondary index, but only if it's a leased secret (not auth)
if le.Secret != nil {
var indexToken string
Expand Down Expand Up @@ -1029,7 +1049,8 @@ func (m *ExpirationManager) RevokePrefix(ctx context.Context, prefix string, syn
// RevokeByToken is used to revoke all the secrets issued with a given token.
// This is done by using the secondary index. It also removes the lease entry
// for the token itself. As a result it should *ONLY* ever be called from the
// token store's revokeSalted function.
// token store's revokeInternal function.
// (NB: it's called by token tidy as well.)
func (m *ExpirationManager) RevokeByToken(ctx context.Context, te *logical.TokenEntry) error {
defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now())
tokenNS, err := NamespaceByID(ctx, te.NamespaceID, m.core)
Expand All @@ -1047,31 +1068,12 @@ func (m *ExpirationManager) RevokeByToken(ctx context.Context, te *logical.Token
return fmt.Errorf("failed to scan for leases: %w", err)
}

// Revoke all the keys
// Revoke all the keys by marking them expired
for _, leaseID := range existing {
// Load the entry
le, err := m.loadEntry(ctx, leaseID)
err := m.lazyRevokeInternal(ctx, leaseID)
if err != nil {
return err
}

// If there's a lease, set expiration to now, persist, and call
// updatePending to hand off revocation to the expiration manager's pending
// timer map
if le != nil {
le.ExpireTime = time.Now()

{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return err
}

m.updatePendingInternal(le)
m.pendingLock.Unlock()
}
}
}

// te.Path should never be empty, but we check just in case
Expand Down Expand Up @@ -1137,6 +1139,7 @@ func (m *ExpirationManager) revokePrefixCommon(ctx context.Context, prefix strin
// Revoke all the keys
for idx, suffix := range existing {
leaseID := prefix + suffix
// No need to acquire per-lease lock here, one of these two will do it.
switch {
case sync:
if err := m.revokeCommon(ctx, leaseID, force, false); err != nil {
Expand All @@ -1157,6 +1160,11 @@ func (m *ExpirationManager) revokePrefixCommon(ctx context.Context, prefix strin
func (m *ExpirationManager) Renew(ctx context.Context, leaseID string, increment time.Duration) (*logical.Response, error) {
defer metrics.MeasureSince([]string{"expire", "renew"}, time.Now())

// Acquire lock for this lease
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
Expand Down Expand Up @@ -1249,18 +1257,13 @@ func (m *ExpirationManager) Renew(ctx context.Context, leaseID string, increment
}
}

{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return nil, err
}

// Update the expiration time
m.updatePendingInternal(le)
m.pendingLock.Unlock()
if err := m.persistEntry(ctx, le); err != nil {
return nil, err
}

// Update the expiration time
m.updatePending(le)

// Return the response
return resp, nil
}
Expand Down Expand Up @@ -1299,6 +1302,11 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request
leaseID = fmt.Sprintf("%s.%s", leaseID, ns.ID)
}

// Acquire lock for this lease
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
mgritter marked this conversation as resolved.
Show resolved Hide resolved
defer leaseLock.Unlock()

// Load the entry
le, err := m.loadEntry(ctx, leaseID)
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
Expand Down Expand Up @@ -1368,17 +1376,10 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request
le.ExpireTime = resp.Auth.ExpirationTime()
le.LastRenewalTime = time.Now()

{
m.pendingLock.Lock()
if err := m.persistEntry(ctx, le); err != nil {
m.pendingLock.Unlock()
return nil, err
}

// Update the expiration time
m.updatePendingInternal(le)
m.pendingLock.Unlock()
if err := m.persistEntry(ctx, le); err != nil {
return nil, err
}
m.updatePending(le)

retResp.Auth = resp.Auth
return retResp, nil
Expand Down Expand Up @@ -1467,6 +1468,8 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request,
if err := m.removeIndexByToken(ctx, le, indexToken); err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("an additional error was encountered removing lease indexes associated with the newly-generated secret: %w", err))
}

m.deleteLockForLease(leaseID)
}
}()

Expand All @@ -1485,6 +1488,14 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request,
}
}

// Acquire the lock here so persistEntry and updatePending are atomic,
// although it is *very unlikely* that anybody could grab the lease ID
// before this function returns. (They could find it in an index, or
// find it in a list.)
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Encode the entry
if err := m.persistEntry(ctx, le); err != nil {
return "", err
Expand Down Expand Up @@ -1570,6 +1581,10 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE
Version: 1,
}

leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

// Encode the entry
if err := m.persistEntry(ctx, &le); err != nil {
return err
Expand Down Expand Up @@ -1716,6 +1731,30 @@ func (m *ExpirationManager) uniquePoliciesGc() {
}
}

// Placing a lock in pendingMap means that we need to work very hard on reload
// to only create one lock. Instead, we'll create locks on-demand in an atomic fashion.
//
// Acquiring a lock from a leaseEntry is a bad idea because it could change
// between loading and acquirung the lock. So we only provide an ID-based map, and the
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
// locking discipline should be:
// 1. Lock lease
// 2. Load, or attempt to load, leaseEntry
// 3. Modify leaseEntry and pendingMap (atomic wrt operations on this lease)
// 4. Unlock lease
//
// The lock must be removed from the map when the lease is deleted, or is
// found to not exist in storage. loadEntry does this whenever it returns
// nil, but we should also do it in revokeCommon().
func (m *ExpirationManager) lockForLeaseID(id string) *sync.Mutex {
mutex := &sync.Mutex{}
lock, _ := m.lockPerLease.LoadOrStore(id, mutex)
return lock.(*sync.Mutex)
}

func (m *ExpirationManager) deleteLockForLease(id string) {
m.lockPerLease.Delete(id)
}

// updatePending is used to update a pending invocation for a lease
func (m *ExpirationManager) updatePending(le *leaseEntry) {
m.pendingLock.Lock()
Expand Down Expand Up @@ -1908,7 +1947,16 @@ func (m *ExpirationManager) loadEntry(ctx context.Context, leaseID string) (*lea
} else {
ctx = namespace.ContextWithNamespace(ctx, namespace.RootNamespace)
}
return m.loadEntryInternal(ctx, leaseID, restoreMode, true)

// If a lease entry is nil, proactively delete the lease lock, in case we
// created one erroneously.
// If there was an error, we don't know whether the lease entry exists or not.
leaseEntry, err := m.loadEntryInternal(ctx, leaseID, restoreMode, true)
if err == nil && leaseEntry == nil {
m.deleteLockForLease(leaseID)
swayne275 marked this conversation as resolved.
Show resolved Hide resolved
}
return leaseEntry, err

}

// loadEntryInternal is used when you need to load an entry but also need to
Expand Down Expand Up @@ -2133,6 +2181,15 @@ func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(ctx context.Cont

// If there's no associated leaseEntry for the token, we create one
if le == nil {

// Acquire the lock here so persistEntry and updatePending are atomic,
// although it is *very unlikely* that anybody could grab the lease ID
// before this function returns. (They could find it in an index, or
// find it in a list.)
leaseLock := m.lockForLeaseID(leaseID)
leaseLock.Lock()
defer leaseLock.Unlock()

auth := &logical.Auth{
ClientToken: te.ID,
LeaseOptions: logical.LeaseOptions{
Expand All @@ -2159,6 +2216,7 @@ func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(ctx context.Cont

// Encode the entry
if err := m.persistEntry(ctx, le); err != nil {
m.deleteLockForLease(leaseID)
return "", err
}
}
Expand Down
49 changes: 48 additions & 1 deletion vault/expiration_test.go
Expand Up @@ -612,6 +612,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
Path: "prod/aws/" + pathUUID,
ClientToken: "root",
}
req.SetTokenEntry(&logical.TokenEntry{ID: "root", NamespaceID: "root"})
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
Expand All @@ -623,7 +624,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
"secret_key": "abcd",
},
}
_, err = exp.Register(context.Background(), req, resp)
_, err = exp.Register(namespace.RootContext(nil), req, resp)
if err != nil {
b.Fatalf("err: %v", err)
}
Expand All @@ -646,6 +647,52 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
b.StopTimer()
}

func BenchmarkExpiration_Create_Leases(b *testing.B) {
logger := logging.NewVaultLogger(log.Trace)
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
b.Fatal(err)
}

c, _, _ := TestCoreUnsealedBackend(b, inm)
exp := c.expiration
noop := &NoopBackend{}
view := NewBarrierView(c.barrier, "logical/")
meUUID, err := uuid.GenerateUUID()
if err != nil {
b.Fatal(err)
}
err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor", namespace: namespace.RootNamespace}, view)
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
req := &logical.Request{
Operation: logical.ReadOperation,
Path: fmt.Sprintf("prod/aws/%d", i),
ClientToken: "root",
}
req.SetTokenEntry(&logical.TokenEntry{ID: "root", NamespaceID: "root"})
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: 400 * time.Second,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}
_, err = exp.Register(namespace.RootContext(nil), req, resp)
if err != nil {
b.Fatalf("err: %v", err)
}
}
}

func TestExpiration_Restore(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
exp := c.expiration
Expand Down