Skip to content

Commit

Permalink
Merge pull request #2911 from nats-io/fix_lock_inversions
Browse files Browse the repository at this point in the history
[FIXED] Some lock inversions
  • Loading branch information
kozlovic committed Mar 9, 2022
2 parents 3538aea + 0fae806 commit 5a97ee6
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 43 deletions.
5 changes: 5 additions & 0 deletions locksordering.txt
@@ -0,0 +1,5 @@
Here is the list of some established lock ordering.

In this list, A -> B means that you can have A.Lock() then B.Lock(), not the opposite.

jetStream -> jsAccount -> Server -> client-> Account
94 changes: 56 additions & 38 deletions server/accounts.go
Expand Up @@ -292,6 +292,27 @@ func (a *Account) nextEventID() string {
return id
}

// Returns a slice of clients stored in the account, or nil if none is present.
// Lock is held on entry.
func (a *Account) getClientsLocked() []*client {
if len(a.clients) == 0 {
return nil
}
clients := make([]*client, 0, len(a.clients))
for c := range a.clients {
clients = append(clients, c)
}
return clients
}

// Returns a slice of clients stored in the account, or nil if none is present.
func (a *Account) getClients() []*client {
a.mu.RLock()
clients := a.getClientsLocked()
a.mu.RUnlock()
return clients
}

