diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go index d9c807e0a8dbb..b13041d1b9c68 100644 --- a/p2p/discover/v5wire/encoding_test.go +++ b/p2p/discover/v5wire/encoding_test.go @@ -512,9 +512,6 @@ func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock. db, _ := enode.OpenDB("") n.ln = enode.NewLocalNode(db, key) n.ln.SetStaticIP(ip) - if n.ln.Node().Seq() != 1 { - panic(fmt.Errorf("unexpected seq %d", n.ln.Node().Seq())) - } n.c = NewCodec(n.ln, key, clock) } diff --git a/p2p/enode/localnode.go b/p2p/enode/localnode.go index d8aa02a77e266..4827b6c0aedcd 100644 --- a/p2p/enode/localnode.go +++ b/p2p/enode/localnode.go @@ -36,20 +36,25 @@ const ( iptrackMinStatements = 10 iptrackWindow = 5 * time.Minute iptrackContactWindow = 10 * time.Minute + + // time needed to wait between two updates to the local ENR + recordUpdateThrottle = time.Millisecond ) // LocalNode produces the signed node record of a local node, i.e. a node run in the // current process. Setting ENR entries via the Set method updates the record. A new version // of the record is signed on demand when the Node method is called. type LocalNode struct { - cur atomic.Value // holds a non-nil node pointer while the record is up-to-date. + cur atomic.Value // holds a non-nil node pointer while the record is up-to-date + id ID key *ecdsa.PrivateKey db *DB // everything below is protected by a lock - mu sync.Mutex + mu sync.RWMutex seq uint64 + update time.Time // timestamp when the record was last updated entries map[string]enr.Entry endpoint4 lnEndpoint endpoint6 lnEndpoint @@ -76,7 +81,8 @@ func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode { }, } ln.seq = db.localSeq(ln.id) - ln.invalidate() + ln.update = time.Now() + ln.cur.Store((*Node)(nil)) return ln } @@ -87,14 +93,34 @@ func (ln *LocalNode) Database() *DB { // Node returns the current version of the local node record. func (ln *LocalNode) Node() *Node { + // If we have a valid record, return that n := ln.cur.Load().(*Node) if n != nil { return n } + // Record was invalidated, sign a new copy. ln.mu.Lock() defer ln.mu.Unlock() + + // Double check the current record, since multiple goroutines might be waiting + // on the write mutex. + if n = ln.cur.Load().(*Node); n != nil { + return n + } + + // The initial sequence number is the current timestamp in milliseconds. To ensure + // that the initial sequence number will always be higher than any previous sequence + // number (assuming the clock is correct), we want to avoid updating the record faster + // than once per ms. So we need to sleep here until the next possible update time has + // arrived. + lastChange := time.Since(ln.update) + if lastChange < recordUpdateThrottle { + time.Sleep(recordUpdateThrottle - lastChange) + } + ln.sign() + ln.update = time.Now() return ln.cur.Load().(*Node) } @@ -114,6 +140,10 @@ func (ln *LocalNode) ID() ID { // Set puts the given entry into the local record, overwriting any existing value. // Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll // be overwritten by the endpoint predictor. +// +// Since node record updates are throttled to one per second, Set is asynchronous. +// Any update will be queued up and published when at least one second passes from +// the last change. func (ln *LocalNode) Set(e enr.Entry) { ln.mu.Lock() defer ln.mu.Unlock() @@ -288,3 +318,12 @@ func (ln *LocalNode) bumpSeq() { ln.seq++ ln.db.storeLocalSeq(ln.id, ln.seq) } + +// nowMilliseconds gives the current timestamp at millisecond precision. +func nowMilliseconds() uint64 { + ns := time.Now().UnixNano() + if ns < 0 { + return 0 + } + return uint64(ns / 1000 / 1000) +} diff --git a/p2p/enode/localnode_test.go b/p2p/enode/localnode_test.go index 00746a8d277af..312df813bba4f 100644 --- a/p2p/enode/localnode_test.go +++ b/p2p/enode/localnode_test.go @@ -49,32 +49,39 @@ func TestLocalNode(t *testing.T) { } } +// This test checks that the sequence number is persisted between restarts. func TestLocalNodeSeqPersist(t *testing.T) { + timestamp := nowMilliseconds() + ln, db := newLocalNodeForTesting() defer db.Close() - if s := ln.Node().Seq(); s != 1 { - t.Fatalf("wrong initial seq %d, want 1", s) + initialSeq := ln.Node().Seq() + if initialSeq < timestamp { + t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp) } + ln.Set(enr.WithEntry("x", uint(1))) - if s := ln.Node().Seq(); s != 2 { - t.Fatalf("wrong seq %d after set, want 2", s) + if s := ln.Node().Seq(); s != initialSeq+1 { + t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1) } // Create a new instance, it should reload the sequence number. // The number increases just after that because a new record is // created without the "x" entry. ln2 := NewLocalNode(db, ln.key) - if s := ln2.Node().Seq(); s != 3 { - t.Fatalf("wrong seq %d on new instance, want 3", s) + if s := ln2.Node().Seq(); s != initialSeq+2 { + t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2) } + finalSeq := ln2.Node().Seq() + // Create a new instance with a different node key on the same database. // This should reset the sequence number. key, _ := crypto.GenerateKey() ln3 := NewLocalNode(db, key) - if s := ln3.Node().Seq(); s != 1 { - t.Fatalf("wrong seq %d on instance with changed key, want 1", s) + if s := ln3.Node().Seq(); s < finalSeq { + t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq) } } @@ -91,20 +98,20 @@ func TestLocalNodeEndpoint(t *testing.T) { // Nothing is set initially. assert.Equal(t, net.IP(nil), ln.Node().IP()) assert.Equal(t, 0, ln.Node().UDP()) - assert.Equal(t, uint64(1), ln.Node().Seq()) + initialSeq := ln.Node().Seq() // Set up fallback address. ln.SetFallbackIP(fallback.IP) ln.SetFallbackUDP(fallback.Port) assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.Port, ln.Node().UDP()) - assert.Equal(t, uint64(2), ln.Node().Seq()) + assert.Equal(t, initialSeq+1, ln.Node().Seq()) // Add endpoint statements from random hosts. for i := 0; i < iptrackMinStatements; i++ { assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.Port, ln.Node().UDP()) - assert.Equal(t, uint64(2), ln.Node().Seq()) + assert.Equal(t, initialSeq+1, ln.Node().Seq()) from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90} rand.Read(from.IP) @@ -112,11 +119,11 @@ func TestLocalNodeEndpoint(t *testing.T) { } assert.Equal(t, predicted.IP, ln.Node().IP()) assert.Equal(t, predicted.Port, ln.Node().UDP()) - assert.Equal(t, uint64(3), ln.Node().Seq()) + assert.Equal(t, initialSeq+2, ln.Node().Seq()) // Static IP overrides prediction. ln.SetStaticIP(staticIP) assert.Equal(t, staticIP, ln.Node().IP()) assert.Equal(t, fallback.Port, ln.Node().UDP()) - assert.Equal(t, uint64(4), ln.Node().Seq()) + assert.Equal(t, initialSeq+3, ln.Node().Seq()) } diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index d62f383f0bc9e..d1712f75974a7 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -427,9 +427,14 @@ func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) } -// LocalSeq retrieves the local record sequence counter. +// localSeq retrieves the local record sequence counter, defaulting to the current +// timestamp if no previous exists. This ensures that wiping all data associated +// with a node (apart from its key) will not generate already used sequence nums. func (db *DB) localSeq(id ID) uint64 { - return db.fetchUint64(localItemKey(id, dbLocalSeq)) + if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 { + return seq + } + return nowMilliseconds() } // storeLocalSeq stores the local record sequence counter.