Skip to content

Commit

Permalink
p2p/enode: update to blocking update scheme
Browse files Browse the repository at this point in the history
This changes the local ENR update mechanism to block for up to 1ms in
LocalNode.Node when the previous update was less than 1ms ago.

To make this work, the granularity of sequence numbers is increased to
millisecond.
  • Loading branch information
fjl committed May 27, 2021
1 parent 136be59 commit 34c4e57
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 101 deletions.
53 changes: 27 additions & 26 deletions p2p/enode/localnode.go
Expand Up @@ -36,21 +36,16 @@ const (
iptrackMinStatements = 10
iptrackWindow = 5 * time.Minute
iptrackContactWindow = 10 * time.Minute
)

var (
// recordUpdateThrottle is the time needed to wait between two updates to an ENR
// record. Modifications in between are queued up and published together.
recordUpdateThrottle = time.Second
// time needed to wait between two updates to the local ENR
recordUpdateThrottle = time.Millisecond
)

// LocalNode produces the signed node record of a local node, i.e. a node run in the
// current process. Setting ENR entries via the Set method updates the record. A new version
// of the record is signed on demand when the Node method is called.
type LocalNode struct {
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
prev atomic.Value // holds a non-nil node pointer while the record is thottled on an update
update time.Time // timestamp when the record was last updated to prevent sequence number bloat
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date

id ID
key *ecdsa.PrivateKey
Expand All @@ -59,6 +54,7 @@ type LocalNode struct {
// everything below is protected by a lock
mu sync.RWMutex
seq uint64
update time.Time // timestamp when the record was last updated
entries map[string]enr.Entry
endpoint4 lnEndpoint
endpoint6 lnEndpoint
Expand All @@ -85,9 +81,8 @@ func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
},
}
ln.seq = db.localSeq(ln.id)
ln.prev.Store((*Node)(nil))
ln.update = time.Now()
ln.cur.Store((*Node)(nil))

return ln
}

