Skip to content

Commit

Permalink
p2p/enode: use unix timestamp as base ENR sequence number
Browse files Browse the repository at this point in the history
  • Loading branch information
karalabe committed Jul 31, 2019
1 parent 96ab8e1 commit 5e64cc0
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 16 deletions.
46 changes: 42 additions & 4 deletions p2p/enode/localnode.go
Expand Up @@ -38,17 +38,26 @@ const (
iptrackContactWindow = 10 * time.Minute
)

var (
// recordUpdateThrottle is the time needed to wait between two updates to an ENR
// record. Modifications in between are queued up and published together.
recordUpdateThrottle = time.Second
)

// LocalNode produces the signed node record of a local node, i.e. a node run in the
// current process. Setting ENR entries via the Set method updates the record. A new version
// of the record is signed on demand when the Node method is called.
type LocalNode struct {
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date.
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date
prev atomic.Value // holds a non-nil node pointer while the record is thottled on an update
update time.Time // timestamp when the record was last updated to prevent sequence number bloat

id ID
key *ecdsa.PrivateKey
db *DB

// everything below is protected by a lock
mu sync.Mutex
mu sync.RWMutex
seq uint64
entries map[string]enr.Entry
endpoint4 lnEndpoint
Expand Down Expand Up @@ -76,7 +85,9 @@ func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
},
}
ln.seq = db.localSeq(ln.id)
ln.invalidate()
ln.prev.Store((*Node)(nil))
ln.cur.Store((*Node)(nil))

return ln
}

