diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go index 98273ccfc9dc5..1b4cfeb3da7af 100644 --- a/eth/protocols/eth/peer.go +++ b/eth/protocols/eth/peer.go @@ -75,12 +75,12 @@ type Peer struct { head common.Hash // Latest advertised head block hash td *big.Int // Latest advertised head block total difficulty - knownBlocks mapset.Set // Set of block hashes known to be known by this peer + knownBlocks *knownCache // Set of block hashes known to be known by this peer queuedBlocks chan *blockPropagation // Queue of blocks to broadcast to the peer queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer txpool TxPool // Transaction pool used by the broadcasters for liveness checks - knownTxs mapset.Set // Set of transaction hashes known to be known by this peer + knownTxs *knownCache // Set of transaction hashes known to be known by this peer txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests @@ -96,8 +96,8 @@ func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Pe Peer: p, rw: rw, version: version, - knownTxs: mapset.NewSet(), - knownBlocks: mapset.NewSet(), + knownTxs: newKnownCache(maxKnownTxs), + knownBlocks: newKnownCache(maxKnownBlocks), queuedBlocks: make(chan *blockPropagation, maxQueuedBlocks), queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns), txBroadcast: make(chan []common.Hash), @@ -162,9 +162,6 @@ func (p *Peer) KnownTransaction(hash common.Hash) bool { // never be propagated to this particular peer. func (p *Peer) markBlock(hash common.Hash) { // If we reached the memory allowance, drop a previously known block hash - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } p.knownBlocks.Add(hash) } @@ -172,9 +169,6 @@ func (p *Peer) markBlock(hash common.Hash) { // will never be propagated to this particular peer. func (p *Peer) markTransaction(hash common.Hash) { // If we reached the memory allowance, drop a previously known transaction hash - for p.knownTxs.Cardinality() >= maxKnownTxs { - p.knownTxs.Pop() - } p.knownTxs.Add(hash) } @@ -189,9 +183,6 @@ func (p *Peer) markTransaction(hash common.Hash) { // tests that directly send messages without having to do the asyn queueing. func (p *Peer) SendTransactions(txs types.Transactions) error { // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) { - p.knownTxs.Pop() - } for _, tx := range txs { p.knownTxs.Add(tx.Hash()) } @@ -205,12 +196,7 @@ func (p *Peer) AsyncSendTransactions(hashes []common.Hash) { select { case p.txBroadcast <- hashes: // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } + p.knownTxs.Add(hashes...) case <-p.term: p.Log().Debug("Dropping transaction propagation", "count", len(hashes)) } @@ -224,12 +210,7 @@ func (p *Peer) AsyncSendTransactions(hashes []common.Hash) { // not be managed directly. func (p *Peer) sendPooledTransactionHashes(hashes []common.Hash) error { // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } + p.knownTxs.Add(hashes...) return p2p.Send(p.rw, NewPooledTransactionHashesMsg, NewPooledTransactionHashesPacket(hashes)) } @@ -240,12 +221,7 @@ func (p *Peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) { select { case p.txAnnounce <- hashes: // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } + p.knownTxs.Add(hashes...) case <-p.term: p.Log().Debug("Dropping transaction announcement", "count", len(hashes)) } @@ -254,12 +230,8 @@ func (p *Peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) { // ReplyPooledTransactionsRLP is the eth/66 version of SendPooledTransactionsRLP. func (p *Peer) ReplyPooledTransactionsRLP(id uint64, hashes []common.Hash, txs []rlp.RawValue) error { // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } + p.knownTxs.Add(hashes...) + // Not packed into PooledTransactionsPacket to avoid RLP decoding return p2p.Send(p.rw, PooledTransactionsMsg, PooledTransactionsRLPPacket66{ RequestId: id, @@ -271,12 +243,8 @@ func (p *Peer) ReplyPooledTransactionsRLP(id uint64, hashes []common.Hash, txs [ // a hash notification. func (p *Peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error { // Mark all the block hashes as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) { - p.knownBlocks.Pop() - } - for _, hash := range hashes { - p.knownBlocks.Add(hash) - } + p.knownBlocks.Add(hashes...) + request := make(NewBlockHashesPacket, len(hashes)) for i := 0; i < len(hashes); i++ { request[i].Hash = hashes[i] @@ -292,9 +260,6 @@ func (p *Peer) AsyncSendNewBlockHash(block *types.Block) { select { case p.queuedBlockAnns <- block: // Mark all the block hash as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } p.knownBlocks.Add(block.Hash()) default: p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash()) @@ -304,9 +269,6 @@ func (p *Peer) AsyncSendNewBlockHash(block *types.Block) { // SendNewBlock propagates an entire block to a remote peer. func (p *Peer) SendNewBlock(block *types.Block, td *big.Int) error { // Mark all the block hash as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } p.knownBlocks.Add(block.Hash()) return p2p.Send(p.rw, NewBlockMsg, &NewBlockPacket{ Block: block, @@ -320,9 +282,6 @@ func (p *Peer) AsyncSendNewBlock(block *types.Block, td *big.Int) { select { case p.queuedBlocks <- &blockPropagation{block: block, td: td}: // Mark all the block hash as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } p.knownBlocks.Add(block.Hash()) default: p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash()) @@ -465,3 +424,37 @@ func (p *Peer) RequestTxs(hashes []common.Hash) error { GetPooledTransactionsPacket: hashes, }) } + +// knownCache is a cache for known hashes. +type knownCache struct { + hashes mapset.Set + max int +} + +// newKnownCache creates a new knownCache with a max capacity. +func newKnownCache(max int) *knownCache { + return &knownCache{ + max: max, + hashes: mapset.NewSet(), + } +} + +// Add adds a list of elements to the set. +func (k *knownCache) Add(hashes ...common.Hash) { + for k.hashes.Cardinality() > max(0, k.max-len(hashes)) { + k.hashes.Pop() + } + for _, hash := range hashes { + k.hashes.Add(hash) + } +} + +// Contains returns whether the given item is in the set. +func (k *knownCache) Contains(hash common.Hash) bool { + return k.hashes.Contains(hash) +} + +// Cardinality returns the number of elements in the set. +func (k *knownCache) Cardinality() int { + return k.hashes.Cardinality() +} diff --git a/eth/protocols/eth/peer_test.go b/eth/protocols/eth/peer_test.go index 70e9959f82fd6..fc93443708fd3 100644 --- a/eth/protocols/eth/peer_test.go +++ b/eth/protocols/eth/peer_test.go @@ -21,7 +21,9 @@ package eth import ( "crypto/rand" + "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" ) @@ -59,3 +61,28 @@ func (p *testPeer) close() { p.Peer.Close() p.app.Close() } + +func TestPeerSet(t *testing.T) { + size := 5 + s := newKnownCache(size) + + // add 10 items + for i := 0; i < size*2; i++ { + s.Add(common.Hash{byte(i)}) + } + + if s.Cardinality() != size { + t.Fatalf("wrong size, expected %d but found %d", size, s.Cardinality()) + } + + vals := []common.Hash{} + for i := 10; i < 20; i++ { + vals = append(vals, common.Hash{byte(i)}) + } + + // add item in batch + s.Add(vals...) + if s.Cardinality() < size { + t.Fatalf("bad size") + } +}