Expand All @@ -103,18 +98,8 @@ func (ln *LocalNode) Node() *Node {
if n != nil {
return n
}
// Record was invalidated, check for a previous version and use that unless we
// are allowed to update.
if n = ln.prev.Load().(*Node); n != nil {
ln.mu.RLock()
throttle := time.Since(ln.update) < recordUpdateThrottle
ln.mu.RUnlock()

if throttle {
return n
}
}
// Record was invalidated a long time ago, sign a new copy

// Record was invalidated, sign a new copy.
ln.mu.Lock()
defer ln.mu.Unlock()

Expand All @@ -123,9 +108,19 @@ func (ln *LocalNode) Node() *Node {
if n = ln.cur.Load().(*Node); n != nil {
return n
}

// The initial sequence number is the current timestamp in milliseconds. To ensure
// that the initial sequence number will always be higher than any previous sequence
// number (assuming the clock is correct), we want to avoid updating the record faster
// than once per ms. So we need to sleep here until the next possible update time has
// arrived.
lastChange := time.Since(ln.update)
if lastChange < recordUpdateThrottle {
time.Sleep(recordUpdateThrottle - lastChange)
}

ln.sign()
ln.update = time.Now()

return ln.cur.Load().(*Node)
}

Expand Down Expand Up @@ -294,9 +289,6 @@ func predictAddr(t *netutil.IPTracker) (net.IP, int) {
}

func (ln *LocalNode) invalidate() {
if n := ln.cur.Load().(*Node); n != nil {
ln.prev.Store(n)
}
ln.cur.Store((*Node)(nil))
}

Expand Down Expand Up @@ -326,3 +318,12 @@ func (ln *LocalNode) bumpSeq() {
ln.seq++
ln.db.storeLocalSeq(ln.id, ln.seq)
}

// nowMilliseconds gives the current timestamp at millisecond precision.
func nowMilliseconds() uint64 {
ns := time.Now().UnixNano()
if ns < 0 {
return 0
}
return uint64(ns / 1000 / 1000)
}
93 changes: 19 additions & 74 deletions p2p/enode/localnode_test.go
Expand Up @@ -20,7 +20,6 @@ import (
"math/rand"
"net"
"testing"
"time"

"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enr"
Expand All @@ -34,10 +33,6 @@ func newLocalNodeForTesting() (*LocalNode, *DB) {
}

func TestLocalNode(t *testing.T) {
// Disable throttling for this test
defer func(throttle time.Duration) { recordUpdateThrottle = throttle }(recordUpdateThrottle)
recordUpdateThrottle = 0

ln, db := newLocalNodeForTesting()
defer db.Close()

Expand All @@ -54,49 +49,45 @@ func TestLocalNode(t *testing.T) {
}
}

// This test checks that the sequence number is persisted between restarts.
func TestLocalNodeSeqPersist(t *testing.T) {
// Disable throttling for this test
defer func(throttle time.Duration) { recordUpdateThrottle = throttle }(recordUpdateThrottle)
recordUpdateThrottle = 0

timestamp := uint64(time.Now().Unix())
timestamp := nowMilliseconds()

ln, db := newLocalNodeForTesting()
defer db.Close()

if s := ln.Node().Seq(); s != timestamp+1 {
t.Fatalf("wrong initial seq %d, want %d", s, timestamp+1)
initialSeq := ln.Node().Seq()
if initialSeq < timestamp {
t.Fatalf("wrong initial seq %d, want at least %d", initialSeq, timestamp)
}

ln.Set(enr.WithEntry("x", uint(1)))
if s := ln.Node().Seq(); s != timestamp+2 {
t.Fatalf("wrong seq %d after set, want 2", s)
if s := ln.Node().Seq(); s != initialSeq+1 {
t.Fatalf("wrong seq %d after set, want %d", s, initialSeq+1)
}

// Create a new instance, it should reload the sequence number.
// The number increases just after that because a new record is
// created without the "x" entry.
ln2 := NewLocalNode(db, ln.key)
if s := ln2.Node().Seq(); s != timestamp+3 {
t.Fatalf("wrong seq %d on new instance, want 3", s)
if s := ln2.Node().Seq(); s != initialSeq+2 {
t.Fatalf("wrong seq %d on new instance, want %d", s, initialSeq+2)
}

finalSeq := ln2.Node().Seq()

// Create a new instance with a different node key on the same database.
// This should reset the sequence number.
key, _ := crypto.GenerateKey()
ln3 := NewLocalNode(db, key)
if s := ln3.Node().Seq(); s != timestamp+1 {
t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
if s := ln3.Node().Seq(); s < finalSeq {
t.Fatalf("wrong seq %d on instance with changed key, want >= %d", s, finalSeq)
}
}

// This test checks behavior of the endpoint predictor.
func TestLocalNodeEndpoint(t *testing.T) {
// Disable throttling for this test
defer func(throttle time.Duration) { recordUpdateThrottle = throttle }(recordUpdateThrottle)
recordUpdateThrottle = 0

var (
timestamp = uint64(time.Now().Unix())
fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80}
predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81}
staticIP = net.IP{127, 0, 1, 2}
Expand All @@ -107,78 +98,32 @@ func TestLocalNodeEndpoint(t *testing.T) {
// Nothing is set initially.
assert.Equal(t, net.IP(nil), ln.Node().IP())
assert.Equal(t, 0, ln.Node().UDP())
assert.Equal(t, uint64(timestamp+1), ln.Node().Seq())
initialSeq := ln.Node().Seq()

// Set up fallback address.
ln.SetFallbackIP(fallback.IP)
ln.SetFallbackUDP(fallback.Port)
assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(timestamp+2), ln.Node().Seq())
assert.Equal(t, initialSeq+1, ln.Node().Seq())

// Add endpoint statements from random hosts.
for i := 0; i < iptrackMinStatements; i++ {
assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(timestamp+2), ln.Node().Seq())
assert.Equal(t, initialSeq+1, ln.Node().Seq())

from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
rand.Read(from.IP)
ln.UDPEndpointStatement(from, predicted)
}
assert.Equal(t, predicted.IP, ln.Node().IP())
assert.Equal(t, predicted.Port, ln.Node().UDP())
assert.Equal(t, uint64(timestamp+3), ln.Node().Seq())
assert.Equal(t, initialSeq+2, ln.Node().Seq())

// Static IP overrides prediction.
ln.SetStaticIP(staticIP)
assert.Equal(t, staticIP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(timestamp+4), ln.Node().Seq())
}

// Tests that multiple updates to a node record are throttled until the specified
// timeout expires.
func TestLocalNodeThrottling(t *testing.T) {
var n uint

// Create and retrieve an initial node record to force an update
ln, db := newLocalNodeForTesting()
defer db.Close()

ln.Set(enr.WithEntry("x", uint(3)))
ln.Set(enr.WithEntry("y", uint(2)))
ln.Set(enr.WithEntry("z", uint(1)))

timestamp := uint64(time.Now().Unix())
if s := ln.Node().Seq(); s != timestamp+1 {
t.Fatalf("wrong initial seq %d, want %d", s, timestamp+1)
}
ln.Node().Load(enr.WithEntry("x", &n))
assert.Equal(t, uint(3), n)
ln.Node().Load(enr.WithEntry("y", &n))
assert.Equal(t, uint(2), n)
ln.Node().Load(enr.WithEntry("z", &n))
assert.Equal(t, uint(1), n)

// Trigger a set of updates and ensure they don't publish yet
ln.Set(enr.WithEntry("x", uint(1)))
ln.Delete(enr.WithEntry("y", uint(2)))
ln.Set(enr.WithEntry("z", uint(3)))

ln.Node().Load(enr.WithEntry("x", &n))
assert.Equal(t, uint(3), n)
ln.Node().Load(enr.WithEntry("y", &n))
assert.Equal(t, uint(2), n)
ln.Node().Load(enr.WithEntry("z", &n))
assert.Equal(t, uint(1), n)

// Wait for the timeout to trigger and check again
time.Sleep(recordUpdateThrottle)

ln.Node().Load(enr.WithEntry("x", &n))
assert.Equal(t, uint(1), n)
ln.Node().Load(enr.WithEntry("z", &n))
assert.Equal(t, uint(3), n)
assert.EqualError(t, ln.Node().Load(enr.WithEntry("y", &n)), "missing ENR key \"y\"")
assert.Equal(t, initialSeq+3, ln.Node().Seq())
}
2 changes: 1 addition & 1 deletion p2p/enode/nodedb.go
Expand Up @@ -434,7 +434,7 @@ func (db *DB) localSeq(id ID) uint64 {
if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 {
return seq
}
return uint64(time.Now().Unix())
return nowMilliseconds()
}

// storeLocalSeq stores the local record sequence counter.
Expand Down

0 comments on commit 34c4e57

Please sign in to comment.