Expand All @@ -87,14 +98,34 @@ func (ln *LocalNode) Database() *DB {

// Node returns the current version of the local node record.
func (ln *LocalNode) Node() *Node {
// If we have a valid record, return that
n := ln.cur.Load().(*Node)
if n != nil {
return n
}
// Record was invalidated, sign a new copy.
// Record was invalidated, check for a previous version and use that unless we
// are allowed to update.
if n = ln.prev.Load().(*Node); n != nil {
ln.mu.RLock()
throttle := time.Since(ln.update) < recordUpdateThrottle
ln.mu.RUnlock()

if throttle {
return n
}
}
// Record was invalidated a long time ago, sign a new copy
ln.mu.Lock()
defer ln.mu.Unlock()

// Double check the current record, since multiple goroutines might be waiting
// on the write mutex.
if n = ln.cur.Load().(*Node); n != nil {
return n
}
ln.sign()
ln.update = time.Now()

return ln.cur.Load().(*Node)
}

Expand All @@ -114,6 +145,10 @@ func (ln *LocalNode) ID() ID {
// Set puts the given entry into the local record, overwriting any existing value.
// Use Set*IP and SetFallbackUDP to set IP addresses and UDP port, otherwise they'll
// be overwritten by the endpoint predictor.
//
// Since node record updates are throttled to one per second, Set is asynchronous.
// Any update will be queued up and published when at least one second passes from
// the last change.
func (ln *LocalNode) Set(e enr.Entry) {
ln.mu.Lock()
defer ln.mu.Unlock()
Expand Down Expand Up @@ -259,6 +294,9 @@ func predictAddr(t *netutil.IPTracker) (net.IP, int) {
}

func (ln *LocalNode) invalidate() {
if n := ln.cur.Load().(*Node); n != nil {
ln.prev.Store(n)
}
ln.cur.Store((*Node)(nil))
}

Expand Down
82 changes: 72 additions & 10 deletions p2p/enode/localnode_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"math/rand"
"net"
"testing"
"time"

"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enr"
Expand All @@ -33,6 +34,10 @@ func newLocalNodeForTesting() (*LocalNode, *DB) {
}

func TestLocalNode(t *testing.T) {
// Disable throttling for this test
defer func(throttle time.Duration) { recordUpdateThrottle = throttle }(recordUpdateThrottle)
recordUpdateThrottle = 0

ln, db := newLocalNodeForTesting()
defer db.Close()

Expand All @@ -50,37 +55,48 @@ func TestLocalNode(t *testing.T) {
}

func TestLocalNodeSeqPersist(t *testing.T) {
// Disable throttling for this test
defer func(throttle time.Duration) { recordUpdateThrottle = throttle }(recordUpdateThrottle)
recordUpdateThrottle = 0

timestamp := uint64(time.Now().Unix())

ln, db := newLocalNodeForTesting()
defer db.Close()

if s := ln.Node().Seq(); s != 1 {
t.Fatalf("wrong initial seq %d, want 1", s)
if s := ln.Node().Seq(); s != timestamp+1 {
t.Fatalf("wrong initial seq %d, want %d", s, timestamp+1)
}
ln.Set(enr.WithEntry("x", uint(1)))
if s := ln.Node().Seq(); s != 2 {
if s := ln.Node().Seq(); s != timestamp+2 {
t.Fatalf("wrong seq %d after set, want 2", s)
}

// Create a new instance, it should reload the sequence number.
// The number increases just after that because a new record is
// created without the "x" entry.
ln2 := NewLocalNode(db, ln.key)
if s := ln2.Node().Seq(); s != 3 {
if s := ln2.Node().Seq(); s != timestamp+3 {
t.Fatalf("wrong seq %d on new instance, want 3", s)
}

// Create a new instance with a different node key on the same database.
// This should reset the sequence number.
key, _ := crypto.GenerateKey()
ln3 := NewLocalNode(db, key)
if s := ln3.Node().Seq(); s != 1 {
if s := ln3.Node().Seq(); s != timestamp+1 {
t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
}
}

// This test checks behavior of the endpoint predictor.
func TestLocalNodeEndpoint(t *testing.T) {
// Disable throttling for this test
defer func(throttle time.Duration) { recordUpdateThrottle = throttle }(recordUpdateThrottle)
recordUpdateThrottle = 0

var (
timestamp = uint64(time.Now().Unix())
fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80}
predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81}
staticIP = net.IP{127, 0, 1, 2}
Expand All @@ -91,32 +107,78 @@ func TestLocalNodeEndpoint(t *testing.T) {
// Nothing is set initially.
assert.Equal(t, net.IP(nil), ln.Node().IP())
assert.Equal(t, 0, ln.Node().UDP())
assert.Equal(t, uint64(1), ln.Node().Seq())
assert.Equal(t, uint64(timestamp+1), ln.Node().Seq())

// Set up fallback address.
ln.SetFallbackIP(fallback.IP)
ln.SetFallbackUDP(fallback.Port)
assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(2), ln.Node().Seq())
assert.Equal(t, uint64(timestamp+2), ln.Node().Seq())

// Add endpoint statements from random hosts.
for i := 0; i < iptrackMinStatements; i++ {
assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(2), ln.Node().Seq())
assert.Equal(t, uint64(timestamp+2), ln.Node().Seq())

from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
rand.Read(from.IP)
ln.UDPEndpointStatement(from, predicted)
}
assert.Equal(t, predicted.IP, ln.Node().IP())
assert.Equal(t, predicted.Port, ln.Node().UDP())
assert.Equal(t, uint64(3), ln.Node().Seq())
assert.Equal(t, uint64(timestamp+3), ln.Node().Seq())

// Static IP overrides prediction.
ln.SetStaticIP(staticIP)
assert.Equal(t, staticIP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, uint64(4), ln.Node().Seq())
assert.Equal(t, uint64(timestamp+4), ln.Node().Seq())
}

// Tests that multiple updates to a node record are throttled until the specified
// timeout expires.
func TestLocalNodeThrottling(t *testing.T) {
var n uint

// Create and retrieve an initial node record to force an update
ln, db := newLocalNodeForTesting()
defer db.Close()

ln.Set(enr.WithEntry("x", uint(3)))
ln.Set(enr.WithEntry("y", uint(2)))
ln.Set(enr.WithEntry("z", uint(1)))

timestamp := uint64(time.Now().Unix())
if s := ln.Node().Seq(); s != timestamp+1 {
t.Fatalf("wrong initial seq %d, want %d", s, timestamp+1)
}
ln.Node().Load(enr.WithEntry("x", &n))
assert.Equal(t, uint(3), n)
ln.Node().Load(enr.WithEntry("y", &n))
assert.Equal(t, uint(2), n)
ln.Node().Load(enr.WithEntry("z", &n))
assert.Equal(t, uint(1), n)

// Trigger a set of updates and ensure they don't publish yet
ln.Set(enr.WithEntry("x", uint(1)))
ln.Delete(enr.WithEntry("y", uint(2)))
ln.Set(enr.WithEntry("z", uint(3)))

ln.Node().Load(enr.WithEntry("x", &n))
assert.Equal(t, uint(3), n)
ln.Node().Load(enr.WithEntry("y", &n))
assert.Equal(t, uint(2), n)
ln.Node().Load(enr.WithEntry("z", &n))
assert.Equal(t, uint(1), n)

// Wait for the timeout to trigger and check again
time.Sleep(recordUpdateThrottle)

ln.Node().Load(enr.WithEntry("x", &n))
assert.Equal(t, uint(1), n)
ln.Node().Load(enr.WithEntry("z", &n))
assert.Equal(t, uint(3), n)
assert.EqualError(t, ln.Node().Load(enr.WithEntry("y", &n)), "missing ENR key \"y\"")
}
9 changes: 7 additions & 2 deletions p2p/enode/nodedb.go
Expand Up @@ -378,9 +378,14 @@ func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error {
return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails))
}

// LocalSeq retrieves the local record sequence counter.
// LocalSeq retrieves the local record sequence counter, defaulting to the current
// timestamp if no previous exists. This ensures that wiping all data associated
// with a node (apart from its key) will not generate already used sequence nums.
func (db *DB) localSeq(id ID) uint64 {
return db.fetchUint64(localItemKey(id, dbLocalSeq))
if seq := db.fetchUint64(localItemKey(id, dbLocalSeq)); seq > 0 {
return seq
}
return uint64(time.Now().Unix())
}

// storeLocalSeq stores the local record sequence counter.
Expand Down

0 comments on commit 5e64cc0

Please sign in to comment.