// Called to track a remote server and connections and leafnodes it
// has for this account.
func (a *Account) updateRemoteServer(m *AccountNumConns) []*client {
Expand All @@ -312,10 +333,7 @@ func (a *Account) updateRemoteServer(m *AccountNumConns) []*client {
// conservative and bit harsh here. Clients will reconnect if we over compensate.
var clients []*client
if mtce {
clients = make([]*client, 0, len(a.clients))
for c := range a.clients {
clients = append(clients, c)
}
clients := a.getClientsLocked()
sort.Slice(clients, func(i, j int) bool {
return clients[i].start.After(clients[j].start)
})
Expand Down Expand Up @@ -670,9 +688,13 @@ func (a *Account) AddWeightedMappings(src string, dests ...*MapDest) error {

// If we have connected leafnodes make sure to update.
if len(a.lleafs) > 0 {
for _, lc := range a.lleafs {
leafs := append([]*client(nil), a.lleafs...)
// Need to release because lock ordering is client -> account
a.mu.Unlock()
for _, lc := range leafs {
lc.forceAddToSmap(src)
}
a.mu.Lock()
}
return nil
}
Expand Down Expand Up @@ -963,8 +985,6 @@ func (a *Account) addServiceExportWithResponseAndAccountPos(
}

a.mu.Lock()
defer a.mu.Unlock()

if a.exports.services == nil {
a.exports.services = make(map[string]*serviceExport)
}
Expand All @@ -981,15 +1001,24 @@ func (a *Account) addServiceExportWithResponseAndAccountPos(

if accounts != nil || accountPos > 0 {
if err := setExportAuth(&se.exportAuth, subject, accounts, accountPos); err != nil {
a.mu.Unlock()
return err
}
}
lrt := a.lowestServiceExportResponseTime()
se.acc = a
se.respThresh = DEFAULT_SERVICE_EXPORT_RESPONSE_THRESHOLD
a.exports.services[subject] = se
if nlrt := a.lowestServiceExportResponseTime(); nlrt != lrt {
a.updateAllClientsServiceExportResponseTime(nlrt)

var clients []*client
nlrt := a.lowestServiceExportResponseTime()
if nlrt != lrt && len(a.clients) > 0 {
clients = a.getClientsLocked()
}
// Need to release because lock ordering is client -> Account
a.mu.Unlock()
if len(clients) > 0 {
updateAllClientsServiceExportResponseTime(clients, nlrt)
}
return nil
}
Expand Down Expand Up @@ -1353,9 +1382,8 @@ func (a *Account) sendTrackingLatency(si *serviceImport, responder *client) bool

// This will check to make sure our response lower threshold is set
// properly in any clients doing rrTracking.
// Lock should be held.
func (a *Account) updateAllClientsServiceExportResponseTime(lrt time.Duration) {
for c := range a.clients {
func updateAllClientsServiceExportResponseTime(clients []*client, lrt time.Duration) {
for _, c := range clients {
c.mu.Lock()
if c.rrTracking != nil && lrt != c.rrTracking.lrt {
c.rrTracking.lrt = lrt
Expand Down Expand Up @@ -2234,18 +2262,27 @@ func (a *Account) ServiceExportResponseThreshold(export string) (time.Duration,
// from a service export responder.
func (a *Account) SetServiceExportResponseThreshold(export string, maxTime time.Duration) error {
a.mu.Lock()
defer a.mu.Unlock()
if a.isClaimAccount() {
a.mu.Unlock()
return fmt.Errorf("claim based accounts can not be updated directly")
}
lrt := a.lowestServiceExportResponseTime()
se := a.getServiceExport(export)
if se == nil {
a.mu.Unlock()
return fmt.Errorf("no export defined for %q", export)
}
se.respThresh = maxTime
if nlrt := a.lowestServiceExportResponseTime(); nlrt != lrt {
a.updateAllClientsServiceExportResponseTime(nlrt)

var clients []*client
nlrt := a.lowestServiceExportResponseTime()
if nlrt != lrt && len(a.clients) > 0 {
clients = a.getClientsLocked()
}
// Need to release because lock ordering is client -> Account
a.mu.Unlock()
if len(clients) > 0 {
updateAllClientsServiceExportResponseTime(clients, nlrt)
}
return nil
}
Expand Down Expand Up @@ -2569,10 +2606,7 @@ func (a *Account) streamActivationExpired(exportAcc *Account, subject string) {

a.mu.Lock()
si.invalid = true
clients := make([]*client, 0, len(a.clients))
for c := range a.clients {
clients = append(clients, c)
}
clients := a.getClientsLocked()
awcsti := map[string]struct{}{a.Name: {}}
a.mu.Unlock()
for _, c := range clients {
Expand Down Expand Up @@ -2779,13 +2813,7 @@ func (a *Account) expiredTimeout() {
a.mu.Unlock()

// Collect the clients and expire them.
cs := make([]*client, 0, len(a.clients))
a.mu.RLock()
for c := range a.clients {
cs = append(cs, c)
}
a.mu.RUnlock()

cs := a.getClients()
for _, c := range cs {
c.accountAuthExpired()
}
Expand Down Expand Up @@ -3001,16 +3029,6 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
s.registerSystemImports(a)
}

gatherClients := func() []*client {
a.mu.RLock()
clients := make([]*client, 0, len(a.clients))
for c := range a.clients {
clients = append(clients, c)
}
a.mu.RUnlock()
return clients
}

jsEnabled := s.JetStreamEnabled()
if jsEnabled && a == s.SystemAccount() {
s.checkJetStreamExports()
Expand Down Expand Up @@ -3144,7 +3162,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
// Now let's apply any needed changes from import/export changes.
if !a.checkStreamImportsEqual(old) {
awcsti := map[string]struct{}{a.Name: {}}
for _, c := range gatherClients() {
for _, c := range a.getClients() {
c.processSubsOnConfigReload(awcsti)
}
}
Expand Down Expand Up @@ -3266,9 +3284,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}

a.updated = time.Now().UTC()
clients := a.getClientsLocked()
a.mu.Unlock()

clients := gatherClients()
// Sort if we are over the limit.
if a.MaxTotalConnectionsReached() {
sort.Slice(clients, func(i, j int) bool {
Expand Down
4 changes: 2 additions & 2 deletions server/jetstream_test.go
Expand Up @@ -4524,7 +4524,7 @@ func TestJetStreamSnapshotsAPI(t *testing.T) {
if err != nil {
t.Fatalf("Expected to find a stream for %q", mname)
}
state = mset.state()
mset.state()
mset.delete()

rreq.Config.Name = "NEW_STREAM"
Expand Down Expand Up @@ -15195,7 +15195,7 @@ func TestJetStreamPullConsumerHeartBeats(t *testing.T) {
}
}()

start, msgs = time.Now(), doReq(10, 75*time.Millisecond, 350*time.Millisecond, 6)
msgs = doReq(10, 75*time.Millisecond, 350*time.Millisecond, 6)
// The first 5 should be msgs, no HBs.
for i := 0; i < 5; i++ {
if m := msgs[i].msg; len(m.Header) > 0 {
Expand Down
1 change: 0 additions & 1 deletion server/monitor_test.go
Expand Up @@ -677,7 +677,6 @@ func TestConnzLastActivity(t *testing.T) {
if barLA.Equal(nextLA) {
t.Fatalf("Publish should have triggered update to LastActivity\n")
}
barLA = nextLA

// Message delivery on ncFoo should have triggered as well.
nextLA = ciFoo.LastActivity
Expand Down
2 changes: 0 additions & 2 deletions test/service_latency_test.go
Expand Up @@ -361,7 +361,6 @@ func TestServiceLatencyClientRTTSlowerVsServiceRTT(t *testing.T) {
}

// Send the request.
start = time.Now()
_, err := nc2.Request("ngs.usage", []byte("1h"), time.Second)
if err != nil {
t.Fatalf("Expected a response")
Expand Down Expand Up @@ -1500,7 +1499,6 @@ func TestServiceLatencyRequestorSharesConfig(t *testing.T) {
t.Fatalf("Error on server reload: %v", err)
}

start = time.Now()
if _, err = nc2.Request("SVC", []byte("1h"), time.Second); err != nil {
t.Fatalf("Expected a response")
}
Expand Down

0 comments on commit 5a97ee6

Please sign in to comment.