From 9e17648d8c991c7243ca99123f010c152b87fce7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Fri, 10 Sep 2021 10:55:48 +0300 Subject: [PATCH] les: duplicate downloader and fetcher to allow progressive refactoring --- eth/api_backend.go | 6 +- ethstats/ethstats.go | 8 +- graphql/graphql.go | 2 +- internal/ethapi/api.go | 2 +- internal/ethapi/backend.go | 5 +- les/api_backend.go | 6 +- les/api_test.go | 7 +- les/client.go | 2 +- les/client_handler.go | 2 +- les/downloader/api.go | 166 +++ les/downloader/downloader.go | 2014 +++++++++++++++++++++++++++++ les/downloader/downloader_test.go | 1622 +++++++++++++++++++++++ les/downloader/events.go | 25 + les/downloader/metrics.go | 45 + les/downloader/modes.go | 81 ++ les/downloader/peer.go | 501 +++++++ les/downloader/queue.go | 913 +++++++++++++ les/downloader/queue_test.go | 452 +++++++ les/downloader/resultstore.go | 194 +++ les/downloader/statesync.go | 615 +++++++++ les/downloader/testchain_test.go | 230 ++++ les/downloader/types.go | 79 ++ les/fetcher.go | 2 +- les/fetcher/block_fetcher.go | 889 +++++++++++++ les/fetcher/block_fetcher_test.go | 896 +++++++++++++ les/handler_test.go | 2 +- les/sync.go | 2 +- 27 files changed, 8746 insertions(+), 22 deletions(-) create mode 100644 les/downloader/api.go create mode 100644 les/downloader/downloader.go create mode 100644 les/downloader/downloader_test.go create mode 100644 les/downloader/events.go create mode 100644 les/downloader/metrics.go create mode 100644 les/downloader/modes.go create mode 100644 les/downloader/peer.go create mode 100644 les/downloader/queue.go create mode 100644 les/downloader/queue_test.go create mode 100644 les/downloader/resultstore.go create mode 100644 les/downloader/statesync.go create mode 100644 les/downloader/testchain_test.go create mode 100644 les/downloader/types.go create mode 100644 les/fetcher/block_fetcher.go create mode 100644 les/fetcher/block_fetcher_test.go diff --git a/eth/api_backend.go b/eth/api_backend.go index 49de70e21069e..7b40a7edd3a11 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -21,6 +21,7 @@ import ( "errors" "math/big" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" @@ -30,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -279,8 +279,8 @@ func (b *EthAPIBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.S return b.eth.TxPool().SubscribeNewTxsEvent(ch) } -func (b *EthAPIBackend) Downloader() *downloader.Downloader { - return b.eth.Downloader() +func (b *EthAPIBackend) SyncProgress() ethereum.SyncProgress { + return b.eth.Downloader().Progress() } func (b *EthAPIBackend) SuggestGasTipCap(ctx context.Context) (*big.Int, error) { diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index 148359110c048..55c0c880f33c9 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -30,12 +30,12 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/downloader" ethproto "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/les" @@ -67,7 +67,7 @@ type backend interface { HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error) GetTd(ctx context.Context, hash common.Hash) *big.Int Stats() (pending int, queued int) - Downloader() *downloader.Downloader + SyncProgress() ethereum.SyncProgress } // fullNodeBackend encompasses the functionality necessary for a full node @@ -777,7 +777,7 @@ func (s *Service) reportStats(conn *connWrapper) error { mining = fullBackend.Miner().Mining() hashrate = int(fullBackend.Miner().Hashrate()) - sync := fullBackend.Downloader().Progress() + sync := fullBackend.SyncProgress() syncing = fullBackend.CurrentHeader().Number.Uint64() >= sync.HighestBlock price, _ := fullBackend.SuggestGasTipCap(context.Background()) @@ -786,7 +786,7 @@ func (s *Service) reportStats(conn *connWrapper) error { gasprice += int(basefee.Uint64()) } } else { - sync := s.backend.Downloader().Progress() + sync := s.backend.SyncProgress() syncing = s.backend.CurrentHeader().Number.Uint64() >= sync.HighestBlock } // Assemble the node stats and send it to the server diff --git a/graphql/graphql.go b/graphql/graphql.go index d35994234ea57..4dd96c4b9db18 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -1248,7 +1248,7 @@ func (s *SyncState) KnownStates() *hexutil.Uint64 { // - pulledStates: number of state entries processed until now // - knownStates: number of known state entries that still need to be pulled func (r *Resolver) Syncing() (*SyncState, error) { - progress := r.backend.Downloader().Progress() + progress := r.backend.SyncProgress() // Return not syncing if the synchronisation already completed if progress.CurrentBlock >= progress.HighestBlock { diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 9a82824ada14f..6997f2c82878d 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -122,7 +122,7 @@ func (s *PublicEthereumAPI) FeeHistory(ctx context.Context, blockCount rpc.Decim // - pulledStates: number of state entries processed until now // - knownStates: number of known state entries that still need to be pulled func (s *PublicEthereumAPI) Syncing() (interface{}, error) { - progress := s.b.Downloader().Progress() + progress := s.b.SyncProgress() // Return not syncing if the synchronisation already completed if progress.CurrentBlock >= progress.HighestBlock { diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 9954545821932..1624f49635b33 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -21,6 +21,7 @@ import ( "context" "math/big" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" @@ -29,7 +30,6 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/params" @@ -40,7 +40,8 @@ import ( // both full and light clients) with access to necessary functions. type Backend interface { // General Ethereum API - Downloader() *downloader.Downloader + SyncProgress() ethereum.SyncProgress + SuggestGasTipCap(ctx context.Context) (*big.Int, error) FeeHistory(ctx context.Context, blockCount int, lastBlock rpc.BlockNumber, rewardPercentiles []float64) (*big.Int, [][]*big.Int, []*big.Int, []float64, error) ChainDb() ethdb.Database diff --git a/les/api_backend.go b/les/api_backend.go index 9c80270da0a11..e12984cb49e36 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -21,6 +21,7 @@ import ( "errors" "math/big" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" @@ -30,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -257,8 +257,8 @@ func (b *LesApiBackend) SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEven return b.eth.blockchain.SubscribeRemovedLogsEvent(ch) } -func (b *LesApiBackend) Downloader() *downloader.Downloader { - return b.eth.Downloader() +func (b *LesApiBackend) SyncProgress() ethereum.SyncProgress { + return b.eth.Downloader().Progress() } func (b *LesApiBackend) ProtocolVersion() int { diff --git a/les/api_test.go b/les/api_test.go index f7017c5d982e9..6a19b0fe4fbf9 100644 --- a/les/api_test.go +++ b/les/api_test.go @@ -32,8 +32,9 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/eth" - "github.com/ethereum/go-ethereum/eth/downloader" + ethdownloader "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/ethconfig" + "github.com/ethereum/go-ethereum/les/downloader" "github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/node" @@ -494,14 +495,14 @@ func testSim(t *testing.T, serverCount, clientCount int, serverDir, clientDir [] func newLesClientService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { config := ethconfig.Defaults - config.SyncMode = downloader.LightSync + config.SyncMode = (ethdownloader.SyncMode)(downloader.LightSync) config.Ethash.PowMode = ethash.ModeFake return New(stack, &config) } func newLesServerService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { config := ethconfig.Defaults - config.SyncMode = downloader.FullSync + config.SyncMode = (ethdownloader.SyncMode)(downloader.FullSync) config.LightServ = testServerCapacity config.LightPeers = testMaxClients ethereum, err := eth.New(stack, &config) diff --git a/les/client.go b/les/client.go index 1d8a2c6f9a072..5d07c783e99d6 100644 --- a/les/client.go +++ b/les/client.go @@ -30,12 +30,12 @@ import ( "github.com/ethereum/go-ethereum/core/bloombits" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/ethconfig" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/les/downloader" "github.com/ethereum/go-ethereum/les/vflux" vfc "github.com/ethereum/go-ethereum/les/vflux/client" "github.com/ethereum/go-ethereum/light" diff --git a/les/client_handler.go b/les/client_handler.go index 4a550b20745c4..9583bd57ca059 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -28,8 +28,8 @@ import ( "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/les/downloader" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" diff --git a/les/downloader/api.go b/les/downloader/api.go new file mode 100644 index 0000000000000..2024d23deade6 --- /dev/null +++ b/les/downloader/api.go @@ -0,0 +1,166 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "context" + "sync" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/rpc" +) + +// PublicDownloaderAPI provides an API which gives information about the current synchronisation status. +// It offers only methods that operates on data that can be available to anyone without security risks. +type PublicDownloaderAPI struct { + d *Downloader + mux *event.TypeMux + installSyncSubscription chan chan interface{} + uninstallSyncSubscription chan *uninstallSyncSubscriptionRequest +} + +// NewPublicDownloaderAPI create a new PublicDownloaderAPI. The API has an internal event loop that +// listens for events from the downloader through the global event mux. In case it receives one of +// these events it broadcasts it to all syncing subscriptions that are installed through the +// installSyncSubscription channel. +func NewPublicDownloaderAPI(d *Downloader, m *event.TypeMux) *PublicDownloaderAPI { + api := &PublicDownloaderAPI{ + d: d, + mux: m, + installSyncSubscription: make(chan chan interface{}), + uninstallSyncSubscription: make(chan *uninstallSyncSubscriptionRequest), + } + + go api.eventLoop() + + return api +} + +// eventLoop runs a loop until the event mux closes. It will install and uninstall new +// sync subscriptions and broadcasts sync status updates to the installed sync subscriptions. +func (api *PublicDownloaderAPI) eventLoop() { + var ( + sub = api.mux.Subscribe(StartEvent{}, DoneEvent{}, FailedEvent{}) + syncSubscriptions = make(map[chan interface{}]struct{}) + ) + + for { + select { + case i := <-api.installSyncSubscription: + syncSubscriptions[i] = struct{}{} + case u := <-api.uninstallSyncSubscription: + delete(syncSubscriptions, u.c) + close(u.uninstalled) + case event := <-sub.Chan(): + if event == nil { + return + } + + var notification interface{} + switch event.Data.(type) { + case StartEvent: + notification = &SyncingResult{ + Syncing: true, + Status: api.d.Progress(), + } + case DoneEvent, FailedEvent: + notification = false + } + // broadcast + for c := range syncSubscriptions { + c <- notification + } + } + } +} + +// Syncing provides information when this nodes starts synchronising with the Ethereum network and when it's finished. +func (api *PublicDownloaderAPI) Syncing(ctx context.Context) (*rpc.Subscription, error) { + notifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported + } + + rpcSub := notifier.CreateSubscription() + + go func() { + statuses := make(chan interface{}) + sub := api.SubscribeSyncStatus(statuses) + + for { + select { + case status := <-statuses: + notifier.Notify(rpcSub.ID, status) + case <-rpcSub.Err(): + sub.Unsubscribe() + return + case <-notifier.Closed(): + sub.Unsubscribe() + return + } + } + }() + + return rpcSub, nil +} + +// SyncingResult provides information about the current synchronisation status for this node. +type SyncingResult struct { + Syncing bool `json:"syncing"` + Status ethereum.SyncProgress `json:"status"` +} + +// uninstallSyncSubscriptionRequest uninstalles a syncing subscription in the API event loop. +type uninstallSyncSubscriptionRequest struct { + c chan interface{} + uninstalled chan interface{} +} + +// SyncStatusSubscription represents a syncing subscription. +type SyncStatusSubscription struct { + api *PublicDownloaderAPI // register subscription in event loop of this api instance + c chan interface{} // channel where events are broadcasted to + unsubOnce sync.Once // make sure unsubscribe logic is executed once +} + +// Unsubscribe uninstalls the subscription from the DownloadAPI event loop. +// The status channel that was passed to subscribeSyncStatus isn't used anymore +// after this method returns. +func (s *SyncStatusSubscription) Unsubscribe() { + s.unsubOnce.Do(func() { + req := uninstallSyncSubscriptionRequest{s.c, make(chan interface{})} + s.api.uninstallSyncSubscription <- &req + + for { + select { + case <-s.c: + // drop new status events until uninstall confirmation + continue + case <-req.uninstalled: + return + } + } + }) +} + +// SubscribeSyncStatus creates a subscription that will broadcast new synchronisation updates. +// The given channel must receive interface values, the result can either +func (api *PublicDownloaderAPI) SubscribeSyncStatus(status chan interface{}) *SyncStatusSubscription { + api.installSyncSubscription <- status + return &SyncStatusSubscription{api: api, c: status} +} diff --git a/les/downloader/downloader.go b/les/downloader/downloader.go new file mode 100644 index 0000000000000..e7dfc4158e0ed --- /dev/null +++ b/les/downloader/downloader.go @@ -0,0 +1,2014 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// This is a temporary package whilst working on the eth/66 blocking refactors. +// After that work is done, les needs to be refactored to use the new package, +// or alternatively use a stripped down version of it. Either way, we need to +// keep the changes scoped so duplicating temporarily seems the sanest. +package downloader + +import ( + "errors" + "fmt" + "math/big" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/eth/protocols/snap" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" +) + +var ( + MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request + MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly + MaxReceiptFetch = 256 // Amount of transaction receipts to allow fetching per request + MaxStateFetch = 384 // Amount of node state values to allow fetching per request + + maxQueuedHeaders = 32 * 1024 // [eth/62] Maximum number of headers to queue for import (DOS protection) + maxHeadersProcess = 2048 // Number of header download results to import at once into the chain + maxResultsProcess = 2048 // Number of content download results to import at once into the chain + fullMaxForkAncestry uint64 = params.FullImmutabilityThreshold // Maximum chain reorganisation (locally redeclared so tests can reduce it) + lightMaxForkAncestry uint64 = params.LightImmutabilityThreshold // Maximum chain reorganisation (locally redeclared so tests can reduce it) + + reorgProtThreshold = 48 // Threshold number of recent blocks to disable mini reorg protection + reorgProtHeaderDelay = 2 // Number of headers to delay delivering to cover mini reorgs + + fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync + fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected + fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it + fsHeaderContCheck = 3 * time.Second // Time interval to check for header continuations during state download + fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync +) + +var ( + errBusy = errors.New("busy") + errUnknownPeer = errors.New("peer is unknown or unhealthy") + errBadPeer = errors.New("action from bad peer ignored") + errStallingPeer = errors.New("peer is stalling") + errUnsyncedPeer = errors.New("unsynced peer") + errNoPeers = errors.New("no peers to keep download active") + errTimeout = errors.New("timeout") + errEmptyHeaderSet = errors.New("empty header set by peer") + errPeersUnavailable = errors.New("no peers available or all tried for download") + errInvalidAncestor = errors.New("retrieved ancestor is invalid") + errInvalidChain = errors.New("retrieved hash chain is invalid") + errInvalidBody = errors.New("retrieved block body is invalid") + errInvalidReceipt = errors.New("retrieved receipt is invalid") + errCancelStateFetch = errors.New("state data download canceled (requested)") + errCancelContentProcessing = errors.New("content processing canceled (requested)") + errCanceled = errors.New("syncing canceled (requested)") + errNoSyncActive = errors.New("no sync active") + errTooOld = errors.New("peer's protocol version too old") + errNoAncestorFound = errors.New("no common ancestor found") +) + +type Downloader struct { + mode uint32 // Synchronisation mode defining the strategy used (per sync cycle), use d.getMode() to get the SyncMode + mux *event.TypeMux // Event multiplexer to announce sync operation events + + checkpoint uint64 // Checkpoint block number to enforce head against (e.g. fast sync) + genesis uint64 // Genesis block number to limit sync to (e.g. light client CHT) + queue *queue // Scheduler for selecting the hashes to download + peers *peerSet // Set of active peers from which download can proceed + + stateDB ethdb.Database // Database to state sync into (and deduplicate via) + stateBloom *trie.SyncBloom // Bloom filter for fast trie node and contract code existence checks + + // Statistics + syncStatsChainOrigin uint64 // Origin block number where syncing started at + syncStatsChainHeight uint64 // Highest block number known when syncing started + syncStatsState stateSyncStats + syncStatsLock sync.RWMutex // Lock protecting the sync stats fields + + lightchain LightChain + blockchain BlockChain + + // Callbacks + dropPeer peerDropFn // Drops a peer for misbehaving + + // Status + synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing + synchronising int32 + notified int32 + committed int32 + ancientLimit uint64 // The maximum block number which can be regarded as ancient data. + + // Channels + headerCh chan dataPack // Channel receiving inbound block headers + bodyCh chan dataPack // Channel receiving inbound block bodies + receiptCh chan dataPack // Channel receiving inbound receipts + bodyWakeCh chan bool // Channel to signal the block body fetcher of new tasks + receiptWakeCh chan bool // Channel to signal the receipt fetcher of new tasks + headerProcCh chan []*types.Header // Channel to feed the header processor new tasks + + // State sync + pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root + pivotLock sync.RWMutex // Lock protecting pivot header reads from updates + + snapSync bool // Whether to run state sync over the snap protocol + SnapSyncer *snap.Syncer // TODO(karalabe): make private! hack for now + stateSyncStart chan *stateSync + trackStateReq chan *stateReq + stateCh chan dataPack // Channel receiving inbound node state data + + // Cancellation and termination + cancelPeer string // Identifier of the peer currently being used as the master (cancel on drop) + cancelCh chan struct{} // Channel to cancel mid-flight syncs + cancelLock sync.RWMutex // Lock to protect the cancel channel and peer in delivers + cancelWg sync.WaitGroup // Make sure all fetcher goroutines have exited. + + quitCh chan struct{} // Quit channel to signal termination + quitLock sync.Mutex // Lock to prevent double closes + + // Testing hooks + syncInitHook func(uint64, uint64) // Method to call upon initiating a new sync run + bodyFetchHook func([]*types.Header) // Method to call upon starting a block body fetch + receiptFetchHook func([]*types.Header) // Method to call upon starting a receipt fetch + chainInsertHook func([]*fetchResult) // Method to call upon inserting a chain of blocks (possibly in multiple invocations) +} + +// LightChain encapsulates functions required to synchronise a light chain. +type LightChain interface { + // HasHeader verifies a header's presence in the local chain. + HasHeader(common.Hash, uint64) bool + + // GetHeaderByHash retrieves a header from the local chain. + GetHeaderByHash(common.Hash) *types.Header + + // CurrentHeader retrieves the head header from the local chain. + CurrentHeader() *types.Header + + // GetTd returns the total difficulty of a local block. + GetTd(common.Hash, uint64) *big.Int + + // InsertHeaderChain inserts a batch of headers into the local chain. + InsertHeaderChain([]*types.Header, int) (int, error) + + // SetHead rewinds the local chain to a new head. + SetHead(uint64) error +} + +// BlockChain encapsulates functions required to sync a (full or fast) blockchain. +type BlockChain interface { + LightChain + + // HasBlock verifies a block's presence in the local chain. + HasBlock(common.Hash, uint64) bool + + // HasFastBlock verifies a fast block's presence in the local chain. + HasFastBlock(common.Hash, uint64) bool + + // GetBlockByHash retrieves a block from the local chain. + GetBlockByHash(common.Hash) *types.Block + + // CurrentBlock retrieves the head block from the local chain. + CurrentBlock() *types.Block + + // CurrentFastBlock retrieves the head fast block from the local chain. + CurrentFastBlock() *types.Block + + // FastSyncCommitHead directly commits the head block to a certain entity. + FastSyncCommitHead(common.Hash) error + + // InsertChain inserts a batch of blocks into the local chain. + InsertChain(types.Blocks) (int, error) + + // InsertReceiptChain inserts a batch of receipts into the local chain. + InsertReceiptChain(types.Blocks, []types.Receipts, uint64) (int, error) + + // Snapshots returns the blockchain snapshot tree to paused it during sync. + Snapshots() *snapshot.Tree +} + +// New creates a new downloader to fetch hashes and blocks from remote peers. +func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn) *Downloader { + if lightchain == nil { + lightchain = chain + } + dl := &Downloader{ + stateDB: stateDb, + stateBloom: stateBloom, + mux: mux, + checkpoint: checkpoint, + queue: newQueue(blockCacheMaxItems, blockCacheInitialItems), + peers: newPeerSet(), + blockchain: chain, + lightchain: lightchain, + dropPeer: dropPeer, + headerCh: make(chan dataPack, 1), + bodyCh: make(chan dataPack, 1), + receiptCh: make(chan dataPack, 1), + bodyWakeCh: make(chan bool, 1), + receiptWakeCh: make(chan bool, 1), + headerProcCh: make(chan []*types.Header, 1), + quitCh: make(chan struct{}), + stateCh: make(chan dataPack), + SnapSyncer: snap.NewSyncer(stateDb), + stateSyncStart: make(chan *stateSync), + syncStatsState: stateSyncStats{ + processed: rawdb.ReadFastTrieProgress(stateDb), + }, + trackStateReq: make(chan *stateReq), + } + go dl.stateFetcher() + return dl +} + +// Progress retrieves the synchronisation boundaries, specifically the origin +// block where synchronisation started at (may have failed/suspended); the block +// or header sync is currently at; and the latest known block which the sync targets. +// +// In addition, during the state download phase of fast synchronisation the number +// of processed and the total number of known states are also returned. Otherwise +// these are zero. +func (d *Downloader) Progress() ethereum.SyncProgress { + // Lock the current stats and return the progress + d.syncStatsLock.RLock() + defer d.syncStatsLock.RUnlock() + + current := uint64(0) + mode := d.getMode() + switch { + case d.blockchain != nil && mode == FullSync: + current = d.blockchain.CurrentBlock().NumberU64() + case d.blockchain != nil && mode == FastSync: + current = d.blockchain.CurrentFastBlock().NumberU64() + case d.lightchain != nil: + current = d.lightchain.CurrentHeader().Number.Uint64() + default: + log.Error("Unknown downloader chain/mode combo", "light", d.lightchain != nil, "full", d.blockchain != nil, "mode", mode) + } + return ethereum.SyncProgress{ + StartingBlock: d.syncStatsChainOrigin, + CurrentBlock: current, + HighestBlock: d.syncStatsChainHeight, + PulledStates: d.syncStatsState.processed, + KnownStates: d.syncStatsState.processed + d.syncStatsState.pending, + } +} + +// Synchronising returns whether the downloader is currently retrieving blocks. +func (d *Downloader) Synchronising() bool { + return atomic.LoadInt32(&d.synchronising) > 0 +} + +// RegisterPeer injects a new download peer into the set of block source to be +// used for fetching hashes and blocks from. +func (d *Downloader) RegisterPeer(id string, version uint, peer Peer) error { + var logger log.Logger + if len(id) < 16 { + // Tests use short IDs, don't choke on them + logger = log.New("peer", id) + } else { + logger = log.New("peer", id[:8]) + } + logger.Trace("Registering sync peer") + if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil { + logger.Error("Failed to register sync peer", "err", err) + return err + } + return nil +} + +// RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer. +func (d *Downloader) RegisterLightPeer(id string, version uint, peer LightPeer) error { + return d.RegisterPeer(id, version, &lightPeerWrapper{peer}) +} + +// UnregisterPeer remove a peer from the known list, preventing any action from +// the specified peer. An effort is also made to return any pending fetches into +// the queue. +func (d *Downloader) UnregisterPeer(id string) error { + // Unregister the peer from the active peer set and revoke any fetch tasks + var logger log.Logger + if len(id) < 16 { + // Tests use short IDs, don't choke on them + logger = log.New("peer", id) + } else { + logger = log.New("peer", id[:8]) + } + logger.Trace("Unregistering sync peer") + if err := d.peers.Unregister(id); err != nil { + logger.Error("Failed to unregister sync peer", "err", err) + return err + } + d.queue.Revoke(id) + + return nil +} + +// Synchronise tries to sync up our local block chain with a remote peer, both +// adding various sanity checks as well as wrapping it with various log entries. +func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode SyncMode) error { + err := d.synchronise(id, head, td, mode) + + switch err { + case nil, errBusy, errCanceled: + return err + } + if errors.Is(err, errInvalidChain) || errors.Is(err, errBadPeer) || errors.Is(err, errTimeout) || + errors.Is(err, errStallingPeer) || errors.Is(err, errUnsyncedPeer) || errors.Is(err, errEmptyHeaderSet) || + errors.Is(err, errPeersUnavailable) || errors.Is(err, errTooOld) || errors.Is(err, errInvalidAncestor) { + log.Warn("Synchronisation failed, dropping peer", "peer", id, "err", err) + if d.dropPeer == nil { + // The dropPeer method is nil when `--copydb` is used for a local copy. + // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored + log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", id) + } else { + d.dropPeer(id) + } + return err + } + log.Warn("Synchronisation failed, retrying", "err", err) + return err +} + +// synchronise will select the peer and use it for synchronising. If an empty string is given +// it will use the best peer possible and synchronize if its TD is higher than our own. If any of the +// checks fail an error will be returned. This method is synchronous +func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode SyncMode) error { + // Mock out the synchronisation if testing + if d.synchroniseMock != nil { + return d.synchroniseMock(id, hash) + } + // Make sure only one goroutine is ever allowed past this point at once + if !atomic.CompareAndSwapInt32(&d.synchronising, 0, 1) { + return errBusy + } + defer atomic.StoreInt32(&d.synchronising, 0) + + // Post a user notification of the sync (only once per session) + if atomic.CompareAndSwapInt32(&d.notified, 0, 1) { + log.Info("Block synchronisation started") + } + // If we are already full syncing, but have a fast-sync bloom filter laying + // around, make sure it doesn't use memory any more. This is a special case + // when the user attempts to fast sync a new empty network. + if mode == FullSync && d.stateBloom != nil { + d.stateBloom.Close() + } + // If snap sync was requested, create the snap scheduler and switch to fast + // sync mode. Long term we could drop fast sync or merge the two together, + // but until snap becomes prevalent, we should support both. TODO(karalabe). + if mode == SnapSync { + if !d.snapSync { + // Snap sync uses the snapshot namespace to store potentially flakey data until + // sync completely heals and finishes. Pause snapshot maintenance in the mean + // time to prevent access. + if snapshots := d.blockchain.Snapshots(); snapshots != nil { // Only nil in tests + snapshots.Disable() + } + log.Warn("Enabling snapshot sync prototype") + d.snapSync = true + } + mode = FastSync + } + // Reset the queue, peer set and wake channels to clean any internal leftover state + d.queue.Reset(blockCacheMaxItems, blockCacheInitialItems) + d.peers.Reset() + + for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} { + select { + case <-ch: + default: + } + } + for _, ch := range []chan dataPack{d.headerCh, d.bodyCh, d.receiptCh} { + for empty := false; !empty; { + select { + case <-ch: + default: + empty = true + } + } + } + for empty := false; !empty; { + select { + case <-d.headerProcCh: + default: + empty = true + } + } + // Create cancel channel for aborting mid-flight and mark the master peer + d.cancelLock.Lock() + d.cancelCh = make(chan struct{}) + d.cancelPeer = id + d.cancelLock.Unlock() + + defer d.Cancel() // No matter what, we can't leave the cancel channel open + + // Atomically set the requested sync mode + atomic.StoreUint32(&d.mode, uint32(mode)) + + // Retrieve the origin peer and initiate the downloading process + p := d.peers.Peer(id) + if p == nil { + return errUnknownPeer + } + return d.syncWithPeer(p, hash, td) +} + +func (d *Downloader) getMode() SyncMode { + return SyncMode(atomic.LoadUint32(&d.mode)) +} + +// syncWithPeer starts a block synchronization based on the hash chain from the +// specified peer and head hash. +func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.Int) (err error) { + d.mux.Post(StartEvent{}) + defer func() { + // reset on error + if err != nil { + d.mux.Post(FailedEvent{err}) + } else { + latest := d.lightchain.CurrentHeader() + d.mux.Post(DoneEvent{latest}) + } + }() + if p.version < eth.ETH66 { + return fmt.Errorf("%w: advertized %d < required %d", errTooOld, p.version, eth.ETH66) + } + mode := d.getMode() + + log.Debug("Synchronising with the network", "peer", p.id, "eth", p.version, "head", hash, "td", td, "mode", mode) + defer func(start time.Time) { + log.Debug("Synchronisation terminated", "elapsed", common.PrettyDuration(time.Since(start))) + }(time.Now()) + + // Look up the sync boundaries: the common ancestor and the target block + latest, pivot, err := d.fetchHead(p) + if err != nil { + return err + } + if mode == FastSync && pivot == nil { + // If no pivot block was returned, the head is below the min full block + // threshold (i.e. new chain). In that case we won't really fast sync + // anyway, but still need a valid pivot block to avoid some code hitting + // nil panics on an access. + pivot = d.blockchain.CurrentBlock().Header() + } + height := latest.Number.Uint64() + + origin, err := d.findAncestor(p, latest) + if err != nil { + return err + } + d.syncStatsLock.Lock() + if d.syncStatsChainHeight <= origin || d.syncStatsChainOrigin > origin { + d.syncStatsChainOrigin = origin + } + d.syncStatsChainHeight = height + d.syncStatsLock.Unlock() + + // Ensure our origin point is below any fast sync pivot point + if mode == FastSync { + if height <= uint64(fsMinFullBlocks) { + origin = 0 + } else { + pivotNumber := pivot.Number.Uint64() + if pivotNumber <= origin { + origin = pivotNumber - 1 + } + // Write out the pivot into the database so a rollback beyond it will + // reenable fast sync + rawdb.WriteLastPivotNumber(d.stateDB, pivotNumber) + } + } + d.committed = 1 + if mode == FastSync && pivot.Number.Uint64() != 0 { + d.committed = 0 + } + if mode == FastSync { + // Set the ancient data limitation. + // If we are running fast sync, all block data older than ancientLimit will be + // written to the ancient store. More recent data will be written to the active + // database and will wait for the freezer to migrate. + // + // If there is a checkpoint available, then calculate the ancientLimit through + // that. Otherwise calculate the ancient limit through the advertised height + // of the remote peer. + // + // The reason for picking checkpoint first is that a malicious peer can give us + // a fake (very high) height, forcing the ancient limit to also be very high. + // The peer would start to feed us valid blocks until head, resulting in all of + // the blocks might be written into the ancient store. A following mini-reorg + // could cause issues. + if d.checkpoint != 0 && d.checkpoint > fullMaxForkAncestry+1 { + d.ancientLimit = d.checkpoint + } else if height > fullMaxForkAncestry+1 { + d.ancientLimit = height - fullMaxForkAncestry - 1 + } else { + d.ancientLimit = 0 + } + frozen, _ := d.stateDB.Ancients() // Ignore the error here since light client can also hit here. + + // If a part of blockchain data has already been written into active store, + // disable the ancient style insertion explicitly. + if origin >= frozen && frozen != 0 { + d.ancientLimit = 0 + log.Info("Disabling direct-ancient mode", "origin", origin, "ancient", frozen-1) + } else if d.ancientLimit > 0 { + log.Debug("Enabling direct-ancient mode", "ancient", d.ancientLimit) + } + // Rewind the ancient store and blockchain if reorg happens. + if origin+1 < frozen { + if err := d.lightchain.SetHead(origin + 1); err != nil { + return err + } + } + } + // Initiate the sync using a concurrent header and content retrieval algorithm + d.queue.Prepare(origin+1, mode) + if d.syncInitHook != nil { + d.syncInitHook(origin, height) + } + fetchers := []func() error{ + func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved + func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync + func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync + func() error { return d.processHeaders(origin+1, td) }, + } + if mode == FastSync { + d.pivotLock.Lock() + d.pivotHeader = pivot + d.pivotLock.Unlock() + + fetchers = append(fetchers, func() error { return d.processFastSyncContent() }) + } else if mode == FullSync { + fetchers = append(fetchers, d.processFullSyncContent) + } + return d.spawnSync(fetchers) +} + +// spawnSync runs d.process and all given fetcher functions to completion in +// separate goroutines, returning the first error that appears. +func (d *Downloader) spawnSync(fetchers []func() error) error { + errc := make(chan error, len(fetchers)) + d.cancelWg.Add(len(fetchers)) + for _, fn := range fetchers { + fn := fn + go func() { defer d.cancelWg.Done(); errc <- fn() }() + } + // Wait for the first error, then terminate the others. + var err error + for i := 0; i < len(fetchers); i++ { + if i == len(fetchers)-1 { + // Close the queue when all fetchers have exited. + // This will cause the block processor to end when + // it has processed the queue. + d.queue.Close() + } + if err = <-errc; err != nil && err != errCanceled { + break + } + } + d.queue.Close() + d.Cancel() + return err +} + +// cancel aborts all of the operations and resets the queue. However, cancel does +// not wait for the running download goroutines to finish. This method should be +// used when cancelling the downloads from inside the downloader. +func (d *Downloader) cancel() { + // Close the current cancel channel + d.cancelLock.Lock() + defer d.cancelLock.Unlock() + + if d.cancelCh != nil { + select { + case <-d.cancelCh: + // Channel was already closed + default: + close(d.cancelCh) + } + } +} + +// Cancel aborts all of the operations and waits for all download goroutines to +// finish before returning. +func (d *Downloader) Cancel() { + d.cancel() + d.cancelWg.Wait() +} + +// Terminate interrupts the downloader, canceling all pending operations. +// The downloader cannot be reused after calling Terminate. +func (d *Downloader) Terminate() { + // Close the termination channel (make sure double close is allowed) + d.quitLock.Lock() + select { + case <-d.quitCh: + default: + close(d.quitCh) + } + if d.stateBloom != nil { + d.stateBloom.Close() + } + d.quitLock.Unlock() + + // Cancel any pending download requests + d.Cancel() +} + +// fetchHead retrieves the head header and prior pivot block (if available) from +// a remote peer. +func (d *Downloader) fetchHead(p *peerConnection) (head *types.Header, pivot *types.Header, err error) { + p.log.Debug("Retrieving remote chain head") + mode := d.getMode() + + // Request the advertised remote head block and wait for the response + latest, _ := p.peer.Head() + fetch := 1 + if mode == FastSync { + fetch = 2 // head + pivot headers + } + go p.peer.RequestHeadersByHash(latest, fetch, fsMinFullBlocks-1, true) + + ttl := d.peers.rates.TargetTimeout() + timeout := time.After(ttl) + for { + select { + case <-d.cancelCh: + return nil, nil, errCanceled + + case packet := <-d.headerCh: + // Discard anything not from the origin peer + if packet.PeerId() != p.id { + log.Debug("Received headers from incorrect peer", "peer", packet.PeerId()) + break + } + // Make sure the peer gave us at least one and at most the requested headers + headers := packet.(*headerPack).headers + if len(headers) == 0 || len(headers) > fetch { + return nil, nil, fmt.Errorf("%w: returned headers %d != requested %d", errBadPeer, len(headers), fetch) + } + // The first header needs to be the head, validate against the checkpoint + // and request. If only 1 header was returned, make sure there's no pivot + // or there was not one requested. + head := headers[0] + if (mode == FastSync || mode == LightSync) && head.Number.Uint64() < d.checkpoint { + return nil, nil, fmt.Errorf("%w: remote head %d below checkpoint %d", errUnsyncedPeer, head.Number, d.checkpoint) + } + if len(headers) == 1 { + if mode == FastSync && head.Number.Uint64() > uint64(fsMinFullBlocks) { + return nil, nil, fmt.Errorf("%w: no pivot included along head header", errBadPeer) + } + p.log.Debug("Remote head identified, no pivot", "number", head.Number, "hash", head.Hash()) + return head, nil, nil + } + // At this point we have 2 headers in total and the first is the + // validated head of the chain. Check the pivot number and return, + pivot := headers[1] + if pivot.Number.Uint64() != head.Number.Uint64()-uint64(fsMinFullBlocks) { + return nil, nil, fmt.Errorf("%w: remote pivot %d != requested %d", errInvalidChain, pivot.Number, head.Number.Uint64()-uint64(fsMinFullBlocks)) + } + return head, pivot, nil + + case <-timeout: + p.log.Debug("Waiting for head header timed out", "elapsed", ttl) + return nil, nil, errTimeout + + case <-d.bodyCh: + case <-d.receiptCh: + // Out of bounds delivery, ignore + } + } +} + +// calculateRequestSpan calculates what headers to request from a peer when trying to determine the +// common ancestor. +// It returns parameters to be used for peer.RequestHeadersByNumber: +// from - starting block number +// count - number of headers to request +// skip - number of headers to skip +// and also returns 'max', the last block which is expected to be returned by the remote peers, +// given the (from,count,skip) +func calculateRequestSpan(remoteHeight, localHeight uint64) (int64, int, int, uint64) { + var ( + from int + count int + MaxCount = MaxHeaderFetch / 16 + ) + // requestHead is the highest block that we will ask for. If requestHead is not offset, + // the highest block that we will get is 16 blocks back from head, which means we + // will fetch 14 or 15 blocks unnecessarily in the case the height difference + // between us and the peer is 1-2 blocks, which is most common + requestHead := int(remoteHeight) - 1 + if requestHead < 0 { + requestHead = 0 + } + // requestBottom is the lowest block we want included in the query + // Ideally, we want to include the one just below our own head + requestBottom := int(localHeight - 1) + if requestBottom < 0 { + requestBottom = 0 + } + totalSpan := requestHead - requestBottom + span := 1 + totalSpan/MaxCount + if span < 2 { + span = 2 + } + if span > 16 { + span = 16 + } + + count = 1 + totalSpan/span + if count > MaxCount { + count = MaxCount + } + if count < 2 { + count = 2 + } + from = requestHead - (count-1)*span + if from < 0 { + from = 0 + } + max := from + (count-1)*span + return int64(from), count, span - 1, uint64(max) +} + +// findAncestor tries to locate the common ancestor link of the local chain and +// a remote peers blockchain. In the general case when our node was in sync and +// on the correct chain, checking the top N links should already get us a match. +// In the rare scenario when we ended up on a long reorganisation (i.e. none of +// the head links match), we do a binary search to find the common ancestor. +func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header) (uint64, error) { + // Figure out the valid ancestor range to prevent rewrite attacks + var ( + floor = int64(-1) + localHeight uint64 + remoteHeight = remoteHeader.Number.Uint64() + ) + mode := d.getMode() + switch mode { + case FullSync: + localHeight = d.blockchain.CurrentBlock().NumberU64() + case FastSync: + localHeight = d.blockchain.CurrentFastBlock().NumberU64() + default: + localHeight = d.lightchain.CurrentHeader().Number.Uint64() + } + p.log.Debug("Looking for common ancestor", "local", localHeight, "remote", remoteHeight) + + // Recap floor value for binary search + maxForkAncestry := fullMaxForkAncestry + if d.getMode() == LightSync { + maxForkAncestry = lightMaxForkAncestry + } + if localHeight >= maxForkAncestry { + // We're above the max reorg threshold, find the earliest fork point + floor = int64(localHeight - maxForkAncestry) + } + // If we're doing a light sync, ensure the floor doesn't go below the CHT, as + // all headers before that point will be missing. + if mode == LightSync { + // If we don't know the current CHT position, find it + if d.genesis == 0 { + header := d.lightchain.CurrentHeader() + for header != nil { + d.genesis = header.Number.Uint64() + if floor >= int64(d.genesis)-1 { + break + } + header = d.lightchain.GetHeaderByHash(header.ParentHash) + } + } + // We already know the "genesis" block number, cap floor to that + if floor < int64(d.genesis)-1 { + floor = int64(d.genesis) - 1 + } + } + + ancestor, err := d.findAncestorSpanSearch(p, mode, remoteHeight, localHeight, floor) + if err == nil { + return ancestor, nil + } + // The returned error was not nil. + // If the error returned does not reflect that a common ancestor was not found, return it. + // If the error reflects that a common ancestor was not found, continue to binary search, + // where the error value will be reassigned. + if !errors.Is(err, errNoAncestorFound) { + return 0, err + } + + ancestor, err = d.findAncestorBinarySearch(p, mode, remoteHeight, floor) + if err != nil { + return 0, err + } + return ancestor, nil +} + +func (d *Downloader) findAncestorSpanSearch(p *peerConnection, mode SyncMode, remoteHeight, localHeight uint64, floor int64) (commonAncestor uint64, err error) { + from, count, skip, max := calculateRequestSpan(remoteHeight, localHeight) + + p.log.Trace("Span searching for common ancestor", "count", count, "from", from, "skip", skip) + go p.peer.RequestHeadersByNumber(uint64(from), count, skip, false) + + // Wait for the remote response to the head fetch + number, hash := uint64(0), common.Hash{} + + ttl := d.peers.rates.TargetTimeout() + timeout := time.After(ttl) + + for finished := false; !finished; { + select { + case <-d.cancelCh: + return 0, errCanceled + + case packet := <-d.headerCh: + // Discard anything not from the origin peer + if packet.PeerId() != p.id { + log.Debug("Received headers from incorrect peer", "peer", packet.PeerId()) + break + } + // Make sure the peer actually gave something valid + headers := packet.(*headerPack).headers + if len(headers) == 0 { + p.log.Warn("Empty head header set") + return 0, errEmptyHeaderSet + } + // Make sure the peer's reply conforms to the request + for i, header := range headers { + expectNumber := from + int64(i)*int64(skip+1) + if number := header.Number.Int64(); number != expectNumber { + p.log.Warn("Head headers broke chain ordering", "index", i, "requested", expectNumber, "received", number) + return 0, fmt.Errorf("%w: %v", errInvalidChain, errors.New("head headers broke chain ordering")) + } + } + // Check if a common ancestor was found + finished = true + for i := len(headers) - 1; i >= 0; i-- { + // Skip any headers that underflow/overflow our requested set + if headers[i].Number.Int64() < from || headers[i].Number.Uint64() > max { + continue + } + // Otherwise check if we already know the header or not + h := headers[i].Hash() + n := headers[i].Number.Uint64() + + var known bool + switch mode { + case FullSync: + known = d.blockchain.HasBlock(h, n) + case FastSync: + known = d.blockchain.HasFastBlock(h, n) + default: + known = d.lightchain.HasHeader(h, n) + } + if known { + number, hash = n, h + break + } + } + + case <-timeout: + p.log.Debug("Waiting for head header timed out", "elapsed", ttl) + return 0, errTimeout + + case <-d.bodyCh: + case <-d.receiptCh: + // Out of bounds delivery, ignore + } + } + // If the head fetch already found an ancestor, return + if hash != (common.Hash{}) { + if int64(number) <= floor { + p.log.Warn("Ancestor below allowance", "number", number, "hash", hash, "allowance", floor) + return 0, errInvalidAncestor + } + p.log.Debug("Found common ancestor", "number", number, "hash", hash) + return number, nil + } + return 0, errNoAncestorFound +} + +func (d *Downloader) findAncestorBinarySearch(p *peerConnection, mode SyncMode, remoteHeight uint64, floor int64) (commonAncestor uint64, err error) { + hash := common.Hash{} + + // Ancestor not found, we need to binary search over our chain + start, end := uint64(0), remoteHeight + if floor > 0 { + start = uint64(floor) + } + p.log.Trace("Binary searching for common ancestor", "start", start, "end", end) + + for start+1 < end { + // Split our chain interval in two, and request the hash to cross check + check := (start + end) / 2 + + ttl := d.peers.rates.TargetTimeout() + timeout := time.After(ttl) + + go p.peer.RequestHeadersByNumber(check, 1, 0, false) + + // Wait until a reply arrives to this request + for arrived := false; !arrived; { + select { + case <-d.cancelCh: + return 0, errCanceled + + case packet := <-d.headerCh: + // Discard anything not from the origin peer + if packet.PeerId() != p.id { + log.Debug("Received headers from incorrect peer", "peer", packet.PeerId()) + break + } + // Make sure the peer actually gave something valid + headers := packet.(*headerPack).headers + if len(headers) != 1 { + p.log.Warn("Multiple headers for single request", "headers", len(headers)) + return 0, fmt.Errorf("%w: multiple headers (%d) for single request", errBadPeer, len(headers)) + } + arrived = true + + // Modify the search interval based on the response + h := headers[0].Hash() + n := headers[0].Number.Uint64() + + var known bool + switch mode { + case FullSync: + known = d.blockchain.HasBlock(h, n) + case FastSync: + known = d.blockchain.HasFastBlock(h, n) + default: + known = d.lightchain.HasHeader(h, n) + } + if !known { + end = check + break + } + header := d.lightchain.GetHeaderByHash(h) // Independent of sync mode, header surely exists + if header.Number.Uint64() != check { + p.log.Warn("Received non requested header", "number", header.Number, "hash", header.Hash(), "request", check) + return 0, fmt.Errorf("%w: non-requested header (%d)", errBadPeer, header.Number) + } + start = check + hash = h + + case <-timeout: + p.log.Debug("Waiting for search header timed out", "elapsed", ttl) + return 0, errTimeout + + case <-d.bodyCh: + case <-d.receiptCh: + // Out of bounds delivery, ignore + } + } + } + // Ensure valid ancestry and return + if int64(start) <= floor { + p.log.Warn("Ancestor below allowance", "number", start, "hash", hash, "allowance", floor) + return 0, errInvalidAncestor + } + p.log.Debug("Found common ancestor", "number", start, "hash", hash) + return start, nil +} + +// fetchHeaders keeps retrieving headers concurrently from the number +// requested, until no more are returned, potentially throttling on the way. To +// facilitate concurrency but still protect against malicious nodes sending bad +// headers, we construct a header chain skeleton using the "origin" peer we are +// syncing with, and fill in the missing headers using anyone else. Headers from +// other peers are only accepted if they map cleanly to the skeleton. If no one +// can fill in the skeleton - not even the origin peer - it's assumed invalid and +// the origin is dropped. +func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error { + p.log.Debug("Directing header downloads", "origin", from) + defer p.log.Debug("Header download terminated") + + // Create a timeout timer, and the associated header fetcher + skeleton := true // Skeleton assembly phase or finishing up + pivoting := false // Whether the next request is pivot verification + request := time.Now() // time of the last skeleton fetch request + timeout := time.NewTimer(0) // timer to dump a non-responsive active peer + <-timeout.C // timeout channel should be initially empty + defer timeout.Stop() + + var ttl time.Duration + getHeaders := func(from uint64) { + request = time.Now() + + ttl = d.peers.rates.TargetTimeout() + timeout.Reset(ttl) + + if skeleton { + p.log.Trace("Fetching skeleton headers", "count", MaxHeaderFetch, "from", from) + go p.peer.RequestHeadersByNumber(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false) + } else { + p.log.Trace("Fetching full headers", "count", MaxHeaderFetch, "from", from) + go p.peer.RequestHeadersByNumber(from, MaxHeaderFetch, 0, false) + } + } + getNextPivot := func() { + pivoting = true + request = time.Now() + + ttl = d.peers.rates.TargetTimeout() + timeout.Reset(ttl) + + d.pivotLock.RLock() + pivot := d.pivotHeader.Number.Uint64() + d.pivotLock.RUnlock() + + p.log.Trace("Fetching next pivot header", "number", pivot+uint64(fsMinFullBlocks)) + go p.peer.RequestHeadersByNumber(pivot+uint64(fsMinFullBlocks), 2, fsMinFullBlocks-9, false) // move +64 when it's 2x64-8 deep + } + // Start pulling the header chain skeleton until all is done + ancestor := from + getHeaders(from) + + mode := d.getMode() + for { + select { + case <-d.cancelCh: + return errCanceled + + case packet := <-d.headerCh: + // Make sure the active peer is giving us the skeleton headers + if packet.PeerId() != p.id { + log.Debug("Received skeleton from incorrect peer", "peer", packet.PeerId()) + break + } + headerReqTimer.UpdateSince(request) + timeout.Stop() + + // If the pivot is being checked, move if it became stale and run the real retrieval + var pivot uint64 + + d.pivotLock.RLock() + if d.pivotHeader != nil { + pivot = d.pivotHeader.Number.Uint64() + } + d.pivotLock.RUnlock() + + if pivoting { + if packet.Items() == 2 { + // Retrieve the headers and do some sanity checks, just in case + headers := packet.(*headerPack).headers + + if have, want := headers[0].Number.Uint64(), pivot+uint64(fsMinFullBlocks); have != want { + log.Warn("Peer sent invalid next pivot", "have", have, "want", want) + return fmt.Errorf("%w: next pivot number %d != requested %d", errInvalidChain, have, want) + } + if have, want := headers[1].Number.Uint64(), pivot+2*uint64(fsMinFullBlocks)-8; have != want { + log.Warn("Peer sent invalid pivot confirmer", "have", have, "want", want) + return fmt.Errorf("%w: next pivot confirmer number %d != requested %d", errInvalidChain, have, want) + } + log.Warn("Pivot seemingly stale, moving", "old", pivot, "new", headers[0].Number) + pivot = headers[0].Number.Uint64() + + d.pivotLock.Lock() + d.pivotHeader = headers[0] + d.pivotLock.Unlock() + + // Write out the pivot into the database so a rollback beyond + // it will reenable fast sync and update the state root that + // the state syncer will be downloading. + rawdb.WriteLastPivotNumber(d.stateDB, pivot) + } + pivoting = false + getHeaders(from) + continue + } + // If the skeleton's finished, pull any remaining head headers directly from the origin + if skeleton && packet.Items() == 0 { + skeleton = false + getHeaders(from) + continue + } + // If no more headers are inbound, notify the content fetchers and return + if packet.Items() == 0 { + // Don't abort header fetches while the pivot is downloading + if atomic.LoadInt32(&d.committed) == 0 && pivot <= from { + p.log.Debug("No headers, waiting for pivot commit") + select { + case <-time.After(fsHeaderContCheck): + getHeaders(from) + continue + case <-d.cancelCh: + return errCanceled + } + } + // Pivot done (or not in fast sync) and no more headers, terminate the process + p.log.Debug("No more headers available") + select { + case d.headerProcCh <- nil: + return nil + case <-d.cancelCh: + return errCanceled + } + } + headers := packet.(*headerPack).headers + + // If we received a skeleton batch, resolve internals concurrently + if skeleton { + filled, proced, err := d.fillHeaderSkeleton(from, headers) + if err != nil { + p.log.Debug("Skeleton chain invalid", "err", err) + return fmt.Errorf("%w: %v", errInvalidChain, err) + } + headers = filled[proced:] + from += uint64(proced) + } else { + // If we're closing in on the chain head, but haven't yet reached it, delay + // the last few headers so mini reorgs on the head don't cause invalid hash + // chain errors. + if n := len(headers); n > 0 { + // Retrieve the current head we're at + var head uint64 + if mode == LightSync { + head = d.lightchain.CurrentHeader().Number.Uint64() + } else { + head = d.blockchain.CurrentFastBlock().NumberU64() + if full := d.blockchain.CurrentBlock().NumberU64(); head < full { + head = full + } + } + // If the head is below the common ancestor, we're actually deduplicating + // already existing chain segments, so use the ancestor as the fake head. + // Otherwise we might end up delaying header deliveries pointlessly. + if head < ancestor { + head = ancestor + } + // If the head is way older than this batch, delay the last few headers + if head+uint64(reorgProtThreshold) < headers[n-1].Number.Uint64() { + delay := reorgProtHeaderDelay + if delay > n { + delay = n + } + headers = headers[:n-delay] + } + } + } + // Insert all the new headers and fetch the next batch + if len(headers) > 0 { + p.log.Trace("Scheduling new headers", "count", len(headers), "from", from) + select { + case d.headerProcCh <- headers: + case <-d.cancelCh: + return errCanceled + } + from += uint64(len(headers)) + + // If we're still skeleton filling fast sync, check pivot staleness + // before continuing to the next skeleton filling + if skeleton && pivot > 0 { + getNextPivot() + } else { + getHeaders(from) + } + } else { + // No headers delivered, or all of them being delayed, sleep a bit and retry + p.log.Trace("All headers delayed, waiting") + select { + case <-time.After(fsHeaderContCheck): + getHeaders(from) + continue + case <-d.cancelCh: + return errCanceled + } + } + + case <-timeout.C: + if d.dropPeer == nil { + // The dropPeer method is nil when `--copydb` is used for a local copy. + // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored + p.log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", p.id) + break + } + // Header retrieval timed out, consider the peer bad and drop + p.log.Debug("Header request timed out", "elapsed", ttl) + headerTimeoutMeter.Mark(1) + d.dropPeer(p.id) + + // Finish the sync gracefully instead of dumping the gathered data though + for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} { + select { + case ch <- false: + case <-d.cancelCh: + } + } + select { + case d.headerProcCh <- nil: + case <-d.cancelCh: + } + return fmt.Errorf("%w: header request timed out", errBadPeer) + } + } +} + +// fillHeaderSkeleton concurrently retrieves headers from all our available peers +// and maps them to the provided skeleton header chain. +// +// Any partial results from the beginning of the skeleton is (if possible) forwarded +// immediately to the header processor to keep the rest of the pipeline full even +// in the case of header stalls. +// +// The method returns the entire filled skeleton and also the number of headers +// already forwarded for processing. +func (d *Downloader) fillHeaderSkeleton(from uint64, skeleton []*types.Header) ([]*types.Header, int, error) { + log.Debug("Filling up skeleton", "from", from) + d.queue.ScheduleSkeleton(from, skeleton) + + var ( + deliver = func(packet dataPack) (int, error) { + pack := packet.(*headerPack) + return d.queue.DeliverHeaders(pack.peerID, pack.headers, d.headerProcCh) + } + expire = func() map[string]int { return d.queue.ExpireHeaders(d.peers.rates.TargetTimeout()) } + reserve = func(p *peerConnection, count int) (*fetchRequest, bool, bool) { + return d.queue.ReserveHeaders(p, count), false, false + } + fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) } + capacity = func(p *peerConnection) int { return p.HeaderCapacity(d.peers.rates.TargetRoundTrip()) } + setIdle = func(p *peerConnection, accepted int, deliveryTime time.Time) { + p.SetHeadersIdle(accepted, deliveryTime) + } + ) + err := d.fetchParts(d.headerCh, deliver, d.queue.headerContCh, expire, + d.queue.PendingHeaders, d.queue.InFlightHeaders, reserve, + nil, fetch, d.queue.CancelHeaders, capacity, d.peers.HeaderIdlePeers, setIdle, "headers") + + log.Debug("Skeleton fill terminated", "err", err) + + filled, proced := d.queue.RetrieveHeaders() + return filled, proced, err +} + +// fetchBodies iteratively downloads the scheduled block bodies, taking any +// available peers, reserving a chunk of blocks for each, waiting for delivery +// and also periodically checking for timeouts. +func (d *Downloader) fetchBodies(from uint64) error { + log.Debug("Downloading block bodies", "origin", from) + + var ( + deliver = func(packet dataPack) (int, error) { + pack := packet.(*bodyPack) + return d.queue.DeliverBodies(pack.peerID, pack.transactions, pack.uncles) + } + expire = func() map[string]int { return d.queue.ExpireBodies(d.peers.rates.TargetTimeout()) } + fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchBodies(req) } + capacity = func(p *peerConnection) int { return p.BlockCapacity(d.peers.rates.TargetRoundTrip()) } + setIdle = func(p *peerConnection, accepted int, deliveryTime time.Time) { p.SetBodiesIdle(accepted, deliveryTime) } + ) + err := d.fetchParts(d.bodyCh, deliver, d.bodyWakeCh, expire, + d.queue.PendingBlocks, d.queue.InFlightBlocks, d.queue.ReserveBodies, + d.bodyFetchHook, fetch, d.queue.CancelBodies, capacity, d.peers.BodyIdlePeers, setIdle, "bodies") + + log.Debug("Block body download terminated", "err", err) + return err +} + +// fetchReceipts iteratively downloads the scheduled block receipts, taking any +// available peers, reserving a chunk of receipts for each, waiting for delivery +// and also periodically checking for timeouts. +func (d *Downloader) fetchReceipts(from uint64) error { + log.Debug("Downloading transaction receipts", "origin", from) + + var ( + deliver = func(packet dataPack) (int, error) { + pack := packet.(*receiptPack) + return d.queue.DeliverReceipts(pack.peerID, pack.receipts) + } + expire = func() map[string]int { return d.queue.ExpireReceipts(d.peers.rates.TargetTimeout()) } + fetch = func(p *peerConnection, req *fetchRequest) error { return p.FetchReceipts(req) } + capacity = func(p *peerConnection) int { return p.ReceiptCapacity(d.peers.rates.TargetRoundTrip()) } + setIdle = func(p *peerConnection, accepted int, deliveryTime time.Time) { + p.SetReceiptsIdle(accepted, deliveryTime) + } + ) + err := d.fetchParts(d.receiptCh, deliver, d.receiptWakeCh, expire, + d.queue.PendingReceipts, d.queue.InFlightReceipts, d.queue.ReserveReceipts, + d.receiptFetchHook, fetch, d.queue.CancelReceipts, capacity, d.peers.ReceiptIdlePeers, setIdle, "receipts") + + log.Debug("Transaction receipt download terminated", "err", err) + return err +} + +// fetchParts iteratively downloads scheduled block parts, taking any available +// peers, reserving a chunk of fetch requests for each, waiting for delivery and +// also periodically checking for timeouts. +// +// As the scheduling/timeout logic mostly is the same for all downloaded data +// types, this method is used by each for data gathering and is instrumented with +// various callbacks to handle the slight differences between processing them. +// +// The instrumentation parameters: +// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) +// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) +// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) +// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) +// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) +// - pending: task callback for the number of requests still needing download (detect completion/non-completability) +// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) +// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) +// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) +// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) +// - fetch: network callback to actually send a particular download request to a physical remote peer +// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) +// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) +// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks +// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) +// - kind: textual label of the type being downloaded to display in log messages +func (d *Downloader) fetchParts(deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, + expire func() map[string]int, pending func() int, inFlight func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, bool), + fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int, + idle func() ([]*peerConnection, int), setIdle func(*peerConnection, int, time.Time), kind string) error { + + // Create a ticker to detect expired retrieval tasks + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + update := make(chan struct{}, 1) + + // Prepare the queue and fetch block parts until the block header fetcher's done + finished := false + for { + select { + case <-d.cancelCh: + return errCanceled + + case packet := <-deliveryCh: + deliveryTime := time.Now() + // If the peer was previously banned and failed to deliver its pack + // in a reasonable time frame, ignore its message. + if peer := d.peers.Peer(packet.PeerId()); peer != nil { + // Deliver the received chunk of data and check chain validity + accepted, err := deliver(packet) + if errors.Is(err, errInvalidChain) { + return err + } + // Unless a peer delivered something completely else than requested (usually + // caused by a timed out request which came through in the end), set it to + // idle. If the delivery's stale, the peer should have already been idled. + if !errors.Is(err, errStaleDelivery) { + setIdle(peer, accepted, deliveryTime) + } + // Issue a log to the user to see what's going on + switch { + case err == nil && packet.Items() == 0: + peer.log.Trace("Requested data not delivered", "type", kind) + case err == nil: + peer.log.Trace("Delivered new batch of data", "type", kind, "count", packet.Stats()) + default: + peer.log.Debug("Failed to deliver retrieved data", "type", kind, "err", err) + } + } + // Blocks assembled, try to update the progress + select { + case update <- struct{}{}: + default: + } + + case cont := <-wakeCh: + // The header fetcher sent a continuation flag, check if it's done + if !cont { + finished = true + } + // Headers arrive, try to update the progress + select { + case update <- struct{}{}: + default: + } + + case <-ticker.C: + // Sanity check update the progress + select { + case update <- struct{}{}: + default: + } + + case <-update: + // Short circuit if we lost all our peers + if d.peers.Len() == 0 { + return errNoPeers + } + // Check for fetch request timeouts and demote the responsible peers + for pid, fails := range expire() { + if peer := d.peers.Peer(pid); peer != nil { + // If a lot of retrieval elements expired, we might have overestimated the remote peer or perhaps + // ourselves. Only reset to minimal throughput but don't drop just yet. If even the minimal times + // out that sync wise we need to get rid of the peer. + // + // The reason the minimum threshold is 2 is because the downloader tries to estimate the bandwidth + // and latency of a peer separately, which requires pushing the measures capacity a bit and seeing + // how response times reacts, to it always requests one more than the minimum (i.e. min 2). + if fails > 2 { + peer.log.Trace("Data delivery timed out", "type", kind) + setIdle(peer, 0, time.Now()) + } else { + peer.log.Debug("Stalling delivery, dropping", "type", kind) + + if d.dropPeer == nil { + // The dropPeer method is nil when `--copydb` is used for a local copy. + // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored + peer.log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", pid) + } else { + d.dropPeer(pid) + + // If this peer was the master peer, abort sync immediately + d.cancelLock.RLock() + master := pid == d.cancelPeer + d.cancelLock.RUnlock() + + if master { + d.cancel() + return errTimeout + } + } + } + } + } + // If there's nothing more to fetch, wait or terminate + if pending() == 0 { + if !inFlight() && finished { + log.Debug("Data fetching completed", "type", kind) + return nil + } + break + } + // Send a download request to all idle peers, until throttled + progressed, throttled, running := false, false, inFlight() + idles, total := idle() + pendCount := pending() + for _, peer := range idles { + // Short circuit if throttling activated + if throttled { + break + } + // Short circuit if there is no more available task. + if pendCount = pending(); pendCount == 0 { + break + } + // Reserve a chunk of fetches for a peer. A nil can mean either that + // no more headers are available, or that the peer is known not to + // have them. + request, progress, throttle := reserve(peer, capacity(peer)) + if progress { + progressed = true + } + if throttle { + throttled = true + throttleCounter.Inc(1) + } + if request == nil { + continue + } + if request.From > 0 { + peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From) + } else { + peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number) + } + // Fetch the chunk and make sure any errors return the hashes to the queue + if fetchHook != nil { + fetchHook(request.Headers) + } + if err := fetch(peer, request); err != nil { + // Although we could try and make an attempt to fix this, this error really + // means that we've double allocated a fetch task to a peer. If that is the + // case, the internal state of the downloader and the queue is very wrong so + // better hard crash and note the error instead of silently accumulating into + // a much bigger issue. + panic(fmt.Sprintf("%v: %s fetch assignment failed", peer, kind)) + } + running = true + } + // Make sure that we have peers available for fetching. If all peers have been tried + // and all failed throw an error + if !progressed && !throttled && !running && len(idles) == total && pendCount > 0 { + return errPeersUnavailable + } + } + } +} + +// processHeaders takes batches of retrieved headers from an input channel and +// keeps processing and scheduling them into the header chain and downloader's +// queue until the stream ends or a failure occurs. +func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { + // Keep a count of uncertain headers to roll back + var ( + rollback uint64 // Zero means no rollback (fine as you can't unroll the genesis) + rollbackErr error + mode = d.getMode() + ) + defer func() { + if rollback > 0 { + lastHeader, lastFastBlock, lastBlock := d.lightchain.CurrentHeader().Number, common.Big0, common.Big0 + if mode != LightSync { + lastFastBlock = d.blockchain.CurrentFastBlock().Number() + lastBlock = d.blockchain.CurrentBlock().Number() + } + if err := d.lightchain.SetHead(rollback - 1); err != nil { // -1 to target the parent of the first uncertain block + // We're already unwinding the stack, only print the error to make it more visible + log.Error("Failed to roll back chain segment", "head", rollback-1, "err", err) + } + curFastBlock, curBlock := common.Big0, common.Big0 + if mode != LightSync { + curFastBlock = d.blockchain.CurrentFastBlock().Number() + curBlock = d.blockchain.CurrentBlock().Number() + } + log.Warn("Rolled back chain segment", + "header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number), + "fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock), + "block", fmt.Sprintf("%d->%d", lastBlock, curBlock), "reason", rollbackErr) + } + }() + // Wait for batches of headers to process + gotHeaders := false + + for { + select { + case <-d.cancelCh: + rollbackErr = errCanceled + return errCanceled + + case headers := <-d.headerProcCh: + // Terminate header processing if we synced up + if len(headers) == 0 { + // Notify everyone that headers are fully processed + for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} { + select { + case ch <- false: + case <-d.cancelCh: + } + } + // If no headers were retrieved at all, the peer violated its TD promise that it had a + // better chain compared to ours. The only exception is if its promised blocks were + // already imported by other means (e.g. fetcher): + // + // R , L : Both at block 10 + // R: Mine block 11, and propagate it to L + // L: Queue block 11 for import + // L: Notice that R's head and TD increased compared to ours, start sync + // L: Import of block 11 finishes + // L: Sync begins, and finds common ancestor at 11 + // L: Request new headers up from 11 (R's TD was higher, it must have something) + // R: Nothing to give + if mode != LightSync { + head := d.blockchain.CurrentBlock() + if !gotHeaders && td.Cmp(d.blockchain.GetTd(head.Hash(), head.NumberU64())) > 0 { + return errStallingPeer + } + } + // If fast or light syncing, ensure promised headers are indeed delivered. This is + // needed to detect scenarios where an attacker feeds a bad pivot and then bails out + // of delivering the post-pivot blocks that would flag the invalid content. + // + // This check cannot be executed "as is" for full imports, since blocks may still be + // queued for processing when the header download completes. However, as long as the + // peer gave us something useful, we're already happy/progressed (above check). + if mode == FastSync || mode == LightSync { + head := d.lightchain.CurrentHeader() + if td.Cmp(d.lightchain.GetTd(head.Hash(), head.Number.Uint64())) > 0 { + return errStallingPeer + } + } + // Disable any rollback and return + rollback = 0 + return nil + } + // Otherwise split the chunk of headers into batches and process them + gotHeaders = true + for len(headers) > 0 { + // Terminate if something failed in between processing chunks + select { + case <-d.cancelCh: + rollbackErr = errCanceled + return errCanceled + default: + } + // Select the next chunk of headers to import + limit := maxHeadersProcess + if limit > len(headers) { + limit = len(headers) + } + chunk := headers[:limit] + + // In case of header only syncing, validate the chunk immediately + if mode == FastSync || mode == LightSync { + // If we're importing pure headers, verify based on their recentness + var pivot uint64 + + d.pivotLock.RLock() + if d.pivotHeader != nil { + pivot = d.pivotHeader.Number.Uint64() + } + d.pivotLock.RUnlock() + + frequency := fsHeaderCheckFrequency + if chunk[len(chunk)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot { + frequency = 1 + } + if n, err := d.lightchain.InsertHeaderChain(chunk, frequency); err != nil { + rollbackErr = err + + // If some headers were inserted, track them as uncertain + if (mode == FastSync || frequency > 1) && n > 0 && rollback == 0 { + rollback = chunk[0].Number.Uint64() + } + log.Warn("Invalid header encountered", "number", chunk[n].Number, "hash", chunk[n].Hash(), "parent", chunk[n].ParentHash, "err", err) + return fmt.Errorf("%w: %v", errInvalidChain, err) + } + // All verifications passed, track all headers within the alloted limits + if mode == FastSync { + head := chunk[len(chunk)-1].Number.Uint64() + if head-rollback > uint64(fsHeaderSafetyNet) { + rollback = head - uint64(fsHeaderSafetyNet) + } else { + rollback = 1 + } + } + } + // Unless we're doing light chains, schedule the headers for associated content retrieval + if mode == FullSync || mode == FastSync { + // If we've reached the allowed number of pending headers, stall a bit + for d.queue.PendingBlocks() >= maxQueuedHeaders || d.queue.PendingReceipts() >= maxQueuedHeaders { + select { + case <-d.cancelCh: + rollbackErr = errCanceled + return errCanceled + case <-time.After(time.Second): + } + } + // Otherwise insert the headers for content retrieval + inserts := d.queue.Schedule(chunk, origin) + if len(inserts) != len(chunk) { + rollbackErr = fmt.Errorf("stale headers: len inserts %v len(chunk) %v", len(inserts), len(chunk)) + return fmt.Errorf("%w: stale headers", errBadPeer) + } + } + headers = headers[limit:] + origin += uint64(limit) + } + // Update the highest block number we know if a higher one is found. + d.syncStatsLock.Lock() + if d.syncStatsChainHeight < origin { + d.syncStatsChainHeight = origin - 1 + } + d.syncStatsLock.Unlock() + + // Signal the content downloaders of the availablility of new tasks + for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh} { + select { + case ch <- true: + default: + } + } + } + } +} + +// processFullSyncContent takes fetch results from the queue and imports them into the chain. +func (d *Downloader) processFullSyncContent() error { + for { + results := d.queue.Results(true) + if len(results) == 0 { + return nil + } + if d.chainInsertHook != nil { + d.chainInsertHook(results) + } + if err := d.importBlockResults(results); err != nil { + return err + } + } +} + +func (d *Downloader) importBlockResults(results []*fetchResult) error { + // Check for any early termination requests + if len(results) == 0 { + return nil + } + select { + case <-d.quitCh: + return errCancelContentProcessing + default: + } + // Retrieve the a batch of results to import + first, last := results[0].Header, results[len(results)-1].Header + log.Debug("Inserting downloaded chain", "items", len(results), + "firstnum", first.Number, "firsthash", first.Hash(), + "lastnum", last.Number, "lasthash", last.Hash(), + ) + blocks := make([]*types.Block, len(results)) + for i, result := range results { + blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) + } + if index, err := d.blockchain.InsertChain(blocks); err != nil { + if index < len(results) { + log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) + } else { + // The InsertChain method in blockchain.go will sometimes return an out-of-bounds index, + // when it needs to preprocess blocks to import a sidechain. + // The importer will put together a new list of blocks to import, which is a superset + // of the blocks delivered from the downloader, and the indexing will be off. + log.Debug("Downloaded item processing failed on sidechain import", "index", index, "err", err) + } + return fmt.Errorf("%w: %v", errInvalidChain, err) + } + return nil +} + +// processFastSyncContent takes fetch results from the queue and writes them to the +// database. It also controls the synchronisation of state nodes of the pivot block. +func (d *Downloader) processFastSyncContent() error { + // Start syncing state of the reported head block. This should get us most of + // the state of the pivot block. + d.pivotLock.RLock() + sync := d.syncState(d.pivotHeader.Root) + d.pivotLock.RUnlock() + + defer func() { + // The `sync` object is replaced every time the pivot moves. We need to + // defer close the very last active one, hence the lazy evaluation vs. + // calling defer sync.Cancel() !!! + sync.Cancel() + }() + + closeOnErr := func(s *stateSync) { + if err := s.Wait(); err != nil && err != errCancelStateFetch && err != errCanceled && err != snap.ErrCancelled { + d.queue.Close() // wake up Results + } + } + go closeOnErr(sync) + + // To cater for moving pivot points, track the pivot block and subsequently + // accumulated download results separately. + var ( + oldPivot *fetchResult // Locked in pivot block, might change eventually + oldTail []*fetchResult // Downloaded content after the pivot + ) + for { + // Wait for the next batch of downloaded data to be available, and if the pivot + // block became stale, move the goalpost + results := d.queue.Results(oldPivot == nil) // Block if we're not monitoring pivot staleness + if len(results) == 0 { + // If pivot sync is done, stop + if oldPivot == nil { + return sync.Cancel() + } + // If sync failed, stop + select { + case <-d.cancelCh: + sync.Cancel() + return errCanceled + default: + } + } + if d.chainInsertHook != nil { + d.chainInsertHook(results) + } + // If we haven't downloaded the pivot block yet, check pivot staleness + // notifications from the header downloader + d.pivotLock.RLock() + pivot := d.pivotHeader + d.pivotLock.RUnlock() + + if oldPivot == nil { + if pivot.Root != sync.root { + sync.Cancel() + sync = d.syncState(pivot.Root) + + go closeOnErr(sync) + } + } else { + results = append(append([]*fetchResult{oldPivot}, oldTail...), results...) + } + // Split around the pivot block and process the two sides via fast/full sync + if atomic.LoadInt32(&d.committed) == 0 { + latest := results[len(results)-1].Header + // If the height is above the pivot block by 2 sets, it means the pivot + // become stale in the network and it was garbage collected, move to a + // new pivot. + // + // Note, we have `reorgProtHeaderDelay` number of blocks withheld, Those + // need to be taken into account, otherwise we're detecting the pivot move + // late and will drop peers due to unavailable state!!! + if height := latest.Number.Uint64(); height >= pivot.Number.Uint64()+2*uint64(fsMinFullBlocks)-uint64(reorgProtHeaderDelay) { + log.Warn("Pivot became stale, moving", "old", pivot.Number.Uint64(), "new", height-uint64(fsMinFullBlocks)+uint64(reorgProtHeaderDelay)) + pivot = results[len(results)-1-fsMinFullBlocks+reorgProtHeaderDelay].Header // must exist as lower old pivot is uncommitted + + d.pivotLock.Lock() + d.pivotHeader = pivot + d.pivotLock.Unlock() + + // Write out the pivot into the database so a rollback beyond it will + // reenable fast sync + rawdb.WriteLastPivotNumber(d.stateDB, pivot.Number.Uint64()) + } + } + P, beforeP, afterP := splitAroundPivot(pivot.Number.Uint64(), results) + if err := d.commitFastSyncData(beforeP, sync); err != nil { + return err + } + if P != nil { + // If new pivot block found, cancel old state retrieval and restart + if oldPivot != P { + sync.Cancel() + sync = d.syncState(P.Header.Root) + + go closeOnErr(sync) + oldPivot = P + } + // Wait for completion, occasionally checking for pivot staleness + select { + case <-sync.done: + if sync.err != nil { + return sync.err + } + if err := d.commitPivotBlock(P); err != nil { + return err + } + oldPivot = nil + + case <-time.After(time.Second): + oldTail = afterP + continue + } + } + // Fast sync done, pivot commit done, full import + if err := d.importBlockResults(afterP); err != nil { + return err + } + } +} + +func splitAroundPivot(pivot uint64, results []*fetchResult) (p *fetchResult, before, after []*fetchResult) { + if len(results) == 0 { + return nil, nil, nil + } + if lastNum := results[len(results)-1].Header.Number.Uint64(); lastNum < pivot { + // the pivot is somewhere in the future + return nil, results, nil + } + // This can also be optimized, but only happens very seldom + for _, result := range results { + num := result.Header.Number.Uint64() + switch { + case num < pivot: + before = append(before, result) + case num == pivot: + p = result + default: + after = append(after, result) + } + } + return p, before, after +} + +func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error { + // Check for any early termination requests + if len(results) == 0 { + return nil + } + select { + case <-d.quitCh: + return errCancelContentProcessing + case <-stateSync.done: + if err := stateSync.Wait(); err != nil { + return err + } + default: + } + // Retrieve the a batch of results to import + first, last := results[0].Header, results[len(results)-1].Header + log.Debug("Inserting fast-sync blocks", "items", len(results), + "firstnum", first.Number, "firsthash", first.Hash(), + "lastnumn", last.Number, "lasthash", last.Hash(), + ) + blocks := make([]*types.Block, len(results)) + receipts := make([]types.Receipts, len(results)) + for i, result := range results { + blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) + receipts[i] = result.Receipts + } + if index, err := d.blockchain.InsertReceiptChain(blocks, receipts, d.ancientLimit); err != nil { + log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) + return fmt.Errorf("%w: %v", errInvalidChain, err) + } + return nil +} + +func (d *Downloader) commitPivotBlock(result *fetchResult) error { + block := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) + log.Debug("Committing fast sync pivot as new head", "number", block.Number(), "hash", block.Hash()) + + // Commit the pivot block as the new head, will require full sync from here on + if _, err := d.blockchain.InsertReceiptChain([]*types.Block{block}, []types.Receipts{result.Receipts}, d.ancientLimit); err != nil { + return err + } + if err := d.blockchain.FastSyncCommitHead(block.Hash()); err != nil { + return err + } + atomic.StoreInt32(&d.committed, 1) + + // If we had a bloom filter for the state sync, deallocate it now. Note, we only + // deallocate internally, but keep the empty wrapper. This ensures that if we do + // a rollback after committing the pivot and restarting fast sync, we don't end + // up using a nil bloom. Empty bloom is fine, it just returns that it does not + // have the info we need, so reach down to the database instead. + if d.stateBloom != nil { + d.stateBloom.Close() + } + return nil +} + +// DeliverHeaders injects a new batch of block headers received from a remote +// node into the download schedule. +func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) error { + return d.deliver(d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter) +} + +// DeliverBodies injects a new batch of block bodies received from a remote node. +func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) error { + return d.deliver(d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter) +} + +// DeliverReceipts injects a new batch of receipts received from a remote node. +func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) error { + return d.deliver(d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter) +} + +// DeliverNodeData injects a new batch of node state data received from a remote node. +func (d *Downloader) DeliverNodeData(id string, data [][]byte) error { + return d.deliver(d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter) +} + +// DeliverSnapPacket is invoked from a peer's message handler when it transmits a +// data packet for the local node to consume. +func (d *Downloader) DeliverSnapPacket(peer *snap.Peer, packet snap.Packet) error { + switch packet := packet.(type) { + case *snap.AccountRangePacket: + hashes, accounts, err := packet.Unpack() + if err != nil { + return err + } + return d.SnapSyncer.OnAccounts(peer, packet.ID, hashes, accounts, packet.Proof) + + case *snap.StorageRangesPacket: + hashset, slotset := packet.Unpack() + return d.SnapSyncer.OnStorage(peer, packet.ID, hashset, slotset, packet.Proof) + + case *snap.ByteCodesPacket: + return d.SnapSyncer.OnByteCodes(peer, packet.ID, packet.Codes) + + case *snap.TrieNodesPacket: + return d.SnapSyncer.OnTrieNodes(peer, packet.ID, packet.Nodes) + + default: + return fmt.Errorf("unexpected snap packet type: %T", packet) + } +} + +// deliver injects a new batch of data received from a remote node. +func (d *Downloader) deliver(destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) { + // Update the delivery metrics for both good and failed deliveries + inMeter.Mark(int64(packet.Items())) + defer func() { + if err != nil { + dropMeter.Mark(int64(packet.Items())) + } + }() + // Deliver or abort if the sync is canceled while queuing + d.cancelLock.RLock() + cancel := d.cancelCh + d.cancelLock.RUnlock() + if cancel == nil { + return errNoSyncActive + } + select { + case destCh <- packet: + return nil + case <-cancel: + return errNoSyncActive + } +} diff --git a/les/downloader/downloader_test.go b/les/downloader/downloader_test.go new file mode 100644 index 0000000000000..17cd3630c98e0 --- /dev/null +++ b/les/downloader/downloader_test.go @@ -0,0 +1,1622 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "errors" + "fmt" + "math/big" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/trie" +) + +// Reduce some of the parameters to make the tester faster. +func init() { + fullMaxForkAncestry = 10000 + lightMaxForkAncestry = 10000 + blockCacheMaxItems = 1024 + fsHeaderContCheck = 500 * time.Millisecond +} + +// downloadTester is a test simulator for mocking out local block chain. +type downloadTester struct { + downloader *Downloader + + genesis *types.Block // Genesis blocks used by the tester and peers + stateDb ethdb.Database // Database used by the tester for syncing from peers + peerDb ethdb.Database // Database of the peers containing all data + peers map[string]*downloadTesterPeer + + ownHashes []common.Hash // Hash chain belonging to the tester + ownHeaders map[common.Hash]*types.Header // Headers belonging to the tester + ownBlocks map[common.Hash]*types.Block // Blocks belonging to the tester + ownReceipts map[common.Hash]types.Receipts // Receipts belonging to the tester + ownChainTd map[common.Hash]*big.Int // Total difficulties of the blocks in the local chain + + ancientHeaders map[common.Hash]*types.Header // Ancient headers belonging to the tester + ancientBlocks map[common.Hash]*types.Block // Ancient blocks belonging to the tester + ancientReceipts map[common.Hash]types.Receipts // Ancient receipts belonging to the tester + ancientChainTd map[common.Hash]*big.Int // Ancient total difficulties of the blocks in the local chain + + lock sync.RWMutex +} + +// newTester creates a new downloader test mocker. +func newTester() *downloadTester { + tester := &downloadTester{ + genesis: testGenesis, + peerDb: testDB, + peers: make(map[string]*downloadTesterPeer), + ownHashes: []common.Hash{testGenesis.Hash()}, + ownHeaders: map[common.Hash]*types.Header{testGenesis.Hash(): testGenesis.Header()}, + ownBlocks: map[common.Hash]*types.Block{testGenesis.Hash(): testGenesis}, + ownReceipts: map[common.Hash]types.Receipts{testGenesis.Hash(): nil}, + ownChainTd: map[common.Hash]*big.Int{testGenesis.Hash(): testGenesis.Difficulty()}, + + // Initialize ancient store with test genesis block + ancientHeaders: map[common.Hash]*types.Header{testGenesis.Hash(): testGenesis.Header()}, + ancientBlocks: map[common.Hash]*types.Block{testGenesis.Hash(): testGenesis}, + ancientReceipts: map[common.Hash]types.Receipts{testGenesis.Hash(): nil}, + ancientChainTd: map[common.Hash]*big.Int{testGenesis.Hash(): testGenesis.Difficulty()}, + } + tester.stateDb = rawdb.NewMemoryDatabase() + tester.stateDb.Put(testGenesis.Root().Bytes(), []byte{0x00}) + + tester.downloader = New(0, tester.stateDb, trie.NewSyncBloom(1, tester.stateDb), new(event.TypeMux), tester, nil, tester.dropPeer) + return tester +} + +// terminate aborts any operations on the embedded downloader and releases all +// held resources. +func (dl *downloadTester) terminate() { + dl.downloader.Terminate() +} + +// sync starts synchronizing with a remote peer, blocking until it completes. +func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error { + dl.lock.RLock() + hash := dl.peers[id].chain.headBlock().Hash() + // If no particular TD was requested, load from the peer's blockchain + if td == nil { + td = dl.peers[id].chain.td(hash) + } + dl.lock.RUnlock() + + // Synchronise with the chosen peer and ensure proper cleanup afterwards + err := dl.downloader.synchronise(id, hash, td, mode) + select { + case <-dl.downloader.cancelCh: + // Ok, downloader fully cancelled after sync cycle + default: + // Downloader is still accepting packets, can block a peer up + panic("downloader active post sync cycle") // panic will be caught by tester + } + return err +} + +// HasHeader checks if a header is present in the testers canonical chain. +func (dl *downloadTester) HasHeader(hash common.Hash, number uint64) bool { + return dl.GetHeaderByHash(hash) != nil +} + +// HasBlock checks if a block is present in the testers canonical chain. +func (dl *downloadTester) HasBlock(hash common.Hash, number uint64) bool { + return dl.GetBlockByHash(hash) != nil +} + +// HasFastBlock checks if a block is present in the testers canonical chain. +func (dl *downloadTester) HasFastBlock(hash common.Hash, number uint64) bool { + dl.lock.RLock() + defer dl.lock.RUnlock() + + if _, ok := dl.ancientReceipts[hash]; ok { + return true + } + _, ok := dl.ownReceipts[hash] + return ok +} + +// GetHeader retrieves a header from the testers canonical chain. +func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header { + dl.lock.RLock() + defer dl.lock.RUnlock() + return dl.getHeaderByHash(hash) +} + +// getHeaderByHash returns the header if found either within ancients or own blocks) +// This method assumes that the caller holds at least the read-lock (dl.lock) +func (dl *downloadTester) getHeaderByHash(hash common.Hash) *types.Header { + header := dl.ancientHeaders[hash] + if header != nil { + return header + } + return dl.ownHeaders[hash] +} + +// GetBlock retrieves a block from the testers canonical chain. +func (dl *downloadTester) GetBlockByHash(hash common.Hash) *types.Block { + dl.lock.RLock() + defer dl.lock.RUnlock() + + block := dl.ancientBlocks[hash] + if block != nil { + return block + } + return dl.ownBlocks[hash] +} + +// CurrentHeader retrieves the current head header from the canonical chain. +func (dl *downloadTester) CurrentHeader() *types.Header { + dl.lock.RLock() + defer dl.lock.RUnlock() + + for i := len(dl.ownHashes) - 1; i >= 0; i-- { + if header := dl.ancientHeaders[dl.ownHashes[i]]; header != nil { + return header + } + if header := dl.ownHeaders[dl.ownHashes[i]]; header != nil { + return header + } + } + return dl.genesis.Header() +} + +// CurrentBlock retrieves the current head block from the canonical chain. +func (dl *downloadTester) CurrentBlock() *types.Block { + dl.lock.RLock() + defer dl.lock.RUnlock() + + for i := len(dl.ownHashes) - 1; i >= 0; i-- { + if block := dl.ancientBlocks[dl.ownHashes[i]]; block != nil { + if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil { + return block + } + return block + } + if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil { + if _, err := dl.stateDb.Get(block.Root().Bytes()); err == nil { + return block + } + } + } + return dl.genesis +} + +// CurrentFastBlock retrieves the current head fast-sync block from the canonical chain. +func (dl *downloadTester) CurrentFastBlock() *types.Block { + dl.lock.RLock() + defer dl.lock.RUnlock() + + for i := len(dl.ownHashes) - 1; i >= 0; i-- { + if block := dl.ancientBlocks[dl.ownHashes[i]]; block != nil { + return block + } + if block := dl.ownBlocks[dl.ownHashes[i]]; block != nil { + return block + } + } + return dl.genesis +} + +// FastSyncCommitHead manually sets the head block to a given hash. +func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { + // For now only check that the state trie is correct + if block := dl.GetBlockByHash(hash); block != nil { + _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb)) + return err + } + return fmt.Errorf("non existent block: %x", hash[:4]) +} + +// GetTd retrieves the block's total difficulty from the canonical chain. +func (dl *downloadTester) GetTd(hash common.Hash, number uint64) *big.Int { + dl.lock.RLock() + defer dl.lock.RUnlock() + + return dl.getTd(hash) +} + +// getTd retrieves the block's total difficulty if found either within +// ancients or own blocks). +// This method assumes that the caller holds at least the read-lock (dl.lock) +func (dl *downloadTester) getTd(hash common.Hash) *big.Int { + if td := dl.ancientChainTd[hash]; td != nil { + return td + } + return dl.ownChainTd[hash] +} + +// InsertHeaderChain injects a new batch of headers into the simulated chain. +func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq int) (i int, err error) { + dl.lock.Lock() + defer dl.lock.Unlock() + // Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anything in case of errors + if dl.getHeaderByHash(headers[0].ParentHash) == nil { + return 0, fmt.Errorf("InsertHeaderChain: unknown parent at first position, parent of number %d", headers[0].Number) + } + var hashes []common.Hash + for i := 1; i < len(headers); i++ { + hash := headers[i-1].Hash() + if headers[i].ParentHash != headers[i-1].Hash() { + return i, fmt.Errorf("non-contiguous import at position %d", i) + } + hashes = append(hashes, hash) + } + hashes = append(hashes, headers[len(headers)-1].Hash()) + // Do a full insert if pre-checks passed + for i, header := range headers { + hash := hashes[i] + if dl.getHeaderByHash(hash) != nil { + continue + } + if dl.getHeaderByHash(header.ParentHash) == nil { + // This _should_ be impossible, due to precheck and induction + return i, fmt.Errorf("InsertHeaderChain: unknown parent at position %d", i) + } + dl.ownHashes = append(dl.ownHashes, hash) + dl.ownHeaders[hash] = header + + td := dl.getTd(header.ParentHash) + dl.ownChainTd[hash] = new(big.Int).Add(td, header.Difficulty) + } + return len(headers), nil +} + +// InsertChain injects a new batch of blocks into the simulated chain. +func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) { + dl.lock.Lock() + defer dl.lock.Unlock() + for i, block := range blocks { + if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok { + return i, fmt.Errorf("InsertChain: unknown parent at position %d / %d", i, len(blocks)) + } else if _, err := dl.stateDb.Get(parent.Root().Bytes()); err != nil { + return i, fmt.Errorf("InsertChain: unknown parent state %x: %v", parent.Root(), err) + } + if hdr := dl.getHeaderByHash(block.Hash()); hdr == nil { + dl.ownHashes = append(dl.ownHashes, block.Hash()) + dl.ownHeaders[block.Hash()] = block.Header() + } + dl.ownBlocks[block.Hash()] = block + dl.ownReceipts[block.Hash()] = make(types.Receipts, 0) + dl.stateDb.Put(block.Root().Bytes(), []byte{0x00}) + td := dl.getTd(block.ParentHash()) + dl.ownChainTd[block.Hash()] = new(big.Int).Add(td, block.Difficulty()) + } + return len(blocks), nil +} + +// InsertReceiptChain injects a new batch of receipts into the simulated chain. +func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []types.Receipts, ancientLimit uint64) (i int, err error) { + dl.lock.Lock() + defer dl.lock.Unlock() + + for i := 0; i < len(blocks) && i < len(receipts); i++ { + if _, ok := dl.ownHeaders[blocks[i].Hash()]; !ok { + return i, errors.New("unknown owner") + } + if _, ok := dl.ancientBlocks[blocks[i].ParentHash()]; !ok { + if _, ok := dl.ownBlocks[blocks[i].ParentHash()]; !ok { + return i, errors.New("InsertReceiptChain: unknown parent") + } + } + if blocks[i].NumberU64() <= ancientLimit { + dl.ancientBlocks[blocks[i].Hash()] = blocks[i] + dl.ancientReceipts[blocks[i].Hash()] = receipts[i] + + // Migrate from active db to ancient db + dl.ancientHeaders[blocks[i].Hash()] = blocks[i].Header() + dl.ancientChainTd[blocks[i].Hash()] = new(big.Int).Add(dl.ancientChainTd[blocks[i].ParentHash()], blocks[i].Difficulty()) + delete(dl.ownHeaders, blocks[i].Hash()) + delete(dl.ownChainTd, blocks[i].Hash()) + } else { + dl.ownBlocks[blocks[i].Hash()] = blocks[i] + dl.ownReceipts[blocks[i].Hash()] = receipts[i] + } + } + return len(blocks), nil +} + +// SetHead rewinds the local chain to a new head. +func (dl *downloadTester) SetHead(head uint64) error { + dl.lock.Lock() + defer dl.lock.Unlock() + + // Find the hash of the head to reset to + var hash common.Hash + for h, header := range dl.ownHeaders { + if header.Number.Uint64() == head { + hash = h + } + } + for h, header := range dl.ancientHeaders { + if header.Number.Uint64() == head { + hash = h + } + } + if hash == (common.Hash{}) { + return fmt.Errorf("unknown head to set: %d", head) + } + // Find the offset in the header chain + var offset int + for o, h := range dl.ownHashes { + if h == hash { + offset = o + break + } + } + // Remove all the hashes and associated data afterwards + for i := offset + 1; i < len(dl.ownHashes); i++ { + delete(dl.ownChainTd, dl.ownHashes[i]) + delete(dl.ownHeaders, dl.ownHashes[i]) + delete(dl.ownReceipts, dl.ownHashes[i]) + delete(dl.ownBlocks, dl.ownHashes[i]) + + delete(dl.ancientChainTd, dl.ownHashes[i]) + delete(dl.ancientHeaders, dl.ownHashes[i]) + delete(dl.ancientReceipts, dl.ownHashes[i]) + delete(dl.ancientBlocks, dl.ownHashes[i]) + } + dl.ownHashes = dl.ownHashes[:offset+1] + return nil +} + +// Rollback removes some recently added elements from the chain. +func (dl *downloadTester) Rollback(hashes []common.Hash) { +} + +// newPeer registers a new block download source into the downloader. +func (dl *downloadTester) newPeer(id string, version uint, chain *testChain) error { + dl.lock.Lock() + defer dl.lock.Unlock() + + peer := &downloadTesterPeer{dl: dl, id: id, chain: chain} + dl.peers[id] = peer + return dl.downloader.RegisterPeer(id, version, peer) +} + +// dropPeer simulates a hard peer removal from the connection pool. +func (dl *downloadTester) dropPeer(id string) { + dl.lock.Lock() + defer dl.lock.Unlock() + + delete(dl.peers, id) + dl.downloader.UnregisterPeer(id) +} + +// Snapshots implements the BlockChain interface for the downloader, but is a noop. +func (dl *downloadTester) Snapshots() *snapshot.Tree { + return nil +} + +type downloadTesterPeer struct { + dl *downloadTester + id string + chain *testChain + missingStates map[common.Hash]bool // State entries that fast sync should not return +} + +// Head constructs a function to retrieve a peer's current head hash +// and total difficulty. +func (dlp *downloadTesterPeer) Head() (common.Hash, *big.Int) { + b := dlp.chain.headBlock() + return b.Hash(), dlp.chain.td(b.Hash()) +} + +// RequestHeadersByHash constructs a GetBlockHeaders function based on a hashed +// origin; associated with a particular peer in the download tester. The returned +// function can be used to retrieve batches of headers from the particular peer. +func (dlp *downloadTesterPeer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { + result := dlp.chain.headersByHash(origin, amount, skip, reverse) + go dlp.dl.downloader.DeliverHeaders(dlp.id, result) + return nil +} + +// RequestHeadersByNumber constructs a GetBlockHeaders function based on a numbered +// origin; associated with a particular peer in the download tester. The returned +// function can be used to retrieve batches of headers from the particular peer. +func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { + result := dlp.chain.headersByNumber(origin, amount, skip, reverse) + go dlp.dl.downloader.DeliverHeaders(dlp.id, result) + return nil +} + +// RequestBodies constructs a getBlockBodies method associated with a particular +// peer in the download tester. The returned function can be used to retrieve +// batches of block bodies from the particularly requested peer. +func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash) error { + txs, uncles := dlp.chain.bodies(hashes) + go dlp.dl.downloader.DeliverBodies(dlp.id, txs, uncles) + return nil +} + +// RequestReceipts constructs a getReceipts method associated with a particular +// peer in the download tester. The returned function can be used to retrieve +// batches of block receipts from the particularly requested peer. +func (dlp *downloadTesterPeer) RequestReceipts(hashes []common.Hash) error { + receipts := dlp.chain.receipts(hashes) + go dlp.dl.downloader.DeliverReceipts(dlp.id, receipts) + return nil +} + +// RequestNodeData constructs a getNodeData method associated with a particular +// peer in the download tester. The returned function can be used to retrieve +// batches of node state data from the particularly requested peer. +func (dlp *downloadTesterPeer) RequestNodeData(hashes []common.Hash) error { + dlp.dl.lock.RLock() + defer dlp.dl.lock.RUnlock() + + results := make([][]byte, 0, len(hashes)) + for _, hash := range hashes { + if data, err := dlp.dl.peerDb.Get(hash.Bytes()); err == nil { + if !dlp.missingStates[hash] { + results = append(results, data) + } + } + } + go dlp.dl.downloader.DeliverNodeData(dlp.id, results) + return nil +} + +// assertOwnChain checks if the local chain contains the correct number of items +// of the various chain components. +func assertOwnChain(t *testing.T, tester *downloadTester, length int) { + // Mark this method as a helper to report errors at callsite, not in here + t.Helper() + + assertOwnForkedChain(t, tester, 1, []int{length}) +} + +// assertOwnForkedChain checks if the local forked chain contains the correct +// number of items of the various chain components. +func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) { + // Mark this method as a helper to report errors at callsite, not in here + t.Helper() + + // Initialize the counters for the first fork + headers, blocks, receipts := lengths[0], lengths[0], lengths[0] + + // Update the counters for each subsequent fork + for _, length := range lengths[1:] { + headers += length - common + blocks += length - common + receipts += length - common + } + if tester.downloader.getMode() == LightSync { + blocks, receipts = 1, 1 + } + if hs := len(tester.ownHeaders) + len(tester.ancientHeaders) - 1; hs != headers { + t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers) + } + if bs := len(tester.ownBlocks) + len(tester.ancientBlocks) - 1; bs != blocks { + t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks) + } + if rs := len(tester.ownReceipts) + len(tester.ancientReceipts) - 1; rs != receipts { + t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts) + } +} + +func TestCanonicalSynchronisation66Full(t *testing.T) { testCanonSync(t, eth.ETH66, FullSync) } +func TestCanonicalSynchronisation66Fast(t *testing.T) { testCanonSync(t, eth.ETH66, FastSync) } +func TestCanonicalSynchronisation66Light(t *testing.T) { testCanonSync(t, eth.ETH66, LightSync) } + +func testCanonSync(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + // Create a small enough block chain to download + chain := testChainBase.shorten(blockCacheMaxItems - 15) + tester.newPeer("peer", protocol, chain) + + // Synchronise with the peer and make sure all relevant data was retrieved + if err := tester.sync("peer", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chain.len()) +} + +// Tests that if a large batch of blocks are being downloaded, it is throttled +// until the cached blocks are retrieved. +func TestThrottling66Full(t *testing.T) { testThrottling(t, eth.ETH66, FullSync) } +func TestThrottling66Fast(t *testing.T) { testThrottling(t, eth.ETH66, FastSync) } + +func testThrottling(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + tester := newTester() + + // Create a long block chain to download and the tester + targetBlocks := testChainBase.len() - 1 + tester.newPeer("peer", protocol, testChainBase) + + // Wrap the importer to allow stepping + blocked, proceed := uint32(0), make(chan struct{}) + tester.downloader.chainInsertHook = func(results []*fetchResult) { + atomic.StoreUint32(&blocked, uint32(len(results))) + <-proceed + } + // Start a synchronisation concurrently + errc := make(chan error, 1) + go func() { + errc <- tester.sync("peer", nil, mode) + }() + // Iteratively take some blocks, always checking the retrieval count + for { + // Check the retrieval count synchronously (! reason for this ugly block) + tester.lock.RLock() + retrieved := len(tester.ownBlocks) + tester.lock.RUnlock() + if retrieved >= targetBlocks+1 { + break + } + // Wait a bit for sync to throttle itself + var cached, frozen int + for start := time.Now(); time.Since(start) < 3*time.Second; { + time.Sleep(25 * time.Millisecond) + + tester.lock.Lock() + tester.downloader.queue.lock.Lock() + tester.downloader.queue.resultCache.lock.Lock() + { + cached = tester.downloader.queue.resultCache.countCompleted() + frozen = int(atomic.LoadUint32(&blocked)) + retrieved = len(tester.ownBlocks) + } + tester.downloader.queue.resultCache.lock.Unlock() + tester.downloader.queue.lock.Unlock() + tester.lock.Unlock() + + if cached == blockCacheMaxItems || + cached == blockCacheMaxItems-reorgProtHeaderDelay || + retrieved+cached+frozen == targetBlocks+1 || + retrieved+cached+frozen == targetBlocks+1-reorgProtHeaderDelay { + break + } + } + // Make sure we filled up the cache, then exhaust it + time.Sleep(25 * time.Millisecond) // give it a chance to screw up + tester.lock.RLock() + retrieved = len(tester.ownBlocks) + tester.lock.RUnlock() + if cached != blockCacheMaxItems && cached != blockCacheMaxItems-reorgProtHeaderDelay && retrieved+cached+frozen != targetBlocks+1 && retrieved+cached+frozen != targetBlocks+1-reorgProtHeaderDelay { + t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheMaxItems, retrieved, frozen, targetBlocks+1) + } + + // Permit the blocked blocks to import + if atomic.LoadUint32(&blocked) > 0 { + atomic.StoreUint32(&blocked, uint32(0)) + proceed <- struct{}{} + } + } + // Check that we haven't pulled more blocks than available + assertOwnChain(t, tester, targetBlocks+1) + if err := <-errc; err != nil { + t.Fatalf("block synchronization failed: %v", err) + } + tester.terminate() + +} + +// Tests that simple synchronization against a forked chain works correctly. In +// this test common ancestor lookup should *not* be short circuited, and a full +// binary search should be executed. +func TestForkedSync66Full(t *testing.T) { testForkedSync(t, eth.ETH66, FullSync) } +func TestForkedSync66Fast(t *testing.T) { testForkedSync(t, eth.ETH66, FastSync) } +func TestForkedSync66Light(t *testing.T) { testForkedSync(t, eth.ETH66, LightSync) } + +func testForkedSync(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + chainA := testChainForkLightA.shorten(testChainBase.len() + 80) + chainB := testChainForkLightB.shorten(testChainBase.len() + 80) + tester.newPeer("fork A", protocol, chainA) + tester.newPeer("fork B", protocol, chainB) + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("fork A", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chainA.len()) + + // Synchronise with the second peer and make sure that fork is pulled too + if err := tester.sync("fork B", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnForkedChain(t, tester, testChainBase.len(), []int{chainA.len(), chainB.len()}) +} + +// Tests that synchronising against a much shorter but much heavyer fork works +// corrently and is not dropped. +func TestHeavyForkedSync66Full(t *testing.T) { testHeavyForkedSync(t, eth.ETH66, FullSync) } +func TestHeavyForkedSync66Fast(t *testing.T) { testHeavyForkedSync(t, eth.ETH66, FastSync) } +func TestHeavyForkedSync66Light(t *testing.T) { testHeavyForkedSync(t, eth.ETH66, LightSync) } + +func testHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + chainA := testChainForkLightA.shorten(testChainBase.len() + 80) + chainB := testChainForkHeavy.shorten(testChainBase.len() + 80) + tester.newPeer("light", protocol, chainA) + tester.newPeer("heavy", protocol, chainB) + + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("light", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chainA.len()) + + // Synchronise with the second peer and make sure that fork is pulled too + if err := tester.sync("heavy", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnForkedChain(t, tester, testChainBase.len(), []int{chainA.len(), chainB.len()}) +} + +// Tests that chain forks are contained within a certain interval of the current +// chain head, ensuring that malicious peers cannot waste resources by feeding +// long dead chains. +func TestBoundedForkedSync66Full(t *testing.T) { testBoundedForkedSync(t, eth.ETH66, FullSync) } +func TestBoundedForkedSync66Fast(t *testing.T) { testBoundedForkedSync(t, eth.ETH66, FastSync) } +func TestBoundedForkedSync66Light(t *testing.T) { testBoundedForkedSync(t, eth.ETH66, LightSync) } + +func testBoundedForkedSync(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + chainA := testChainForkLightA + chainB := testChainForkLightB + tester.newPeer("original", protocol, chainA) + tester.newPeer("rewriter", protocol, chainB) + + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("original", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chainA.len()) + + // Synchronise with the second peer and ensure that the fork is rejected to being too old + if err := tester.sync("rewriter", nil, mode); err != errInvalidAncestor { + t.Fatalf("sync failure mismatch: have %v, want %v", err, errInvalidAncestor) + } +} + +// Tests that chain forks are contained within a certain interval of the current +// chain head for short but heavy forks too. These are a bit special because they +// take different ancestor lookup paths. +func TestBoundedHeavyForkedSync66Full(t *testing.T) { + testBoundedHeavyForkedSync(t, eth.ETH66, FullSync) +} +func TestBoundedHeavyForkedSync66Fast(t *testing.T) { + testBoundedHeavyForkedSync(t, eth.ETH66, FastSync) +} +func TestBoundedHeavyForkedSync66Light(t *testing.T) { + testBoundedHeavyForkedSync(t, eth.ETH66, LightSync) +} + +func testBoundedHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + tester := newTester() + + // Create a long enough forked chain + chainA := testChainForkLightA + chainB := testChainForkHeavy + tester.newPeer("original", protocol, chainA) + + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("original", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chainA.len()) + + tester.newPeer("heavy-rewriter", protocol, chainB) + // Synchronise with the second peer and ensure that the fork is rejected to being too old + if err := tester.sync("heavy-rewriter", nil, mode); err != errInvalidAncestor { + t.Fatalf("sync failure mismatch: have %v, want %v", err, errInvalidAncestor) + } + tester.terminate() +} + +// Tests that an inactive downloader will not accept incoming block headers, +// bodies and receipts. +func TestInactiveDownloader63(t *testing.T) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + // Check that neither block headers nor bodies are accepted + if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { + t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive) + } + if err := tester.downloader.DeliverBodies("bad peer", [][]*types.Transaction{}, [][]*types.Header{}); err != errNoSyncActive { + t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive) + } + if err := tester.downloader.DeliverReceipts("bad peer", [][]*types.Receipt{}); err != errNoSyncActive { + t.Errorf("error mismatch: have %v, want %v", err, errNoSyncActive) + } +} + +// Tests that a canceled download wipes all previously accumulated state. +func TestCancel66Full(t *testing.T) { testCancel(t, eth.ETH66, FullSync) } +func TestCancel66Fast(t *testing.T) { testCancel(t, eth.ETH66, FastSync) } +func TestCancel66Light(t *testing.T) { testCancel(t, eth.ETH66, LightSync) } + +func testCancel(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + chain := testChainBase.shorten(MaxHeaderFetch) + tester.newPeer("peer", protocol, chain) + + // Make sure canceling works with a pristine downloader + tester.downloader.Cancel() + if !tester.downloader.queue.Idle() { + t.Errorf("download queue not idle") + } + // Synchronise with the peer, but cancel afterwards + if err := tester.sync("peer", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + tester.downloader.Cancel() + if !tester.downloader.queue.Idle() { + t.Errorf("download queue not idle") + } +} + +// Tests that synchronisation from multiple peers works as intended (multi thread sanity test). +func TestMultiSynchronisation66Full(t *testing.T) { testMultiSynchronisation(t, eth.ETH66, FullSync) } +func TestMultiSynchronisation66Fast(t *testing.T) { testMultiSynchronisation(t, eth.ETH66, FastSync) } +func TestMultiSynchronisation66Light(t *testing.T) { testMultiSynchronisation(t, eth.ETH66, LightSync) } + +func testMultiSynchronisation(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + // Create various peers with various parts of the chain + targetPeers := 8 + chain := testChainBase.shorten(targetPeers * 100) + + for i := 0; i < targetPeers; i++ { + id := fmt.Sprintf("peer #%d", i) + tester.newPeer(id, protocol, chain.shorten(chain.len()/(i+1))) + } + if err := tester.sync("peer #0", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chain.len()) +} + +// Tests that synchronisations behave well in multi-version protocol environments +// and not wreak havoc on other nodes in the network. +func TestMultiProtoSynchronisation66Full(t *testing.T) { testMultiProtoSync(t, eth.ETH66, FullSync) } +func TestMultiProtoSynchronisation66Fast(t *testing.T) { testMultiProtoSync(t, eth.ETH66, FastSync) } +func TestMultiProtoSynchronisation66Light(t *testing.T) { testMultiProtoSync(t, eth.ETH66, LightSync) } + +func testMultiProtoSync(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + // Create a small enough block chain to download + chain := testChainBase.shorten(blockCacheMaxItems - 15) + + // Create peers of every type + tester.newPeer("peer 66", eth.ETH66, chain) + //tester.newPeer("peer 65", eth.ETH67, chain) + + // Synchronise with the requested peer and make sure all blocks were retrieved + if err := tester.sync(fmt.Sprintf("peer %d", protocol), nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chain.len()) + + // Check that no peers have been dropped off + for _, version := range []int{66} { + peer := fmt.Sprintf("peer %d", version) + if _, ok := tester.peers[peer]; !ok { + t.Errorf("%s dropped", peer) + } + } +} + +// Tests that if a block is empty (e.g. header only), no body request should be +// made, and instead the header should be assembled into a whole block in itself. +func TestEmptyShortCircuit66Full(t *testing.T) { testEmptyShortCircuit(t, eth.ETH66, FullSync) } +func TestEmptyShortCircuit66Fast(t *testing.T) { testEmptyShortCircuit(t, eth.ETH66, FastSync) } +func TestEmptyShortCircuit66Light(t *testing.T) { testEmptyShortCircuit(t, eth.ETH66, LightSync) } + +func testEmptyShortCircuit(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + // Create a block chain to download + chain := testChainBase + tester.newPeer("peer", protocol, chain) + + // Instrument the downloader to signal body requests + bodiesHave, receiptsHave := int32(0), int32(0) + tester.downloader.bodyFetchHook = func(headers []*types.Header) { + atomic.AddInt32(&bodiesHave, int32(len(headers))) + } + tester.downloader.receiptFetchHook = func(headers []*types.Header) { + atomic.AddInt32(&receiptsHave, int32(len(headers))) + } + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("peer", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chain.len()) + + // Validate the number of block bodies that should have been requested + bodiesNeeded, receiptsNeeded := 0, 0 + for _, block := range chain.blockm { + if mode != LightSync && block != tester.genesis && (len(block.Transactions()) > 0 || len(block.Uncles()) > 0) { + bodiesNeeded++ + } + } + for _, receipt := range chain.receiptm { + if mode == FastSync && len(receipt) > 0 { + receiptsNeeded++ + } + } + if int(bodiesHave) != bodiesNeeded { + t.Errorf("body retrieval count mismatch: have %v, want %v", bodiesHave, bodiesNeeded) + } + if int(receiptsHave) != receiptsNeeded { + t.Errorf("receipt retrieval count mismatch: have %v, want %v", receiptsHave, receiptsNeeded) + } +} + +// Tests that headers are enqueued continuously, preventing malicious nodes from +// stalling the downloader by feeding gapped header chains. +func TestMissingHeaderAttack66Full(t *testing.T) { testMissingHeaderAttack(t, eth.ETH66, FullSync) } +func TestMissingHeaderAttack66Fast(t *testing.T) { testMissingHeaderAttack(t, eth.ETH66, FastSync) } +func TestMissingHeaderAttack66Light(t *testing.T) { testMissingHeaderAttack(t, eth.ETH66, LightSync) } + +func testMissingHeaderAttack(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + chain := testChainBase.shorten(blockCacheMaxItems - 15) + brokenChain := chain.shorten(chain.len()) + delete(brokenChain.headerm, brokenChain.chain[brokenChain.len()/2]) + tester.newPeer("attack", protocol, brokenChain) + + if err := tester.sync("attack", nil, mode); err == nil { + t.Fatalf("succeeded attacker synchronisation") + } + // Synchronise with the valid peer and make sure sync succeeds + tester.newPeer("valid", protocol, chain) + if err := tester.sync("valid", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chain.len()) +} + +// Tests that if requested headers are shifted (i.e. first is missing), the queue +// detects the invalid numbering. +func TestShiftedHeaderAttack66Full(t *testing.T) { testShiftedHeaderAttack(t, eth.ETH66, FullSync) } +func TestShiftedHeaderAttack66Fast(t *testing.T) { testShiftedHeaderAttack(t, eth.ETH66, FastSync) } +func TestShiftedHeaderAttack66Light(t *testing.T) { testShiftedHeaderAttack(t, eth.ETH66, LightSync) } + +func testShiftedHeaderAttack(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + + chain := testChainBase.shorten(blockCacheMaxItems - 15) + + // Attempt a full sync with an attacker feeding shifted headers + brokenChain := chain.shorten(chain.len()) + delete(brokenChain.headerm, brokenChain.chain[1]) + delete(brokenChain.blockm, brokenChain.chain[1]) + delete(brokenChain.receiptm, brokenChain.chain[1]) + tester.newPeer("attack", protocol, brokenChain) + if err := tester.sync("attack", nil, mode); err == nil { + t.Fatalf("succeeded attacker synchronisation") + } + + // Synchronise with the valid peer and make sure sync succeeds + tester.newPeer("valid", protocol, chain) + if err := tester.sync("valid", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, chain.len()) +} + +// Tests that upon detecting an invalid header, the recent ones are rolled back +// for various failure scenarios. Afterwards a full sync is attempted to make +// sure no state was corrupted. +func TestInvalidHeaderRollback66Fast(t *testing.T) { testInvalidHeaderRollback(t, eth.ETH66, FastSync) } + +func testInvalidHeaderRollback(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + + // Create a small enough block chain to download + targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks + chain := testChainBase.shorten(targetBlocks) + + // Attempt to sync with an attacker that feeds junk during the fast sync phase. + // This should result in the last fsHeaderSafetyNet headers being rolled back. + missing := fsHeaderSafetyNet + MaxHeaderFetch + 1 + fastAttackChain := chain.shorten(chain.len()) + delete(fastAttackChain.headerm, fastAttackChain.chain[missing]) + tester.newPeer("fast-attack", protocol, fastAttackChain) + + if err := tester.sync("fast-attack", nil, mode); err == nil { + t.Fatalf("succeeded fast attacker synchronisation") + } + if head := tester.CurrentHeader().Number.Int64(); int(head) > MaxHeaderFetch { + t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) + } + + // Attempt to sync with an attacker that feeds junk during the block import phase. + // This should result in both the last fsHeaderSafetyNet number of headers being + // rolled back, and also the pivot point being reverted to a non-block status. + missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 + blockAttackChain := chain.shorten(chain.len()) + delete(fastAttackChain.headerm, fastAttackChain.chain[missing]) // Make sure the fast-attacker doesn't fill in + delete(blockAttackChain.headerm, blockAttackChain.chain[missing]) + tester.newPeer("block-attack", protocol, blockAttackChain) + + if err := tester.sync("block-attack", nil, mode); err == nil { + t.Fatalf("succeeded block attacker synchronisation") + } + if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { + t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) + } + if mode == FastSync { + if head := tester.CurrentBlock().NumberU64(); head != 0 { + t.Errorf("fast sync pivot block #%d not rolled back", head) + } + } + + // Attempt to sync with an attacker that withholds promised blocks after the + // fast sync pivot point. This could be a trial to leave the node with a bad + // but already imported pivot block. + withholdAttackChain := chain.shorten(chain.len()) + tester.newPeer("withhold-attack", protocol, withholdAttackChain) + tester.downloader.syncInitHook = func(uint64, uint64) { + for i := missing; i < withholdAttackChain.len(); i++ { + delete(withholdAttackChain.headerm, withholdAttackChain.chain[i]) + } + tester.downloader.syncInitHook = nil + } + if err := tester.sync("withhold-attack", nil, mode); err == nil { + t.Fatalf("succeeded withholding attacker synchronisation") + } + if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { + t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) + } + if mode == FastSync { + if head := tester.CurrentBlock().NumberU64(); head != 0 { + t.Errorf("fast sync pivot block #%d not rolled back", head) + } + } + + // synchronise with the valid peer and make sure sync succeeds. Since the last rollback + // should also disable fast syncing for this process, verify that we did a fresh full + // sync. Note, we can't assert anything about the receipts since we won't purge the + // database of them, hence we can't use assertOwnChain. + tester.newPeer("valid", protocol, chain) + if err := tester.sync("valid", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + if hs := len(tester.ownHeaders); hs != chain.len() { + t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, chain.len()) + } + if mode != LightSync { + if bs := len(tester.ownBlocks); bs != chain.len() { + t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, chain.len()) + } + } + tester.terminate() +} + +// Tests that a peer advertising a high TD doesn't get to stall the downloader +// afterwards by not sending any useful hashes. +func TestHighTDStarvationAttack66Full(t *testing.T) { + testHighTDStarvationAttack(t, eth.ETH66, FullSync) +} +func TestHighTDStarvationAttack66Fast(t *testing.T) { + testHighTDStarvationAttack(t, eth.ETH66, FastSync) +} +func TestHighTDStarvationAttack66Light(t *testing.T) { + testHighTDStarvationAttack(t, eth.ETH66, LightSync) +} + +func testHighTDStarvationAttack(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + + chain := testChainBase.shorten(1) + tester.newPeer("attack", protocol, chain) + if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer { + t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer) + } + tester.terminate() +} + +// Tests that misbehaving peers are disconnected, whilst behaving ones are not. +func TestBlockHeaderAttackerDropping66(t *testing.T) { testBlockHeaderAttackerDropping(t, eth.ETH66) } + +func testBlockHeaderAttackerDropping(t *testing.T, protocol uint) { + t.Parallel() + + // Define the disconnection requirement for individual hash fetch errors + tests := []struct { + result error + drop bool + }{ + {nil, false}, // Sync succeeded, all is well + {errBusy, false}, // Sync is already in progress, no problem + {errUnknownPeer, false}, // Peer is unknown, was already dropped, don't double drop + {errBadPeer, true}, // Peer was deemed bad for some reason, drop it + {errStallingPeer, true}, // Peer was detected to be stalling, drop it + {errUnsyncedPeer, true}, // Peer was detected to be unsynced, drop it + {errNoPeers, false}, // No peers to download from, soft race, no issue + {errTimeout, true}, // No hashes received in due time, drop the peer + {errEmptyHeaderSet, true}, // No headers were returned as a response, drop as it's a dead end + {errPeersUnavailable, true}, // Nobody had the advertised blocks, drop the advertiser + {errInvalidAncestor, true}, // Agreed upon ancestor is not acceptable, drop the chain rewriter + {errInvalidChain, true}, // Hash chain was detected as invalid, definitely drop + {errInvalidBody, false}, // A bad peer was detected, but not the sync origin + {errInvalidReceipt, false}, // A bad peer was detected, but not the sync origin + {errCancelContentProcessing, false}, // Synchronisation was canceled, origin may be innocent, don't drop + } + // Run the tests and check disconnection status + tester := newTester() + defer tester.terminate() + chain := testChainBase.shorten(1) + + for i, tt := range tests { + // Register a new peer and ensure its presence + id := fmt.Sprintf("test %d", i) + if err := tester.newPeer(id, protocol, chain); err != nil { + t.Fatalf("test %d: failed to register new peer: %v", i, err) + } + if _, ok := tester.peers[id]; !ok { + t.Fatalf("test %d: registered peer not found", i) + } + // Simulate a synchronisation and check the required result + tester.downloader.synchroniseMock = func(string, common.Hash) error { return tt.result } + + tester.downloader.Synchronise(id, tester.genesis.Hash(), big.NewInt(1000), FullSync) + if _, ok := tester.peers[id]; !ok != tt.drop { + t.Errorf("test %d: peer drop mismatch for %v: have %v, want %v", i, tt.result, !ok, tt.drop) + } + } +} + +// Tests that synchronisation progress (origin block number, current block number +// and highest block number) is tracked and updated correctly. +func TestSyncProgress66Full(t *testing.T) { testSyncProgress(t, eth.ETH66, FullSync) } +func TestSyncProgress66Fast(t *testing.T) { testSyncProgress(t, eth.ETH66, FastSync) } +func TestSyncProgress66Light(t *testing.T) { testSyncProgress(t, eth.ETH66, LightSync) } + +func testSyncProgress(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + chain := testChainBase.shorten(blockCacheMaxItems - 15) + + // Set a sync init hook to catch progress changes + starting := make(chan struct{}) + progress := make(chan struct{}) + + tester.downloader.syncInitHook = func(origin, latest uint64) { + starting <- struct{}{} + <-progress + } + checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{}) + + // Synchronise half the blocks and check initial progress + tester.newPeer("peer-half", protocol, chain.shorten(chain.len()/2)) + pending := new(sync.WaitGroup) + pending.Add(1) + + go func() { + defer pending.Done() + if err := tester.sync("peer-half", nil, mode); err != nil { + panic(fmt.Sprintf("failed to synchronise blocks: %v", err)) + } + }() + <-starting + checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{ + HighestBlock: uint64(chain.len()/2 - 1), + }) + progress <- struct{}{} + pending.Wait() + + // Synchronise all the blocks and check continuation progress + tester.newPeer("peer-full", protocol, chain) + pending.Add(1) + go func() { + defer pending.Done() + if err := tester.sync("peer-full", nil, mode); err != nil { + panic(fmt.Sprintf("failed to synchronise blocks: %v", err)) + } + }() + <-starting + checkProgress(t, tester.downloader, "completing", ethereum.SyncProgress{ + StartingBlock: uint64(chain.len()/2 - 1), + CurrentBlock: uint64(chain.len()/2 - 1), + HighestBlock: uint64(chain.len() - 1), + }) + + // Check final progress after successful sync + progress <- struct{}{} + pending.Wait() + checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{ + StartingBlock: uint64(chain.len()/2 - 1), + CurrentBlock: uint64(chain.len() - 1), + HighestBlock: uint64(chain.len() - 1), + }) +} + +func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.SyncProgress) { + // Mark this method as a helper to report errors at callsite, not in here + t.Helper() + + p := d.Progress() + p.KnownStates, p.PulledStates = 0, 0 + want.KnownStates, want.PulledStates = 0, 0 + if p != want { + t.Fatalf("%s progress mismatch:\nhave %+v\nwant %+v", stage, p, want) + } +} + +// Tests that synchronisation progress (origin block number and highest block +// number) is tracked and updated correctly in case of a fork (or manual head +// revertal). +func TestForkedSyncProgress66Full(t *testing.T) { testForkedSyncProgress(t, eth.ETH66, FullSync) } +func TestForkedSyncProgress66Fast(t *testing.T) { testForkedSyncProgress(t, eth.ETH66, FastSync) } +func TestForkedSyncProgress66Light(t *testing.T) { testForkedSyncProgress(t, eth.ETH66, LightSync) } + +func testForkedSyncProgress(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHeaderFetch) + chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHeaderFetch) + + // Set a sync init hook to catch progress changes + starting := make(chan struct{}) + progress := make(chan struct{}) + + tester.downloader.syncInitHook = func(origin, latest uint64) { + starting <- struct{}{} + <-progress + } + checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{}) + + // Synchronise with one of the forks and check progress + tester.newPeer("fork A", protocol, chainA) + pending := new(sync.WaitGroup) + pending.Add(1) + go func() { + defer pending.Done() + if err := tester.sync("fork A", nil, mode); err != nil { + panic(fmt.Sprintf("failed to synchronise blocks: %v", err)) + } + }() + <-starting + + checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{ + HighestBlock: uint64(chainA.len() - 1), + }) + progress <- struct{}{} + pending.Wait() + + // Simulate a successful sync above the fork + tester.downloader.syncStatsChainOrigin = tester.downloader.syncStatsChainHeight + + // Synchronise with the second fork and check progress resets + tester.newPeer("fork B", protocol, chainB) + pending.Add(1) + go func() { + defer pending.Done() + if err := tester.sync("fork B", nil, mode); err != nil { + panic(fmt.Sprintf("failed to synchronise blocks: %v", err)) + } + }() + <-starting + checkProgress(t, tester.downloader, "forking", ethereum.SyncProgress{ + StartingBlock: uint64(testChainBase.len()) - 1, + CurrentBlock: uint64(chainA.len() - 1), + HighestBlock: uint64(chainB.len() - 1), + }) + + // Check final progress after successful sync + progress <- struct{}{} + pending.Wait() + checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{ + StartingBlock: uint64(testChainBase.len()) - 1, + CurrentBlock: uint64(chainB.len() - 1), + HighestBlock: uint64(chainB.len() - 1), + }) +} + +// Tests that if synchronisation is aborted due to some failure, then the progress +// origin is not updated in the next sync cycle, as it should be considered the +// continuation of the previous sync and not a new instance. +func TestFailedSyncProgress66Full(t *testing.T) { testFailedSyncProgress(t, eth.ETH66, FullSync) } +func TestFailedSyncProgress66Fast(t *testing.T) { testFailedSyncProgress(t, eth.ETH66, FastSync) } +func TestFailedSyncProgress66Light(t *testing.T) { testFailedSyncProgress(t, eth.ETH66, LightSync) } + +func testFailedSyncProgress(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + chain := testChainBase.shorten(blockCacheMaxItems - 15) + + // Set a sync init hook to catch progress changes + starting := make(chan struct{}) + progress := make(chan struct{}) + + tester.downloader.syncInitHook = func(origin, latest uint64) { + starting <- struct{}{} + <-progress + } + checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{}) + + // Attempt a full sync with a faulty peer + brokenChain := chain.shorten(chain.len()) + missing := brokenChain.len() / 2 + delete(brokenChain.headerm, brokenChain.chain[missing]) + delete(brokenChain.blockm, brokenChain.chain[missing]) + delete(brokenChain.receiptm, brokenChain.chain[missing]) + tester.newPeer("faulty", protocol, brokenChain) + + pending := new(sync.WaitGroup) + pending.Add(1) + go func() { + defer pending.Done() + if err := tester.sync("faulty", nil, mode); err == nil { + panic("succeeded faulty synchronisation") + } + }() + <-starting + checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{ + HighestBlock: uint64(brokenChain.len() - 1), + }) + progress <- struct{}{} + pending.Wait() + afterFailedSync := tester.downloader.Progress() + + // Synchronise with a good peer and check that the progress origin remind the same + // after a failure + tester.newPeer("valid", protocol, chain) + pending.Add(1) + go func() { + defer pending.Done() + if err := tester.sync("valid", nil, mode); err != nil { + panic(fmt.Sprintf("failed to synchronise blocks: %v", err)) + } + }() + <-starting + checkProgress(t, tester.downloader, "completing", afterFailedSync) + + // Check final progress after successful sync + progress <- struct{}{} + pending.Wait() + checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{ + CurrentBlock: uint64(chain.len() - 1), + HighestBlock: uint64(chain.len() - 1), + }) +} + +// Tests that if an attacker fakes a chain height, after the attack is detected, +// the progress height is successfully reduced at the next sync invocation. +func TestFakedSyncProgress66Full(t *testing.T) { testFakedSyncProgress(t, eth.ETH66, FullSync) } +func TestFakedSyncProgress66Fast(t *testing.T) { testFakedSyncProgress(t, eth.ETH66, FastSync) } +func TestFakedSyncProgress66Light(t *testing.T) { testFakedSyncProgress(t, eth.ETH66, LightSync) } + +func testFakedSyncProgress(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + tester := newTester() + defer tester.terminate() + chain := testChainBase.shorten(blockCacheMaxItems - 15) + + // Set a sync init hook to catch progress changes + starting := make(chan struct{}) + progress := make(chan struct{}) + tester.downloader.syncInitHook = func(origin, latest uint64) { + starting <- struct{}{} + <-progress + } + checkProgress(t, tester.downloader, "pristine", ethereum.SyncProgress{}) + + // Create and sync with an attacker that promises a higher chain than available. + brokenChain := chain.shorten(chain.len()) + numMissing := 5 + for i := brokenChain.len() - 2; i > brokenChain.len()-numMissing; i-- { + delete(brokenChain.headerm, brokenChain.chain[i]) + } + tester.newPeer("attack", protocol, brokenChain) + + pending := new(sync.WaitGroup) + pending.Add(1) + go func() { + defer pending.Done() + if err := tester.sync("attack", nil, mode); err == nil { + panic("succeeded attacker synchronisation") + } + }() + <-starting + checkProgress(t, tester.downloader, "initial", ethereum.SyncProgress{ + HighestBlock: uint64(brokenChain.len() - 1), + }) + progress <- struct{}{} + pending.Wait() + afterFailedSync := tester.downloader.Progress() + + // Synchronise with a good peer and check that the progress height has been reduced to + // the true value. + validChain := chain.shorten(chain.len() - numMissing) + tester.newPeer("valid", protocol, validChain) + pending.Add(1) + + go func() { + defer pending.Done() + if err := tester.sync("valid", nil, mode); err != nil { + panic(fmt.Sprintf("failed to synchronise blocks: %v", err)) + } + }() + <-starting + checkProgress(t, tester.downloader, "completing", ethereum.SyncProgress{ + CurrentBlock: afterFailedSync.CurrentBlock, + HighestBlock: uint64(validChain.len() - 1), + }) + + // Check final progress after successful sync. + progress <- struct{}{} + pending.Wait() + checkProgress(t, tester.downloader, "final", ethereum.SyncProgress{ + CurrentBlock: uint64(validChain.len() - 1), + HighestBlock: uint64(validChain.len() - 1), + }) +} + +// This test reproduces an issue where unexpected deliveries would +// block indefinitely if they arrived at the right time. +func TestDeliverHeadersHang66Full(t *testing.T) { testDeliverHeadersHang(t, eth.ETH66, FullSync) } +func TestDeliverHeadersHang66Fast(t *testing.T) { testDeliverHeadersHang(t, eth.ETH66, FastSync) } +func TestDeliverHeadersHang66Light(t *testing.T) { testDeliverHeadersHang(t, eth.ETH66, LightSync) } + +func testDeliverHeadersHang(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + master := newTester() + defer master.terminate() + chain := testChainBase.shorten(15) + + for i := 0; i < 200; i++ { + tester := newTester() + tester.peerDb = master.peerDb + tester.newPeer("peer", protocol, chain) + + // Whenever the downloader requests headers, flood it with + // a lot of unrequested header deliveries. + tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{ + peer: tester.downloader.peers.peers["peer"].peer, + tester: tester, + } + if err := tester.sync("peer", nil, mode); err != nil { + t.Errorf("test %d: sync failed: %v", i, err) + } + tester.terminate() + } +} + +type floodingTestPeer struct { + peer Peer + tester *downloadTester +} + +func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() } +func (ftp *floodingTestPeer) RequestHeadersByHash(hash common.Hash, count int, skip int, reverse bool) error { + return ftp.peer.RequestHeadersByHash(hash, count, skip, reverse) +} +func (ftp *floodingTestPeer) RequestBodies(hashes []common.Hash) error { + return ftp.peer.RequestBodies(hashes) +} +func (ftp *floodingTestPeer) RequestReceipts(hashes []common.Hash) error { + return ftp.peer.RequestReceipts(hashes) +} +func (ftp *floodingTestPeer) RequestNodeData(hashes []common.Hash) error { + return ftp.peer.RequestNodeData(hashes) +} + +func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int, reverse bool) error { + deliveriesDone := make(chan struct{}, 500) + for i := 0; i < cap(deliveriesDone)-1; i++ { + peer := fmt.Sprintf("fake-peer%d", i) + go func() { + ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}}) + deliveriesDone <- struct{}{} + }() + } + + // None of the extra deliveries should block. + timeout := time.After(60 * time.Second) + launched := false + for i := 0; i < cap(deliveriesDone); i++ { + select { + case <-deliveriesDone: + if !launched { + // Start delivering the requested headers + // after one of the flooding responses has arrived. + go func() { + ftp.peer.RequestHeadersByNumber(from, count, skip, reverse) + deliveriesDone <- struct{}{} + }() + launched = true + } + case <-timeout: + panic("blocked") + } + } + return nil +} + +func TestRemoteHeaderRequestSpan(t *testing.T) { + testCases := []struct { + remoteHeight uint64 + localHeight uint64 + expected []int + }{ + // Remote is way higher. We should ask for the remote head and go backwards + {1500, 1000, + []int{1323, 1339, 1355, 1371, 1387, 1403, 1419, 1435, 1451, 1467, 1483, 1499}, + }, + {15000, 13006, + []int{14823, 14839, 14855, 14871, 14887, 14903, 14919, 14935, 14951, 14967, 14983, 14999}, + }, + // Remote is pretty close to us. We don't have to fetch as many + {1200, 1150, + []int{1149, 1154, 1159, 1164, 1169, 1174, 1179, 1184, 1189, 1194, 1199}, + }, + // Remote is equal to us (so on a fork with higher td) + // We should get the closest couple of ancestors + {1500, 1500, + []int{1497, 1499}, + }, + // We're higher than the remote! Odd + {1000, 1500, + []int{997, 999}, + }, + // Check some weird edgecases that it behaves somewhat rationally + {0, 1500, + []int{0, 2}, + }, + {6000000, 0, + []int{5999823, 5999839, 5999855, 5999871, 5999887, 5999903, 5999919, 5999935, 5999951, 5999967, 5999983, 5999999}, + }, + {0, 0, + []int{0, 2}, + }, + } + reqs := func(from, count, span int) []int { + var r []int + num := from + for len(r) < count { + r = append(r, num) + num += span + 1 + } + return r + } + for i, tt := range testCases { + from, count, span, max := calculateRequestSpan(tt.remoteHeight, tt.localHeight) + data := reqs(int(from), count, span) + + if max != uint64(data[len(data)-1]) { + t.Errorf("test %d: wrong last value %d != %d", i, data[len(data)-1], max) + } + failed := false + if len(data) != len(tt.expected) { + failed = true + t.Errorf("test %d: length wrong, expected %d got %d", i, len(tt.expected), len(data)) + } else { + for j, n := range data { + if n != tt.expected[j] { + failed = true + break + } + } + } + if failed { + res := strings.Replace(fmt.Sprint(data), " ", ",", -1) + exp := strings.Replace(fmt.Sprint(tt.expected), " ", ",", -1) + t.Logf("got: %v\n", res) + t.Logf("exp: %v\n", exp) + t.Errorf("test %d: wrong values", i) + } + } +} + +// Tests that peers below a pre-configured checkpoint block are prevented from +// being fast-synced from, avoiding potential cheap eclipse attacks. +func TestCheckpointEnforcement66Full(t *testing.T) { testCheckpointEnforcement(t, eth.ETH66, FullSync) } +func TestCheckpointEnforcement66Fast(t *testing.T) { testCheckpointEnforcement(t, eth.ETH66, FastSync) } +func TestCheckpointEnforcement66Light(t *testing.T) { + testCheckpointEnforcement(t, eth.ETH66, LightSync) +} + +func testCheckpointEnforcement(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() + + // Create a new tester with a particular hard coded checkpoint block + tester := newTester() + defer tester.terminate() + + tester.downloader.checkpoint = uint64(fsMinFullBlocks) + 256 + chain := testChainBase.shorten(int(tester.downloader.checkpoint) - 1) + + // Attempt to sync with the peer and validate the result + tester.newPeer("peer", protocol, chain) + + var expect error + if mode == FastSync || mode == LightSync { + expect = errUnsyncedPeer + } + if err := tester.sync("peer", nil, mode); !errors.Is(err, expect) { + t.Fatalf("block sync error mismatch: have %v, want %v", err, expect) + } + if mode == FastSync || mode == LightSync { + assertOwnChain(t, tester, 1) + } else { + assertOwnChain(t, tester, chain.len()) + } +} diff --git a/les/downloader/events.go b/les/downloader/events.go new file mode 100644 index 0000000000000..25255a3a72e5f --- /dev/null +++ b/les/downloader/events.go @@ -0,0 +1,25 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import "github.com/ethereum/go-ethereum/core/types" + +type DoneEvent struct { + Latest *types.Header +} +type StartEvent struct{} +type FailedEvent struct{ Err error } diff --git a/les/downloader/metrics.go b/les/downloader/metrics.go new file mode 100644 index 0000000000000..c38732043aa20 --- /dev/null +++ b/les/downloader/metrics.go @@ -0,0 +1,45 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Contains the metrics collected by the downloader. + +package downloader + +import ( + "github.com/ethereum/go-ethereum/metrics" +) + +var ( + headerInMeter = metrics.NewRegisteredMeter("eth/downloader/headers/in", nil) + headerReqTimer = metrics.NewRegisteredTimer("eth/downloader/headers/req", nil) + headerDropMeter = metrics.NewRegisteredMeter("eth/downloader/headers/drop", nil) + headerTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/headers/timeout", nil) + + bodyInMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/in", nil) + bodyReqTimer = metrics.NewRegisteredTimer("eth/downloader/bodies/req", nil) + bodyDropMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/drop", nil) + bodyTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/bodies/timeout", nil) + + receiptInMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/in", nil) + receiptReqTimer = metrics.NewRegisteredTimer("eth/downloader/receipts/req", nil) + receiptDropMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/drop", nil) + receiptTimeoutMeter = metrics.NewRegisteredMeter("eth/downloader/receipts/timeout", nil) + + stateInMeter = metrics.NewRegisteredMeter("eth/downloader/states/in", nil) + stateDropMeter = metrics.NewRegisteredMeter("eth/downloader/states/drop", nil) + + throttleCounter = metrics.NewRegisteredCounter("eth/downloader/throttle", nil) +) diff --git a/les/downloader/modes.go b/les/downloader/modes.go new file mode 100644 index 0000000000000..3ea14d22d7e09 --- /dev/null +++ b/les/downloader/modes.go @@ -0,0 +1,81 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import "fmt" + +// SyncMode represents the synchronisation mode of the downloader. +// It is a uint32 as it is used with atomic operations. +type SyncMode uint32 + +const ( + FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks + FastSync // Quickly download the headers, full sync only at the chain + SnapSync // Download the chain and the state via compact snapshots + LightSync // Download only the headers and terminate afterwards +) + +func (mode SyncMode) IsValid() bool { + return mode >= FullSync && mode <= LightSync +} + +// String implements the stringer interface. +func (mode SyncMode) String() string { + switch mode { + case FullSync: + return "full" + case FastSync: + return "fast" + case SnapSync: + return "snap" + case LightSync: + return "light" + default: + return "unknown" + } +} + +func (mode SyncMode) MarshalText() ([]byte, error) { + switch mode { + case FullSync: + return []byte("full"), nil + case FastSync: + return []byte("fast"), nil + case SnapSync: + return []byte("snap"), nil + case LightSync: + return []byte("light"), nil + default: + return nil, fmt.Errorf("unknown sync mode %d", mode) + } +} + +func (mode *SyncMode) UnmarshalText(text []byte) error { + switch string(text) { + case "full": + *mode = FullSync + case "fast": + *mode = FastSync + case "snap": + *mode = SnapSync + case "light": + *mode = LightSync + default: + return fmt.Errorf(`unknown sync mode %q, want "full", "fast" or "light"`, text) + } + return nil +} diff --git a/les/downloader/peer.go b/les/downloader/peer.go new file mode 100644 index 0000000000000..8632948329711 --- /dev/null +++ b/les/downloader/peer.go @@ -0,0 +1,501 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Contains the active peer-set of the downloader, maintaining both failures +// as well as reputation metrics to prioritize the block retrievals. + +package downloader + +import ( + "errors" + "math/big" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/msgrate" +) + +const ( + maxLackingHashes = 4096 // Maximum number of entries allowed on the list or lacking items +) + +var ( + errAlreadyFetching = errors.New("already fetching blocks from peer") + errAlreadyRegistered = errors.New("peer is already registered") + errNotRegistered = errors.New("peer is not registered") +) + +// peerConnection represents an active peer from which hashes and blocks are retrieved. +type peerConnection struct { + id string // Unique identifier of the peer + + headerIdle int32 // Current header activity state of the peer (idle = 0, active = 1) + blockIdle int32 // Current block activity state of the peer (idle = 0, active = 1) + receiptIdle int32 // Current receipt activity state of the peer (idle = 0, active = 1) + stateIdle int32 // Current node data activity state of the peer (idle = 0, active = 1) + + headerStarted time.Time // Time instance when the last header fetch was started + blockStarted time.Time // Time instance when the last block (body) fetch was started + receiptStarted time.Time // Time instance when the last receipt fetch was started + stateStarted time.Time // Time instance when the last node data fetch was started + + rates *msgrate.Tracker // Tracker to hone in on the number of items retrievable per second + lacking map[common.Hash]struct{} // Set of hashes not to request (didn't have previously) + + peer Peer + + version uint // Eth protocol version number to switch strategies + log log.Logger // Contextual logger to add extra infos to peer logs + lock sync.RWMutex +} + +// LightPeer encapsulates the methods required to synchronise with a remote light peer. +type LightPeer interface { + Head() (common.Hash, *big.Int) + RequestHeadersByHash(common.Hash, int, int, bool) error + RequestHeadersByNumber(uint64, int, int, bool) error +} + +// Peer encapsulates the methods required to synchronise with a remote full peer. +type Peer interface { + LightPeer + RequestBodies([]common.Hash) error + RequestReceipts([]common.Hash) error + RequestNodeData([]common.Hash) error +} + +// lightPeerWrapper wraps a LightPeer struct, stubbing out the Peer-only methods. +type lightPeerWrapper struct { + peer LightPeer +} + +func (w *lightPeerWrapper) Head() (common.Hash, *big.Int) { return w.peer.Head() } +func (w *lightPeerWrapper) RequestHeadersByHash(h common.Hash, amount int, skip int, reverse bool) error { + return w.peer.RequestHeadersByHash(h, amount, skip, reverse) +} +func (w *lightPeerWrapper) RequestHeadersByNumber(i uint64, amount int, skip int, reverse bool) error { + return w.peer.RequestHeadersByNumber(i, amount, skip, reverse) +} +func (w *lightPeerWrapper) RequestBodies([]common.Hash) error { + panic("RequestBodies not supported in light client mode sync") +} +func (w *lightPeerWrapper) RequestReceipts([]common.Hash) error { + panic("RequestReceipts not supported in light client mode sync") +} +func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error { + panic("RequestNodeData not supported in light client mode sync") +} + +// newPeerConnection creates a new downloader peer. +func newPeerConnection(id string, version uint, peer Peer, logger log.Logger) *peerConnection { + return &peerConnection{ + id: id, + lacking: make(map[common.Hash]struct{}), + peer: peer, + version: version, + log: logger, + } +} + +// Reset clears the internal state of a peer entity. +func (p *peerConnection) Reset() { + p.lock.Lock() + defer p.lock.Unlock() + + atomic.StoreInt32(&p.headerIdle, 0) + atomic.StoreInt32(&p.blockIdle, 0) + atomic.StoreInt32(&p.receiptIdle, 0) + atomic.StoreInt32(&p.stateIdle, 0) + + p.lacking = make(map[common.Hash]struct{}) +} + +// FetchHeaders sends a header retrieval request to the remote peer. +func (p *peerConnection) FetchHeaders(from uint64, count int) error { + // Short circuit if the peer is already fetching + if !atomic.CompareAndSwapInt32(&p.headerIdle, 0, 1) { + return errAlreadyFetching + } + p.headerStarted = time.Now() + + // Issue the header retrieval request (absolute upwards without gaps) + go p.peer.RequestHeadersByNumber(from, count, 0, false) + + return nil +} + +// FetchBodies sends a block body retrieval request to the remote peer. +func (p *peerConnection) FetchBodies(request *fetchRequest) error { + // Short circuit if the peer is already fetching + if !atomic.CompareAndSwapInt32(&p.blockIdle, 0, 1) { + return errAlreadyFetching + } + p.blockStarted = time.Now() + + go func() { + // Convert the header set to a retrievable slice + hashes := make([]common.Hash, 0, len(request.Headers)) + for _, header := range request.Headers { + hashes = append(hashes, header.Hash()) + } + p.peer.RequestBodies(hashes) + }() + + return nil +} + +// FetchReceipts sends a receipt retrieval request to the remote peer. +func (p *peerConnection) FetchReceipts(request *fetchRequest) error { + // Short circuit if the peer is already fetching + if !atomic.CompareAndSwapInt32(&p.receiptIdle, 0, 1) { + return errAlreadyFetching + } + p.receiptStarted = time.Now() + + go func() { + // Convert the header set to a retrievable slice + hashes := make([]common.Hash, 0, len(request.Headers)) + for _, header := range request.Headers { + hashes = append(hashes, header.Hash()) + } + p.peer.RequestReceipts(hashes) + }() + + return nil +} + +// FetchNodeData sends a node state data retrieval request to the remote peer. +func (p *peerConnection) FetchNodeData(hashes []common.Hash) error { + // Short circuit if the peer is already fetching + if !atomic.CompareAndSwapInt32(&p.stateIdle, 0, 1) { + return errAlreadyFetching + } + p.stateStarted = time.Now() + + go p.peer.RequestNodeData(hashes) + + return nil +} + +// SetHeadersIdle sets the peer to idle, allowing it to execute new header retrieval +// requests. Its estimated header retrieval throughput is updated with that measured +// just now. +func (p *peerConnection) SetHeadersIdle(delivered int, deliveryTime time.Time) { + p.rates.Update(eth.BlockHeadersMsg, deliveryTime.Sub(p.headerStarted), delivered) + atomic.StoreInt32(&p.headerIdle, 0) +} + +// SetBodiesIdle sets the peer to idle, allowing it to execute block body retrieval +// requests. Its estimated body retrieval throughput is updated with that measured +// just now. +func (p *peerConnection) SetBodiesIdle(delivered int, deliveryTime time.Time) { + p.rates.Update(eth.BlockBodiesMsg, deliveryTime.Sub(p.blockStarted), delivered) + atomic.StoreInt32(&p.blockIdle, 0) +} + +// SetReceiptsIdle sets the peer to idle, allowing it to execute new receipt +// retrieval requests. Its estimated receipt retrieval throughput is updated +// with that measured just now. +func (p *peerConnection) SetReceiptsIdle(delivered int, deliveryTime time.Time) { + p.rates.Update(eth.ReceiptsMsg, deliveryTime.Sub(p.receiptStarted), delivered) + atomic.StoreInt32(&p.receiptIdle, 0) +} + +// SetNodeDataIdle sets the peer to idle, allowing it to execute new state trie +// data retrieval requests. Its estimated state retrieval throughput is updated +// with that measured just now. +func (p *peerConnection) SetNodeDataIdle(delivered int, deliveryTime time.Time) { + p.rates.Update(eth.NodeDataMsg, deliveryTime.Sub(p.stateStarted), delivered) + atomic.StoreInt32(&p.stateIdle, 0) +} + +// HeaderCapacity retrieves the peers header download allowance based on its +// previously discovered throughput. +func (p *peerConnection) HeaderCapacity(targetRTT time.Duration) int { + cap := p.rates.Capacity(eth.BlockHeadersMsg, targetRTT) + if cap > MaxHeaderFetch { + cap = MaxHeaderFetch + } + return cap +} + +// BlockCapacity retrieves the peers block download allowance based on its +// previously discovered throughput. +func (p *peerConnection) BlockCapacity(targetRTT time.Duration) int { + cap := p.rates.Capacity(eth.BlockBodiesMsg, targetRTT) + if cap > MaxBlockFetch { + cap = MaxBlockFetch + } + return cap +} + +// ReceiptCapacity retrieves the peers receipt download allowance based on its +// previously discovered throughput. +func (p *peerConnection) ReceiptCapacity(targetRTT time.Duration) int { + cap := p.rates.Capacity(eth.ReceiptsMsg, targetRTT) + if cap > MaxReceiptFetch { + cap = MaxReceiptFetch + } + return cap +} + +// NodeDataCapacity retrieves the peers state download allowance based on its +// previously discovered throughput. +func (p *peerConnection) NodeDataCapacity(targetRTT time.Duration) int { + cap := p.rates.Capacity(eth.NodeDataMsg, targetRTT) + if cap > MaxStateFetch { + cap = MaxStateFetch + } + return cap +} + +// MarkLacking appends a new entity to the set of items (blocks, receipts, states) +// that a peer is known not to have (i.e. have been requested before). If the +// set reaches its maximum allowed capacity, items are randomly dropped off. +func (p *peerConnection) MarkLacking(hash common.Hash) { + p.lock.Lock() + defer p.lock.Unlock() + + for len(p.lacking) >= maxLackingHashes { + for drop := range p.lacking { + delete(p.lacking, drop) + break + } + } + p.lacking[hash] = struct{}{} +} + +// Lacks retrieves whether the hash of a blockchain item is on the peers lacking +// list (i.e. whether we know that the peer does not have it). +func (p *peerConnection) Lacks(hash common.Hash) bool { + p.lock.RLock() + defer p.lock.RUnlock() + + _, ok := p.lacking[hash] + return ok +} + +// peerSet represents the collection of active peer participating in the chain +// download procedure. +type peerSet struct { + peers map[string]*peerConnection + rates *msgrate.Trackers // Set of rate trackers to give the sync a common beat + + newPeerFeed event.Feed + peerDropFeed event.Feed + + lock sync.RWMutex +} + +// newPeerSet creates a new peer set top track the active download sources. +func newPeerSet() *peerSet { + return &peerSet{ + peers: make(map[string]*peerConnection), + rates: msgrate.NewTrackers(log.New("proto", "eth")), + } +} + +// SubscribeNewPeers subscribes to peer arrival events. +func (ps *peerSet) SubscribeNewPeers(ch chan<- *peerConnection) event.Subscription { + return ps.newPeerFeed.Subscribe(ch) +} + +// SubscribePeerDrops subscribes to peer departure events. +func (ps *peerSet) SubscribePeerDrops(ch chan<- *peerConnection) event.Subscription { + return ps.peerDropFeed.Subscribe(ch) +} + +// Reset iterates over the current peer set, and resets each of the known peers +// to prepare for a next batch of block retrieval. +func (ps *peerSet) Reset() { + ps.lock.RLock() + defer ps.lock.RUnlock() + + for _, peer := range ps.peers { + peer.Reset() + } +} + +// Register injects a new peer into the working set, or returns an error if the +// peer is already known. +// +// The method also sets the starting throughput values of the new peer to the +// average of all existing peers, to give it a realistic chance of being used +// for data retrievals. +func (ps *peerSet) Register(p *peerConnection) error { + // Register the new peer with some meaningful defaults + ps.lock.Lock() + if _, ok := ps.peers[p.id]; ok { + ps.lock.Unlock() + return errAlreadyRegistered + } + p.rates = msgrate.NewTracker(ps.rates.MeanCapacities(), ps.rates.MedianRoundTrip()) + if err := ps.rates.Track(p.id, p.rates); err != nil { + return err + } + ps.peers[p.id] = p + ps.lock.Unlock() + + ps.newPeerFeed.Send(p) + return nil +} + +// Unregister removes a remote peer from the active set, disabling any further +// actions to/from that particular entity. +func (ps *peerSet) Unregister(id string) error { + ps.lock.Lock() + p, ok := ps.peers[id] + if !ok { + ps.lock.Unlock() + return errNotRegistered + } + delete(ps.peers, id) + ps.rates.Untrack(id) + ps.lock.Unlock() + + ps.peerDropFeed.Send(p) + return nil +} + +// Peer retrieves the registered peer with the given id. +func (ps *peerSet) Peer(id string) *peerConnection { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.peers[id] +} + +// Len returns if the current number of peers in the set. +func (ps *peerSet) Len() int { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return len(ps.peers) +} + +// AllPeers retrieves a flat list of all the peers within the set. +func (ps *peerSet) AllPeers() []*peerConnection { + ps.lock.RLock() + defer ps.lock.RUnlock() + + list := make([]*peerConnection, 0, len(ps.peers)) + for _, p := range ps.peers { + list = append(list, p) + } + return list +} + +// HeaderIdlePeers retrieves a flat list of all the currently header-idle peers +// within the active peer set, ordered by their reputation. +func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { + return atomic.LoadInt32(&p.headerIdle) == 0 + } + throughput := func(p *peerConnection) int { + return p.rates.Capacity(eth.BlockHeadersMsg, time.Second) + } + return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput) +} + +// BodyIdlePeers retrieves a flat list of all the currently body-idle peers within +// the active peer set, ordered by their reputation. +func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { + return atomic.LoadInt32(&p.blockIdle) == 0 + } + throughput := func(p *peerConnection) int { + return p.rates.Capacity(eth.BlockBodiesMsg, time.Second) + } + return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput) +} + +// ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers +// within the active peer set, ordered by their reputation. +func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { + return atomic.LoadInt32(&p.receiptIdle) == 0 + } + throughput := func(p *peerConnection) int { + return p.rates.Capacity(eth.ReceiptsMsg, time.Second) + } + return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput) +} + +// NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle +// peers within the active peer set, ordered by their reputation. +func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) { + idle := func(p *peerConnection) bool { + return atomic.LoadInt32(&p.stateIdle) == 0 + } + throughput := func(p *peerConnection) int { + return p.rates.Capacity(eth.NodeDataMsg, time.Second) + } + return ps.idlePeers(eth.ETH66, eth.ETH66, idle, throughput) +} + +// idlePeers retrieves a flat list of all currently idle peers satisfying the +// protocol version constraints, using the provided function to check idleness. +// The resulting set of peers are sorted by their capacity. +func (ps *peerSet) idlePeers(minProtocol, maxProtocol uint, idleCheck func(*peerConnection) bool, capacity func(*peerConnection) int) ([]*peerConnection, int) { + ps.lock.RLock() + defer ps.lock.RUnlock() + + var ( + total = 0 + idle = make([]*peerConnection, 0, len(ps.peers)) + tps = make([]int, 0, len(ps.peers)) + ) + for _, p := range ps.peers { + if p.version >= minProtocol && p.version <= maxProtocol { + if idleCheck(p) { + idle = append(idle, p) + tps = append(tps, capacity(p)) + } + total++ + } + } + + // And sort them + sortPeers := &peerCapacitySort{idle, tps} + sort.Sort(sortPeers) + return sortPeers.p, total +} + +// peerCapacitySort implements sort.Interface. +// It sorts peer connections by capacity (descending). +type peerCapacitySort struct { + p []*peerConnection + tp []int +} + +func (ps *peerCapacitySort) Len() int { + return len(ps.p) +} + +func (ps *peerCapacitySort) Less(i, j int) bool { + return ps.tp[i] > ps.tp[j] +} + +func (ps *peerCapacitySort) Swap(i, j int) { + ps.p[i], ps.p[j] = ps.p[j], ps.p[i] + ps.tp[i], ps.tp[j] = ps.tp[j], ps.tp[i] +} diff --git a/les/downloader/queue.go b/les/downloader/queue.go new file mode 100644 index 0000000000000..04ec12cfa9e7c --- /dev/null +++ b/les/downloader/queue.go @@ -0,0 +1,913 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Contains the block download scheduler to collect download tasks and schedule +// them in an ordered, and throttled way. + +package downloader + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/prque" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + bodyType = uint(0) + receiptType = uint(1) +) + +var ( + blockCacheMaxItems = 8192 // Maximum number of blocks to cache before throttling the download + blockCacheInitialItems = 2048 // Initial number of blocks to start fetching, before we know the sizes of the blocks + blockCacheMemory = 256 * 1024 * 1024 // Maximum amount of memory to use for block caching + blockCacheSizeWeight = 0.1 // Multiplier to approximate the average block size based on past ones +) + +var ( + errNoFetchesPending = errors.New("no fetches pending") + errStaleDelivery = errors.New("stale delivery") +) + +// fetchRequest is a currently running data retrieval operation. +type fetchRequest struct { + Peer *peerConnection // Peer to which the request was sent + From uint64 // [eth/62] Requested chain element index (used for skeleton fills only) + Headers []*types.Header // [eth/62] Requested headers, sorted by request order + Time time.Time // Time when the request was made +} + +// fetchResult is a struct collecting partial results from data fetchers until +// all outstanding pieces complete and the result as a whole can be processed. +type fetchResult struct { + pending int32 // Flag telling what deliveries are outstanding + + Header *types.Header + Uncles []*types.Header + Transactions types.Transactions + Receipts types.Receipts +} + +func newFetchResult(header *types.Header, fastSync bool) *fetchResult { + item := &fetchResult{ + Header: header, + } + if !header.EmptyBody() { + item.pending |= (1 << bodyType) + } + if fastSync && !header.EmptyReceipts() { + item.pending |= (1 << receiptType) + } + return item +} + +// SetBodyDone flags the body as finished. +func (f *fetchResult) SetBodyDone() { + if v := atomic.LoadInt32(&f.pending); (v & (1 << bodyType)) != 0 { + atomic.AddInt32(&f.pending, -1) + } +} + +// AllDone checks if item is done. +func (f *fetchResult) AllDone() bool { + return atomic.LoadInt32(&f.pending) == 0 +} + +// SetReceiptsDone flags the receipts as finished. +func (f *fetchResult) SetReceiptsDone() { + if v := atomic.LoadInt32(&f.pending); (v & (1 << receiptType)) != 0 { + atomic.AddInt32(&f.pending, -2) + } +} + +// Done checks if the given type is done already +func (f *fetchResult) Done(kind uint) bool { + v := atomic.LoadInt32(&f.pending) + return v&(1< 0 +} + +// InFlightBlocks retrieves whether there are block fetch requests currently in +// flight. +func (q *queue) InFlightBlocks() bool { + q.lock.Lock() + defer q.lock.Unlock() + + return len(q.blockPendPool) > 0 +} + +// InFlightReceipts retrieves whether there are receipt fetch requests currently +// in flight. +func (q *queue) InFlightReceipts() bool { + q.lock.Lock() + defer q.lock.Unlock() + + return len(q.receiptPendPool) > 0 +} + +// Idle returns if the queue is fully idle or has some data still inside. +func (q *queue) Idle() bool { + q.lock.Lock() + defer q.lock.Unlock() + + queued := q.blockTaskQueue.Size() + q.receiptTaskQueue.Size() + pending := len(q.blockPendPool) + len(q.receiptPendPool) + + return (queued + pending) == 0 +} + +// ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill +// up an already retrieved header skeleton. +func (q *queue) ScheduleSkeleton(from uint64, skeleton []*types.Header) { + q.lock.Lock() + defer q.lock.Unlock() + + // No skeleton retrieval can be in progress, fail hard if so (huge implementation bug) + if q.headerResults != nil { + panic("skeleton assembly already in progress") + } + // Schedule all the header retrieval tasks for the skeleton assembly + q.headerTaskPool = make(map[uint64]*types.Header) + q.headerTaskQueue = prque.New(nil) + q.headerPeerMiss = make(map[string]map[uint64]struct{}) // Reset availability to correct invalid chains + q.headerResults = make([]*types.Header, len(skeleton)*MaxHeaderFetch) + q.headerProced = 0 + q.headerOffset = from + q.headerContCh = make(chan bool, 1) + + for i, header := range skeleton { + index := from + uint64(i*MaxHeaderFetch) + + q.headerTaskPool[index] = header + q.headerTaskQueue.Push(index, -int64(index)) + } +} + +// RetrieveHeaders retrieves the header chain assemble based on the scheduled +// skeleton. +func (q *queue) RetrieveHeaders() ([]*types.Header, int) { + q.lock.Lock() + defer q.lock.Unlock() + + headers, proced := q.headerResults, q.headerProced + q.headerResults, q.headerProced = nil, 0 + + return headers, proced +} + +// Schedule adds a set of headers for the download queue for scheduling, returning +// the new headers encountered. +func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { + q.lock.Lock() + defer q.lock.Unlock() + + // Insert all the headers prioritised by the contained block number + inserts := make([]*types.Header, 0, len(headers)) + for _, header := range headers { + // Make sure chain order is honoured and preserved throughout + hash := header.Hash() + if header.Number == nil || header.Number.Uint64() != from { + log.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", from) + break + } + if q.headerHead != (common.Hash{}) && q.headerHead != header.ParentHash { + log.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash) + break + } + // Make sure no duplicate requests are executed + // We cannot skip this, even if the block is empty, since this is + // what triggers the fetchResult creation. + if _, ok := q.blockTaskPool[hash]; ok { + log.Warn("Header already scheduled for block fetch", "number", header.Number, "hash", hash) + } else { + q.blockTaskPool[hash] = header + q.blockTaskQueue.Push(header, -int64(header.Number.Uint64())) + } + // Queue for receipt retrieval + if q.mode == FastSync && !header.EmptyReceipts() { + if _, ok := q.receiptTaskPool[hash]; ok { + log.Warn("Header already scheduled for receipt fetch", "number", header.Number, "hash", hash) + } else { + q.receiptTaskPool[hash] = header + q.receiptTaskQueue.Push(header, -int64(header.Number.Uint64())) + } + } + inserts = append(inserts, header) + q.headerHead = hash + from++ + } + return inserts +} + +// Results retrieves and permanently removes a batch of fetch results from +// the cache. the result slice will be empty if the queue has been closed. +// Results can be called concurrently with Deliver and Schedule, +// but assumes that there are not two simultaneous callers to Results +func (q *queue) Results(block bool) []*fetchResult { + // Abort early if there are no items and non-blocking requested + if !block && !q.resultCache.HasCompletedItems() { + return nil + } + closed := false + for !closed && !q.resultCache.HasCompletedItems() { + // In order to wait on 'active', we need to obtain the lock. + // That may take a while, if someone is delivering at the same + // time, so after obtaining the lock, we check again if there + // are any results to fetch. + // Also, in-between we ask for the lock and the lock is obtained, + // someone can have closed the queue. In that case, we should + // return the available results and stop blocking + q.lock.Lock() + if q.resultCache.HasCompletedItems() || q.closed { + q.lock.Unlock() + break + } + // No items available, and not closed + q.active.Wait() + closed = q.closed + q.lock.Unlock() + } + // Regardless if closed or not, we can still deliver whatever we have + results := q.resultCache.GetCompleted(maxResultsProcess) + for _, result := range results { + // Recalculate the result item weights to prevent memory exhaustion + size := result.Header.Size() + for _, uncle := range result.Uncles { + size += uncle.Size() + } + for _, receipt := range result.Receipts { + size += receipt.Size() + } + for _, tx := range result.Transactions { + size += tx.Size() + } + q.resultSize = common.StorageSize(blockCacheSizeWeight)*size + + (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize + } + // Using the newly calibrated resultsize, figure out the new throttle limit + // on the result cache + throttleThreshold := uint64((common.StorageSize(blockCacheMemory) + q.resultSize - 1) / q.resultSize) + throttleThreshold = q.resultCache.SetThrottleThreshold(throttleThreshold) + + // Log some info at certain times + if time.Since(q.lastStatLog) > 60*time.Second { + q.lastStatLog = time.Now() + info := q.Stats() + info = append(info, "throttle", throttleThreshold) + log.Info("Downloader queue stats", info...) + } + return results +} + +func (q *queue) Stats() []interface{} { + q.lock.RLock() + defer q.lock.RUnlock() + + return q.stats() +} + +func (q *queue) stats() []interface{} { + return []interface{}{ + "receiptTasks", q.receiptTaskQueue.Size(), + "blockTasks", q.blockTaskQueue.Size(), + "itemSize", q.resultSize, + } +} + +// ReserveHeaders reserves a set of headers for the given peer, skipping any +// previously failed batches. +func (q *queue) ReserveHeaders(p *peerConnection, count int) *fetchRequest { + q.lock.Lock() + defer q.lock.Unlock() + + // Short circuit if the peer's already downloading something (sanity check to + // not corrupt state) + if _, ok := q.headerPendPool[p.id]; ok { + return nil + } + // Retrieve a batch of hashes, skipping previously failed ones + send, skip := uint64(0), []uint64{} + for send == 0 && !q.headerTaskQueue.Empty() { + from, _ := q.headerTaskQueue.Pop() + if q.headerPeerMiss[p.id] != nil { + if _, ok := q.headerPeerMiss[p.id][from.(uint64)]; ok { + skip = append(skip, from.(uint64)) + continue + } + } + send = from.(uint64) + } + // Merge all the skipped batches back + for _, from := range skip { + q.headerTaskQueue.Push(from, -int64(from)) + } + // Assemble and return the block download request + if send == 0 { + return nil + } + request := &fetchRequest{ + Peer: p, + From: send, + Time: time.Now(), + } + q.headerPendPool[p.id] = request + return request +} + +// ReserveBodies reserves a set of body fetches for the given peer, skipping any +// previously failed downloads. Beside the next batch of needed fetches, it also +// returns a flag whether empty blocks were queued requiring processing. +func (q *queue) ReserveBodies(p *peerConnection, count int) (*fetchRequest, bool, bool) { + q.lock.Lock() + defer q.lock.Unlock() + + return q.reserveHeaders(p, count, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, bodyType) +} + +// ReserveReceipts reserves a set of receipt fetches for the given peer, skipping +// any previously failed downloads. Beside the next batch of needed fetches, it +// also returns a flag whether empty receipts were queued requiring importing. +func (q *queue) ReserveReceipts(p *peerConnection, count int) (*fetchRequest, bool, bool) { + q.lock.Lock() + defer q.lock.Unlock() + + return q.reserveHeaders(p, count, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, receiptType) +} + +// reserveHeaders reserves a set of data download operations for a given peer, +// skipping any previously failed ones. This method is a generic version used +// by the individual special reservation functions. +// +// Note, this method expects the queue lock to be already held for writing. The +// reason the lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. +// +// Returns: +// item - the fetchRequest +// progress - whether any progress was made +// throttle - if the caller should throttle for a while +func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common.Hash]*types.Header, taskQueue *prque.Prque, + pendPool map[string]*fetchRequest, kind uint) (*fetchRequest, bool, bool) { + // Short circuit if the pool has been depleted, or if the peer's already + // downloading something (sanity check not to corrupt state) + if taskQueue.Empty() { + return nil, false, true + } + if _, ok := pendPool[p.id]; ok { + return nil, false, false + } + // Retrieve a batch of tasks, skipping previously failed ones + send := make([]*types.Header, 0, count) + skip := make([]*types.Header, 0) + progress := false + throttled := false + for proc := 0; len(send) < count && !taskQueue.Empty(); proc++ { + // the task queue will pop items in order, so the highest prio block + // is also the lowest block number. + h, _ := taskQueue.Peek() + header := h.(*types.Header) + // we can ask the resultcache if this header is within the + // "prioritized" segment of blocks. If it is not, we need to throttle + + stale, throttle, item, err := q.resultCache.AddFetch(header, q.mode == FastSync) + if stale { + // Don't put back in the task queue, this item has already been + // delivered upstream + taskQueue.PopItem() + progress = true + delete(taskPool, header.Hash()) + proc = proc - 1 + log.Error("Fetch reservation already delivered", "number", header.Number.Uint64()) + continue + } + if throttle { + // There are no resultslots available. Leave it in the task queue + // However, if there are any left as 'skipped', we should not tell + // the caller to throttle, since we still want some other + // peer to fetch those for us + throttled = len(skip) == 0 + break + } + if err != nil { + // this most definitely should _not_ happen + log.Warn("Failed to reserve headers", "err", err) + // There are no resultslots available. Leave it in the task queue + break + } + if item.Done(kind) { + // If it's a noop, we can skip this task + delete(taskPool, header.Hash()) + taskQueue.PopItem() + proc = proc - 1 + progress = true + continue + } + // Remove it from the task queue + taskQueue.PopItem() + // Otherwise unless the peer is known not to have the data, add to the retrieve list + if p.Lacks(header.Hash()) { + skip = append(skip, header) + } else { + send = append(send, header) + } + } + // Merge all the skipped headers back + for _, header := range skip { + taskQueue.Push(header, -int64(header.Number.Uint64())) + } + if q.resultCache.HasCompletedItems() { + // Wake Results, resultCache was modified + q.active.Signal() + } + // Assemble and return the block download request + if len(send) == 0 { + return nil, progress, throttled + } + request := &fetchRequest{ + Peer: p, + Headers: send, + Time: time.Now(), + } + pendPool[p.id] = request + return request, progress, throttled +} + +// CancelHeaders aborts a fetch request, returning all pending skeleton indexes to the queue. +func (q *queue) CancelHeaders(request *fetchRequest) { + q.lock.Lock() + defer q.lock.Unlock() + q.cancel(request, q.headerTaskQueue, q.headerPendPool) +} + +// CancelBodies aborts a body fetch request, returning all pending headers to the +// task queue. +func (q *queue) CancelBodies(request *fetchRequest) { + q.lock.Lock() + defer q.lock.Unlock() + q.cancel(request, q.blockTaskQueue, q.blockPendPool) +} + +// CancelReceipts aborts a body fetch request, returning all pending headers to +// the task queue. +func (q *queue) CancelReceipts(request *fetchRequest) { + q.lock.Lock() + defer q.lock.Unlock() + q.cancel(request, q.receiptTaskQueue, q.receiptPendPool) +} + +// Cancel aborts a fetch request, returning all pending hashes to the task queue. +func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool map[string]*fetchRequest) { + if request.From > 0 { + taskQueue.Push(request.From, -int64(request.From)) + } + for _, header := range request.Headers { + taskQueue.Push(header, -int64(header.Number.Uint64())) + } + delete(pendPool, request.Peer.id) +} + +// Revoke cancels all pending requests belonging to a given peer. This method is +// meant to be called during a peer drop to quickly reassign owned data fetches +// to remaining nodes. +func (q *queue) Revoke(peerID string) { + q.lock.Lock() + defer q.lock.Unlock() + + if request, ok := q.blockPendPool[peerID]; ok { + for _, header := range request.Headers { + q.blockTaskQueue.Push(header, -int64(header.Number.Uint64())) + } + delete(q.blockPendPool, peerID) + } + if request, ok := q.receiptPendPool[peerID]; ok { + for _, header := range request.Headers { + q.receiptTaskQueue.Push(header, -int64(header.Number.Uint64())) + } + delete(q.receiptPendPool, peerID) + } +} + +// ExpireHeaders checks for in flight requests that exceeded a timeout allowance, +// canceling them and returning the responsible peers for penalisation. +func (q *queue) ExpireHeaders(timeout time.Duration) map[string]int { + q.lock.Lock() + defer q.lock.Unlock() + + return q.expire(timeout, q.headerPendPool, q.headerTaskQueue, headerTimeoutMeter) +} + +// ExpireBodies checks for in flight block body requests that exceeded a timeout +// allowance, canceling them and returning the responsible peers for penalisation. +func (q *queue) ExpireBodies(timeout time.Duration) map[string]int { + q.lock.Lock() + defer q.lock.Unlock() + + return q.expire(timeout, q.blockPendPool, q.blockTaskQueue, bodyTimeoutMeter) +} + +// ExpireReceipts checks for in flight receipt requests that exceeded a timeout +// allowance, canceling them and returning the responsible peers for penalisation. +func (q *queue) ExpireReceipts(timeout time.Duration) map[string]int { + q.lock.Lock() + defer q.lock.Unlock() + + return q.expire(timeout, q.receiptPendPool, q.receiptTaskQueue, receiptTimeoutMeter) +} + +// expire is the generic check that move expired tasks from a pending pool back +// into a task pool, returning all entities caught with expired tasks. +// +// Note, this method expects the queue lock to be already held. The +// reason the lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. +func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, taskQueue *prque.Prque, timeoutMeter metrics.Meter) map[string]int { + // Iterate over the expired requests and return each to the queue + expiries := make(map[string]int) + for id, request := range pendPool { + if time.Since(request.Time) > timeout { + // Update the metrics with the timeout + timeoutMeter.Mark(1) + + // Return any non satisfied requests to the pool + if request.From > 0 { + taskQueue.Push(request.From, -int64(request.From)) + } + for _, header := range request.Headers { + taskQueue.Push(header, -int64(header.Number.Uint64())) + } + // Add the peer to the expiry report along the number of failed requests + expiries[id] = len(request.Headers) + + // Remove the expired requests from the pending pool directly + delete(pendPool, id) + } + } + return expiries +} + +// DeliverHeaders injects a header retrieval response into the header results +// cache. This method either accepts all headers it received, or none of them +// if they do not map correctly to the skeleton. +// +// If the headers are accepted, the method makes an attempt to deliver the set +// of ready headers to the processor to keep the pipeline full. However it will +// not block to prevent stalling other pending deliveries. +func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh chan []*types.Header) (int, error) { + q.lock.Lock() + defer q.lock.Unlock() + + var logger log.Logger + if len(id) < 16 { + // Tests use short IDs, don't choke on them + logger = log.New("peer", id) + } else { + logger = log.New("peer", id[:16]) + } + // Short circuit if the data was never requested + request := q.headerPendPool[id] + if request == nil { + return 0, errNoFetchesPending + } + headerReqTimer.UpdateSince(request.Time) + delete(q.headerPendPool, id) + + // Ensure headers can be mapped onto the skeleton chain + target := q.headerTaskPool[request.From].Hash() + + accepted := len(headers) == MaxHeaderFetch + if accepted { + if headers[0].Number.Uint64() != request.From { + logger.Trace("First header broke chain ordering", "number", headers[0].Number, "hash", headers[0].Hash(), "expected", request.From) + accepted = false + } else if headers[len(headers)-1].Hash() != target { + logger.Trace("Last header broke skeleton structure ", "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target) + accepted = false + } + } + if accepted { + parentHash := headers[0].Hash() + for i, header := range headers[1:] { + hash := header.Hash() + if want := request.From + 1 + uint64(i); header.Number.Uint64() != want { + logger.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", want) + accepted = false + break + } + if parentHash != header.ParentHash { + logger.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash) + accepted = false + break + } + // Set-up parent hash for next round + parentHash = hash + } + } + // If the batch of headers wasn't accepted, mark as unavailable + if !accepted { + logger.Trace("Skeleton filling not accepted", "from", request.From) + + miss := q.headerPeerMiss[id] + if miss == nil { + q.headerPeerMiss[id] = make(map[uint64]struct{}) + miss = q.headerPeerMiss[id] + } + miss[request.From] = struct{}{} + + q.headerTaskQueue.Push(request.From, -int64(request.From)) + return 0, errors.New("delivery not accepted") + } + // Clean up a successful fetch and try to deliver any sub-results + copy(q.headerResults[request.From-q.headerOffset:], headers) + delete(q.headerTaskPool, request.From) + + ready := 0 + for q.headerProced+ready < len(q.headerResults) && q.headerResults[q.headerProced+ready] != nil { + ready += MaxHeaderFetch + } + if ready > 0 { + // Headers are ready for delivery, gather them and push forward (non blocking) + process := make([]*types.Header, ready) + copy(process, q.headerResults[q.headerProced:q.headerProced+ready]) + + select { + case headerProcCh <- process: + logger.Trace("Pre-scheduled new headers", "count", len(process), "from", process[0].Number) + q.headerProced += len(process) + default: + } + } + // Check for termination and return + if len(q.headerTaskPool) == 0 { + q.headerContCh <- false + } + return len(headers), nil +} + +// DeliverBodies injects a block body retrieval response into the results queue. +// The method returns the number of blocks bodies accepted from the delivery and +// also wakes any threads waiting for data delivery. +func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLists [][]*types.Header) (int, error) { + q.lock.Lock() + defer q.lock.Unlock() + trieHasher := trie.NewStackTrie(nil) + validate := func(index int, header *types.Header) error { + if types.DeriveSha(types.Transactions(txLists[index]), trieHasher) != header.TxHash { + return errInvalidBody + } + if types.CalcUncleHash(uncleLists[index]) != header.UncleHash { + return errInvalidBody + } + return nil + } + + reconstruct := func(index int, result *fetchResult) { + result.Transactions = txLists[index] + result.Uncles = uncleLists[index] + result.SetBodyDone() + } + return q.deliver(id, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, + bodyReqTimer, len(txLists), validate, reconstruct) +} + +// DeliverReceipts injects a receipt retrieval response into the results queue. +// The method returns the number of transaction receipts accepted from the delivery +// and also wakes any threads waiting for data delivery. +func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, error) { + q.lock.Lock() + defer q.lock.Unlock() + trieHasher := trie.NewStackTrie(nil) + validate := func(index int, header *types.Header) error { + if types.DeriveSha(types.Receipts(receiptList[index]), trieHasher) != header.ReceiptHash { + return errInvalidReceipt + } + return nil + } + reconstruct := func(index int, result *fetchResult) { + result.Receipts = receiptList[index] + result.SetReceiptsDone() + } + return q.deliver(id, q.receiptTaskPool, q.receiptTaskQueue, q.receiptPendPool, + receiptReqTimer, len(receiptList), validate, reconstruct) +} + +// deliver injects a data retrieval response into the results queue. +// +// Note, this method expects the queue lock to be already held for writing. The +// reason this lock is not obtained in here is because the parameters already need +// to access the queue, so they already need a lock anyway. +func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, + taskQueue *prque.Prque, pendPool map[string]*fetchRequest, reqTimer metrics.Timer, + results int, validate func(index int, header *types.Header) error, + reconstruct func(index int, result *fetchResult)) (int, error) { + + // Short circuit if the data was never requested + request := pendPool[id] + if request == nil { + return 0, errNoFetchesPending + } + reqTimer.UpdateSince(request.Time) + delete(pendPool, id) + + // If no data items were retrieved, mark them as unavailable for the origin peer + if results == 0 { + for _, header := range request.Headers { + request.Peer.MarkLacking(header.Hash()) + } + } + // Assemble each of the results with their headers and retrieved data parts + var ( + accepted int + failure error + i int + hashes []common.Hash + ) + for _, header := range request.Headers { + // Short circuit assembly if no more fetch results are found + if i >= results { + break + } + // Validate the fields + if err := validate(i, header); err != nil { + failure = err + break + } + hashes = append(hashes, header.Hash()) + i++ + } + + for _, header := range request.Headers[:i] { + if res, stale, err := q.resultCache.GetDeliverySlot(header.Number.Uint64()); err == nil { + reconstruct(accepted, res) + } else { + // else: betweeen here and above, some other peer filled this result, + // or it was indeed a no-op. This should not happen, but if it does it's + // not something to panic about + log.Error("Delivery stale", "stale", stale, "number", header.Number.Uint64(), "err", err) + failure = errStaleDelivery + } + // Clean up a successful fetch + delete(taskPool, hashes[accepted]) + accepted++ + } + // Return all failed or missing fetches to the queue + for _, header := range request.Headers[accepted:] { + taskQueue.Push(header, -int64(header.Number.Uint64())) + } + // Wake up Results + if accepted > 0 { + q.active.Signal() + } + if failure == nil { + return accepted, nil + } + // If none of the data was good, it's a stale delivery + if accepted > 0 { + return accepted, fmt.Errorf("partial failure: %v", failure) + } + return accepted, fmt.Errorf("%w: %v", failure, errStaleDelivery) +} + +// Prepare configures the result cache to allow accepting and caching inbound +// fetch results. +func (q *queue) Prepare(offset uint64, mode SyncMode) { + q.lock.Lock() + defer q.lock.Unlock() + + // Prepare the queue for sync results + q.resultCache.Prepare(offset) + q.mode = mode +} diff --git a/les/downloader/queue_test.go b/les/downloader/queue_test.go new file mode 100644 index 0000000000000..cde5f306a2c07 --- /dev/null +++ b/les/downloader/queue_test.go @@ -0,0 +1,452 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "fmt" + "math/big" + "math/rand" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/params" +) + +var ( + testdb = rawdb.NewMemoryDatabase() + genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000000000)) +) + +// makeChain creates a chain of n blocks starting at and including parent. +// the returned hash chain is ordered head->parent. In addition, every 3rd block +// contains a transaction and every 5th an uncle to allow testing correct block +// reassembly. +func makeChain(n int, seed byte, parent *types.Block, empty bool) ([]*types.Block, []types.Receipts) { + blocks, receipts := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), testdb, n, func(i int, block *core.BlockGen) { + block.SetCoinbase(common.Address{seed}) + // Add one tx to every secondblock + if !empty && i%2 == 0 { + signer := types.MakeSigner(params.TestChainConfig, block.Number()) + tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, block.BaseFee(), nil), signer, testKey) + if err != nil { + panic(err) + } + block.AddTx(tx) + } + }) + return blocks, receipts +} + +type chainData struct { + blocks []*types.Block + offset int +} + +var chain *chainData +var emptyChain *chainData + +func init() { + // Create a chain of blocks to import + targetBlocks := 128 + blocks, _ := makeChain(targetBlocks, 0, genesis, false) + chain = &chainData{blocks, 0} + + blocks, _ = makeChain(targetBlocks, 0, genesis, true) + emptyChain = &chainData{blocks, 0} +} + +func (chain *chainData) headers() []*types.Header { + hdrs := make([]*types.Header, len(chain.blocks)) + for i, b := range chain.blocks { + hdrs[i] = b.Header() + } + return hdrs +} + +func (chain *chainData) Len() int { + return len(chain.blocks) +} + +func dummyPeer(id string) *peerConnection { + p := &peerConnection{ + id: id, + lacking: make(map[common.Hash]struct{}), + } + return p +} + +func TestBasics(t *testing.T) { + numOfBlocks := len(emptyChain.blocks) + numOfReceipts := len(emptyChain.blocks) / 2 + + q := newQueue(10, 10) + if !q.Idle() { + t.Errorf("new queue should be idle") + } + q.Prepare(1, FastSync) + if res := q.Results(false); len(res) != 0 { + t.Fatal("new queue should have 0 results") + } + + // Schedule a batch of headers + q.Schedule(chain.headers(), 1) + if q.Idle() { + t.Errorf("queue should not be idle") + } + if got, exp := q.PendingBlocks(), chain.Len(); got != exp { + t.Errorf("wrong pending block count, got %d, exp %d", got, exp) + } + // Only non-empty receipts get added to task-queue + if got, exp := q.PendingReceipts(), 64; got != exp { + t.Errorf("wrong pending receipt count, got %d, exp %d", got, exp) + } + // Items are now queued for downloading, next step is that we tell the + // queue that a certain peer will deliver them for us + { + peer := dummyPeer("peer-1") + fetchReq, _, throttle := q.ReserveBodies(peer, 50) + if !throttle { + // queue size is only 10, so throttling should occur + t.Fatal("should throttle") + } + // But we should still get the first things to fetch + if got, exp := len(fetchReq.Headers), 5; got != exp { + t.Fatalf("expected %d requests, got %d", exp, got) + } + if got, exp := fetchReq.Headers[0].Number.Uint64(), uint64(1); got != exp { + t.Fatalf("expected header %d, got %d", exp, got) + } + } + if exp, got := q.blockTaskQueue.Size(), numOfBlocks-10; exp != got { + t.Errorf("expected block task queue to be %d, got %d", exp, got) + } + if exp, got := q.receiptTaskQueue.Size(), numOfReceipts; exp != got { + t.Errorf("expected receipt task queue to be %d, got %d", exp, got) + } + { + peer := dummyPeer("peer-2") + fetchReq, _, throttle := q.ReserveBodies(peer, 50) + + // The second peer should hit throttling + if !throttle { + t.Fatalf("should not throttle") + } + // And not get any fetches at all, since it was throttled to begin with + if fetchReq != nil { + t.Fatalf("should have no fetches, got %d", len(fetchReq.Headers)) + } + } + if exp, got := q.blockTaskQueue.Size(), numOfBlocks-10; exp != got { + t.Errorf("expected block task queue to be %d, got %d", exp, got) + } + if exp, got := q.receiptTaskQueue.Size(), numOfReceipts; exp != got { + t.Errorf("expected receipt task queue to be %d, got %d", exp, got) + } + { + // The receipt delivering peer should not be affected + // by the throttling of body deliveries + peer := dummyPeer("peer-3") + fetchReq, _, throttle := q.ReserveReceipts(peer, 50) + if !throttle { + // queue size is only 10, so throttling should occur + t.Fatal("should throttle") + } + // But we should still get the first things to fetch + if got, exp := len(fetchReq.Headers), 5; got != exp { + t.Fatalf("expected %d requests, got %d", exp, got) + } + if got, exp := fetchReq.Headers[0].Number.Uint64(), uint64(1); got != exp { + t.Fatalf("expected header %d, got %d", exp, got) + } + + } + if exp, got := q.blockTaskQueue.Size(), numOfBlocks-10; exp != got { + t.Errorf("expected block task queue to be %d, got %d", exp, got) + } + if exp, got := q.receiptTaskQueue.Size(), numOfReceipts-5; exp != got { + t.Errorf("expected receipt task queue to be %d, got %d", exp, got) + } + if got, exp := q.resultCache.countCompleted(), 0; got != exp { + t.Errorf("wrong processable count, got %d, exp %d", got, exp) + } +} + +func TestEmptyBlocks(t *testing.T) { + numOfBlocks := len(emptyChain.blocks) + + q := newQueue(10, 10) + + q.Prepare(1, FastSync) + // Schedule a batch of headers + q.Schedule(emptyChain.headers(), 1) + if q.Idle() { + t.Errorf("queue should not be idle") + } + if got, exp := q.PendingBlocks(), len(emptyChain.blocks); got != exp { + t.Errorf("wrong pending block count, got %d, exp %d", got, exp) + } + if got, exp := q.PendingReceipts(), 0; got != exp { + t.Errorf("wrong pending receipt count, got %d, exp %d", got, exp) + } + // They won't be processable, because the fetchresults haven't been + // created yet + if got, exp := q.resultCache.countCompleted(), 0; got != exp { + t.Errorf("wrong processable count, got %d, exp %d", got, exp) + } + + // Items are now queued for downloading, next step is that we tell the + // queue that a certain peer will deliver them for us + // That should trigger all of them to suddenly become 'done' + { + // Reserve blocks + peer := dummyPeer("peer-1") + fetchReq, _, _ := q.ReserveBodies(peer, 50) + + // there should be nothing to fetch, blocks are empty + if fetchReq != nil { + t.Fatal("there should be no body fetch tasks remaining") + } + + } + if q.blockTaskQueue.Size() != numOfBlocks-10 { + t.Errorf("expected block task queue to be %d, got %d", numOfBlocks-10, q.blockTaskQueue.Size()) + } + if q.receiptTaskQueue.Size() != 0 { + t.Errorf("expected receipt task queue to be %d, got %d", 0, q.receiptTaskQueue.Size()) + } + { + peer := dummyPeer("peer-3") + fetchReq, _, _ := q.ReserveReceipts(peer, 50) + + // there should be nothing to fetch, blocks are empty + if fetchReq != nil { + t.Fatal("there should be no body fetch tasks remaining") + } + } + if q.blockTaskQueue.Size() != numOfBlocks-10 { + t.Errorf("expected block task queue to be %d, got %d", numOfBlocks-10, q.blockTaskQueue.Size()) + } + if q.receiptTaskQueue.Size() != 0 { + t.Errorf("expected receipt task queue to be %d, got %d", 0, q.receiptTaskQueue.Size()) + } + if got, exp := q.resultCache.countCompleted(), 10; got != exp { + t.Errorf("wrong processable count, got %d, exp %d", got, exp) + } +} + +// XTestDelivery does some more extensive testing of events that happen, +// blocks that become known and peers that make reservations and deliveries. +// disabled since it's not really a unit-test, but can be executed to test +// some more advanced scenarios +func XTestDelivery(t *testing.T) { + // the outside network, holding blocks + blo, rec := makeChain(128, 0, genesis, false) + world := newNetwork() + world.receipts = rec + world.chain = blo + world.progress(10) + if false { + log.Root().SetHandler(log.StdoutHandler) + + } + q := newQueue(10, 10) + var wg sync.WaitGroup + q.Prepare(1, FastSync) + wg.Add(1) + go func() { + // deliver headers + defer wg.Done() + c := 1 + for { + //fmt.Printf("getting headers from %d\n", c) + hdrs := world.headers(c) + l := len(hdrs) + //fmt.Printf("scheduling %d headers, first %d last %d\n", + // l, hdrs[0].Number.Uint64(), hdrs[len(hdrs)-1].Number.Uint64()) + q.Schedule(hdrs, uint64(c)) + c += l + } + }() + wg.Add(1) + go func() { + // collect results + defer wg.Done() + tot := 0 + for { + res := q.Results(true) + tot += len(res) + fmt.Printf("got %d results, %d tot\n", len(res), tot) + // Now we can forget about these + world.forget(res[len(res)-1].Header.Number.Uint64()) + + } + }() + wg.Add(1) + go func() { + defer wg.Done() + // reserve body fetch + i := 4 + for { + peer := dummyPeer(fmt.Sprintf("peer-%d", i)) + f, _, _ := q.ReserveBodies(peer, rand.Intn(30)) + if f != nil { + var emptyList []*types.Header + var txs [][]*types.Transaction + var uncles [][]*types.Header + numToSkip := rand.Intn(len(f.Headers)) + for _, hdr := range f.Headers[0 : len(f.Headers)-numToSkip] { + txs = append(txs, world.getTransactions(hdr.Number.Uint64())) + uncles = append(uncles, emptyList) + } + time.Sleep(100 * time.Millisecond) + _, err := q.DeliverBodies(peer.id, txs, uncles) + if err != nil { + fmt.Printf("delivered %d bodies %v\n", len(txs), err) + } + } else { + i++ + time.Sleep(200 * time.Millisecond) + } + } + }() + go func() { + defer wg.Done() + // reserve receiptfetch + peer := dummyPeer("peer-3") + for { + f, _, _ := q.ReserveReceipts(peer, rand.Intn(50)) + if f != nil { + var rcs [][]*types.Receipt + for _, hdr := range f.Headers { + rcs = append(rcs, world.getReceipts(hdr.Number.Uint64())) + } + _, err := q.DeliverReceipts(peer.id, rcs) + if err != nil { + fmt.Printf("delivered %d receipts %v\n", len(rcs), err) + } + time.Sleep(100 * time.Millisecond) + } else { + time.Sleep(200 * time.Millisecond) + } + } + }() + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 50; i++ { + time.Sleep(300 * time.Millisecond) + //world.tick() + //fmt.Printf("trying to progress\n") + world.progress(rand.Intn(100)) + } + for i := 0; i < 50; i++ { + time.Sleep(2990 * time.Millisecond) + + } + }() + wg.Add(1) + go func() { + defer wg.Done() + for { + time.Sleep(990 * time.Millisecond) + fmt.Printf("world block tip is %d\n", + world.chain[len(world.chain)-1].Header().Number.Uint64()) + fmt.Println(q.Stats()) + } + }() + wg.Wait() +} + +func newNetwork() *network { + var l sync.RWMutex + return &network{ + cond: sync.NewCond(&l), + offset: 1, // block 1 is at blocks[0] + } +} + +// represents the network +type network struct { + offset int + chain []*types.Block + receipts []types.Receipts + lock sync.RWMutex + cond *sync.Cond +} + +func (n *network) getTransactions(blocknum uint64) types.Transactions { + index := blocknum - uint64(n.offset) + return n.chain[index].Transactions() +} +func (n *network) getReceipts(blocknum uint64) types.Receipts { + index := blocknum - uint64(n.offset) + if got := n.chain[index].Header().Number.Uint64(); got != blocknum { + fmt.Printf("Err, got %d exp %d\n", got, blocknum) + panic("sd") + } + return n.receipts[index] +} + +func (n *network) forget(blocknum uint64) { + index := blocknum - uint64(n.offset) + n.chain = n.chain[index:] + n.receipts = n.receipts[index:] + n.offset = int(blocknum) + +} +func (n *network) progress(numBlocks int) { + + n.lock.Lock() + defer n.lock.Unlock() + //fmt.Printf("progressing...\n") + newBlocks, newR := makeChain(numBlocks, 0, n.chain[len(n.chain)-1], false) + n.chain = append(n.chain, newBlocks...) + n.receipts = append(n.receipts, newR...) + n.cond.Broadcast() + +} + +func (n *network) headers(from int) []*types.Header { + numHeaders := 128 + var hdrs []*types.Header + index := from - n.offset + + for index >= len(n.chain) { + // wait for progress + n.cond.L.Lock() + //fmt.Printf("header going into wait\n") + n.cond.Wait() + index = from - n.offset + n.cond.L.Unlock() + } + n.lock.RLock() + defer n.lock.RUnlock() + for i, b := range n.chain[index:] { + hdrs = append(hdrs, b.Header()) + if i >= numHeaders { + break + } + } + return hdrs +} diff --git a/les/downloader/resultstore.go b/les/downloader/resultstore.go new file mode 100644 index 0000000000000..21928c2a00baf --- /dev/null +++ b/les/downloader/resultstore.go @@ -0,0 +1,194 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/ethereum/go-ethereum/core/types" +) + +// resultStore implements a structure for maintaining fetchResults, tracking their +// download-progress and delivering (finished) results. +type resultStore struct { + items []*fetchResult // Downloaded but not yet delivered fetch results + resultOffset uint64 // Offset of the first cached fetch result in the block chain + + // Internal index of first non-completed entry, updated atomically when needed. + // If all items are complete, this will equal length(items), so + // *important* : is not safe to use for indexing without checking against length + indexIncomplete int32 // atomic access + + // throttleThreshold is the limit up to which we _want_ to fill the + // results. If blocks are large, we want to limit the results to less + // than the number of available slots, and maybe only fill 1024 out of + // 8192 possible places. The queue will, at certain times, recalibrate + // this index. + throttleThreshold uint64 + + lock sync.RWMutex +} + +func newResultStore(size int) *resultStore { + return &resultStore{ + resultOffset: 0, + items: make([]*fetchResult, size), + throttleThreshold: uint64(size), + } +} + +// SetThrottleThreshold updates the throttling threshold based on the requested +// limit and the total queue capacity. It returns the (possibly capped) threshold +func (r *resultStore) SetThrottleThreshold(threshold uint64) uint64 { + r.lock.Lock() + defer r.lock.Unlock() + + limit := uint64(len(r.items)) + if threshold >= limit { + threshold = limit + } + r.throttleThreshold = threshold + return r.throttleThreshold +} + +// AddFetch adds a header for body/receipt fetching. This is used when the queue +// wants to reserve headers for fetching. +// +// It returns the following: +// stale - if true, this item is already passed, and should not be requested again +// throttled - if true, the store is at capacity, this particular header is not prio now +// item - the result to store data into +// err - any error that occurred +func (r *resultStore) AddFetch(header *types.Header, fastSync bool) (stale, throttled bool, item *fetchResult, err error) { + r.lock.Lock() + defer r.lock.Unlock() + + var index int + item, index, stale, throttled, err = r.getFetchResult(header.Number.Uint64()) + if err != nil || stale || throttled { + return stale, throttled, item, err + } + if item == nil { + item = newFetchResult(header, fastSync) + r.items[index] = item + } + return stale, throttled, item, err +} + +// GetDeliverySlot returns the fetchResult for the given header. If the 'stale' flag +// is true, that means the header has already been delivered 'upstream'. This method +// does not bubble up the 'throttle' flag, since it's moot at the point in time when +// the item is downloaded and ready for delivery +func (r *resultStore) GetDeliverySlot(headerNumber uint64) (*fetchResult, bool, error) { + r.lock.RLock() + defer r.lock.RUnlock() + + res, _, stale, _, err := r.getFetchResult(headerNumber) + return res, stale, err +} + +// getFetchResult returns the fetchResult corresponding to the given item, and +// the index where the result is stored. +func (r *resultStore) getFetchResult(headerNumber uint64) (item *fetchResult, index int, stale, throttle bool, err error) { + index = int(int64(headerNumber) - int64(r.resultOffset)) + throttle = index >= int(r.throttleThreshold) + stale = index < 0 + + if index >= len(r.items) { + err = fmt.Errorf("%w: index allocation went beyond available resultStore space "+ + "(index [%d] = header [%d] - resultOffset [%d], len(resultStore) = %d", errInvalidChain, + index, headerNumber, r.resultOffset, len(r.items)) + return nil, index, stale, throttle, err + } + if stale { + return nil, index, stale, throttle, nil + } + item = r.items[index] + return item, index, stale, throttle, nil +} + +// hasCompletedItems returns true if there are processable items available +// this method is cheaper than countCompleted +func (r *resultStore) HasCompletedItems() bool { + r.lock.RLock() + defer r.lock.RUnlock() + + if len(r.items) == 0 { + return false + } + if item := r.items[0]; item != nil && item.AllDone() { + return true + } + return false +} + +// countCompleted returns the number of items ready for delivery, stopping at +// the first non-complete item. +// +// The mthod assumes (at least) rlock is held. +func (r *resultStore) countCompleted() int { + // We iterate from the already known complete point, and see + // if any more has completed since last count + index := atomic.LoadInt32(&r.indexIncomplete) + for ; ; index++ { + if index >= int32(len(r.items)) { + break + } + result := r.items[index] + if result == nil || !result.AllDone() { + break + } + } + atomic.StoreInt32(&r.indexIncomplete, index) + return int(index) +} + +// GetCompleted returns the next batch of completed fetchResults +func (r *resultStore) GetCompleted(limit int) []*fetchResult { + r.lock.Lock() + defer r.lock.Unlock() + + completed := r.countCompleted() + if limit > completed { + limit = completed + } + results := make([]*fetchResult, limit) + copy(results, r.items[:limit]) + + // Delete the results from the cache and clear the tail. + copy(r.items, r.items[limit:]) + for i := len(r.items) - limit; i < len(r.items); i++ { + r.items[i] = nil + } + // Advance the expected block number of the first cache entry + r.resultOffset += uint64(limit) + atomic.AddInt32(&r.indexIncomplete, int32(-limit)) + + return results +} + +// Prepare initialises the offset with the given block number +func (r *resultStore) Prepare(offset uint64) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.resultOffset < offset { + r.resultOffset = offset + } +} diff --git a/les/downloader/statesync.go b/les/downloader/statesync.go new file mode 100644 index 0000000000000..6c53e5577a87b --- /dev/null +++ b/les/downloader/statesync.go @@ -0,0 +1,615 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/trie" + "golang.org/x/crypto/sha3" +) + +// stateReq represents a batch of state fetch requests grouped together into +// a single data retrieval network packet. +type stateReq struct { + nItems uint16 // Number of items requested for download (max is 384, so uint16 is sufficient) + trieTasks map[common.Hash]*trieTask // Trie node download tasks to track previous attempts + codeTasks map[common.Hash]*codeTask // Byte code download tasks to track previous attempts + timeout time.Duration // Maximum round trip time for this to complete + timer *time.Timer // Timer to fire when the RTT timeout expires + peer *peerConnection // Peer that we're requesting from + delivered time.Time // Time when the packet was delivered (independent when we process it) + response [][]byte // Response data of the peer (nil for timeouts) + dropped bool // Flag whether the peer dropped off early +} + +// timedOut returns if this request timed out. +func (req *stateReq) timedOut() bool { + return req.response == nil +} + +// stateSyncStats is a collection of progress stats to report during a state trie +// sync to RPC requests as well as to display in user logs. +type stateSyncStats struct { + processed uint64 // Number of state entries processed + duplicate uint64 // Number of state entries downloaded twice + unexpected uint64 // Number of non-requested state entries received + pending uint64 // Number of still pending state entries +} + +// syncState starts downloading state with the given root hash. +func (d *Downloader) syncState(root common.Hash) *stateSync { + // Create the state sync + s := newStateSync(d, root) + select { + case d.stateSyncStart <- s: + // If we tell the statesync to restart with a new root, we also need + // to wait for it to actually also start -- when old requests have timed + // out or been delivered + <-s.started + case <-d.quitCh: + s.err = errCancelStateFetch + close(s.done) + } + return s +} + +// stateFetcher manages the active state sync and accepts requests +// on its behalf. +func (d *Downloader) stateFetcher() { + for { + select { + case s := <-d.stateSyncStart: + for next := s; next != nil; { + next = d.runStateSync(next) + } + case <-d.stateCh: + // Ignore state responses while no sync is running. + case <-d.quitCh: + return + } + } +} + +// runStateSync runs a state synchronisation until it completes or another root +// hash is requested to be switched over to. +func (d *Downloader) runStateSync(s *stateSync) *stateSync { + var ( + active = make(map[string]*stateReq) // Currently in-flight requests + finished []*stateReq // Completed or failed requests + timeout = make(chan *stateReq) // Timed out active requests + ) + log.Trace("State sync starting", "root", s.root) + + defer func() { + // Cancel active request timers on exit. Also set peers to idle so they're + // available for the next sync. + for _, req := range active { + req.timer.Stop() + req.peer.SetNodeDataIdle(int(req.nItems), time.Now()) + } + }() + go s.run() + defer s.Cancel() + + // Listen for peer departure events to cancel assigned tasks + peerDrop := make(chan *peerConnection, 1024) + peerSub := s.d.peers.SubscribePeerDrops(peerDrop) + defer peerSub.Unsubscribe() + + for { + // Enable sending of the first buffered element if there is one. + var ( + deliverReq *stateReq + deliverReqCh chan *stateReq + ) + if len(finished) > 0 { + deliverReq = finished[0] + deliverReqCh = s.deliver + } + + select { + // The stateSync lifecycle: + case next := <-d.stateSyncStart: + d.spindownStateSync(active, finished, timeout, peerDrop) + return next + + case <-s.done: + d.spindownStateSync(active, finished, timeout, peerDrop) + return nil + + // Send the next finished request to the current sync: + case deliverReqCh <- deliverReq: + // Shift out the first request, but also set the emptied slot to nil for GC + copy(finished, finished[1:]) + finished[len(finished)-1] = nil + finished = finished[:len(finished)-1] + + // Handle incoming state packs: + case pack := <-d.stateCh: + // Discard any data not requested (or previously timed out) + req := active[pack.PeerId()] + if req == nil { + log.Debug("Unrequested node data", "peer", pack.PeerId(), "len", pack.Items()) + continue + } + // Finalize the request and queue up for processing + req.timer.Stop() + req.response = pack.(*statePack).states + req.delivered = time.Now() + + finished = append(finished, req) + delete(active, pack.PeerId()) + + // Handle dropped peer connections: + case p := <-peerDrop: + // Skip if no request is currently pending + req := active[p.id] + if req == nil { + continue + } + // Finalize the request and queue up for processing + req.timer.Stop() + req.dropped = true + req.delivered = time.Now() + + finished = append(finished, req) + delete(active, p.id) + + // Handle timed-out requests: + case req := <-timeout: + // If the peer is already requesting something else, ignore the stale timeout. + // This can happen when the timeout and the delivery happens simultaneously, + // causing both pathways to trigger. + if active[req.peer.id] != req { + continue + } + req.delivered = time.Now() + // Move the timed out data back into the download queue + finished = append(finished, req) + delete(active, req.peer.id) + + // Track outgoing state requests: + case req := <-d.trackStateReq: + // If an active request already exists for this peer, we have a problem. In + // theory the trie node schedule must never assign two requests to the same + // peer. In practice however, a peer might receive a request, disconnect and + // immediately reconnect before the previous times out. In this case the first + // request is never honored, alas we must not silently overwrite it, as that + // causes valid requests to go missing and sync to get stuck. + if old := active[req.peer.id]; old != nil { + log.Warn("Busy peer assigned new state fetch", "peer", old.peer.id) + // Move the previous request to the finished set + old.timer.Stop() + old.dropped = true + old.delivered = time.Now() + finished = append(finished, old) + } + // Start a timer to notify the sync loop if the peer stalled. + req.timer = time.AfterFunc(req.timeout, func() { + timeout <- req + }) + active[req.peer.id] = req + } + } +} + +// spindownStateSync 'drains' the outstanding requests; some will be delivered and other +// will time out. This is to ensure that when the next stateSync starts working, all peers +// are marked as idle and de facto _are_ idle. +func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []*stateReq, timeout chan *stateReq, peerDrop chan *peerConnection) { + log.Trace("State sync spinning down", "active", len(active), "finished", len(finished)) + for len(active) > 0 { + var ( + req *stateReq + reason string + ) + select { + // Handle (drop) incoming state packs: + case pack := <-d.stateCh: + req = active[pack.PeerId()] + reason = "delivered" + // Handle dropped peer connections: + case p := <-peerDrop: + req = active[p.id] + reason = "peerdrop" + // Handle timed-out requests: + case req = <-timeout: + reason = "timeout" + } + if req == nil { + continue + } + req.peer.log.Trace("State peer marked idle (spindown)", "req.items", int(req.nItems), "reason", reason) + req.timer.Stop() + delete(active, req.peer.id) + req.peer.SetNodeDataIdle(int(req.nItems), time.Now()) + } + // The 'finished' set contains deliveries that we were going to pass to processing. + // Those are now moot, but we still need to set those peers as idle, which would + // otherwise have been done after processing + for _, req := range finished { + req.peer.SetNodeDataIdle(int(req.nItems), time.Now()) + } +} + +// stateSync schedules requests for downloading a particular state trie defined +// by a given state root. +type stateSync struct { + d *Downloader // Downloader instance to access and manage current peerset + + root common.Hash // State root currently being synced + sched *trie.Sync // State trie sync scheduler defining the tasks + keccak crypto.KeccakState // Keccak256 hasher to verify deliveries with + + trieTasks map[common.Hash]*trieTask // Set of trie node tasks currently queued for retrieval + codeTasks map[common.Hash]*codeTask // Set of byte code tasks currently queued for retrieval + + numUncommitted int + bytesUncommitted int + + started chan struct{} // Started is signalled once the sync loop starts + + deliver chan *stateReq // Delivery channel multiplexing peer responses + cancel chan struct{} // Channel to signal a termination request + cancelOnce sync.Once // Ensures cancel only ever gets called once + done chan struct{} // Channel to signal termination completion + err error // Any error hit during sync (set before completion) +} + +// trieTask represents a single trie node download task, containing a set of +// peers already attempted retrieval from to detect stalled syncs and abort. +type trieTask struct { + path [][]byte + attempts map[string]struct{} +} + +// codeTask represents a single byte code download task, containing a set of +// peers already attempted retrieval from to detect stalled syncs and abort. +type codeTask struct { + attempts map[string]struct{} +} + +// newStateSync creates a new state trie download scheduler. This method does not +// yet start the sync. The user needs to call run to initiate. +func newStateSync(d *Downloader, root common.Hash) *stateSync { + return &stateSync{ + d: d, + root: root, + sched: state.NewStateSync(root, d.stateDB, d.stateBloom, nil), + keccak: sha3.NewLegacyKeccak256().(crypto.KeccakState), + trieTasks: make(map[common.Hash]*trieTask), + codeTasks: make(map[common.Hash]*codeTask), + deliver: make(chan *stateReq), + cancel: make(chan struct{}), + done: make(chan struct{}), + started: make(chan struct{}), + } +} + +// run starts the task assignment and response processing loop, blocking until +// it finishes, and finally notifying any goroutines waiting for the loop to +// finish. +func (s *stateSync) run() { + close(s.started) + if s.d.snapSync { + s.err = s.d.SnapSyncer.Sync(s.root, s.cancel) + } else { + s.err = s.loop() + } + close(s.done) +} + +// Wait blocks until the sync is done or canceled. +func (s *stateSync) Wait() error { + <-s.done + return s.err +} + +// Cancel cancels the sync and waits until it has shut down. +func (s *stateSync) Cancel() error { + s.cancelOnce.Do(func() { + close(s.cancel) + }) + return s.Wait() +} + +// loop is the main event loop of a state trie sync. It it responsible for the +// assignment of new tasks to peers (including sending it to them) as well as +// for the processing of inbound data. Note, that the loop does not directly +// receive data from peers, rather those are buffered up in the downloader and +// pushed here async. The reason is to decouple processing from data receipt +// and timeouts. +func (s *stateSync) loop() (err error) { + // Listen for new peer events to assign tasks to them + newPeer := make(chan *peerConnection, 1024) + peerSub := s.d.peers.SubscribeNewPeers(newPeer) + defer peerSub.Unsubscribe() + defer func() { + cerr := s.commit(true) + if err == nil { + err = cerr + } + }() + + // Keep assigning new tasks until the sync completes or aborts + for s.sched.Pending() > 0 { + if err = s.commit(false); err != nil { + return err + } + s.assignTasks() + // Tasks assigned, wait for something to happen + select { + case <-newPeer: + // New peer arrived, try to assign it download tasks + + case <-s.cancel: + return errCancelStateFetch + + case <-s.d.cancelCh: + return errCanceled + + case req := <-s.deliver: + // Response, disconnect or timeout triggered, drop the peer if stalling + log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut()) + if req.nItems <= 2 && !req.dropped && req.timedOut() { + // 2 items are the minimum requested, if even that times out, we've no use of + // this peer at the moment. + log.Warn("Stalling state sync, dropping peer", "peer", req.peer.id) + if s.d.dropPeer == nil { + // The dropPeer method is nil when `--copydb` is used for a local copy. + // Timeouts can occur if e.g. compaction hits at the wrong time, and can be ignored + req.peer.log.Warn("Downloader wants to drop peer, but peerdrop-function is not set", "peer", req.peer.id) + } else { + s.d.dropPeer(req.peer.id) + + // If this peer was the master peer, abort sync immediately + s.d.cancelLock.RLock() + master := req.peer.id == s.d.cancelPeer + s.d.cancelLock.RUnlock() + + if master { + s.d.cancel() + return errTimeout + } + } + } + // Process all the received blobs and check for stale delivery + delivered, err := s.process(req) + req.peer.SetNodeDataIdle(delivered, req.delivered) + if err != nil { + log.Warn("Node data write error", "err", err) + return err + } + } + } + return nil +} + +func (s *stateSync) commit(force bool) error { + if !force && s.bytesUncommitted < ethdb.IdealBatchSize { + return nil + } + start := time.Now() + b := s.d.stateDB.NewBatch() + if err := s.sched.Commit(b); err != nil { + return err + } + if err := b.Write(); err != nil { + return fmt.Errorf("DB write error: %v", err) + } + s.updateStats(s.numUncommitted, 0, 0, time.Since(start)) + s.numUncommitted = 0 + s.bytesUncommitted = 0 + return nil +} + +// assignTasks attempts to assign new tasks to all idle peers, either from the +// batch currently being retried, or fetching new data from the trie sync itself. +func (s *stateSync) assignTasks() { + // Iterate over all idle peers and try to assign them state fetches + peers, _ := s.d.peers.NodeDataIdlePeers() + for _, p := range peers { + // Assign a batch of fetches proportional to the estimated latency/bandwidth + cap := p.NodeDataCapacity(s.d.peers.rates.TargetRoundTrip()) + req := &stateReq{peer: p, timeout: s.d.peers.rates.TargetTimeout()} + + nodes, _, codes := s.fillTasks(cap, req) + + // If the peer was assigned tasks to fetch, send the network request + if len(nodes)+len(codes) > 0 { + req.peer.log.Trace("Requesting batch of state data", "nodes", len(nodes), "codes", len(codes), "root", s.root) + select { + case s.d.trackStateReq <- req: + req.peer.FetchNodeData(append(nodes, codes...)) // Unified retrieval under eth/6x + case <-s.cancel: + case <-s.d.cancelCh: + } + } + } +} + +// fillTasks fills the given request object with a maximum of n state download +// tasks to send to the remote peer. +func (s *stateSync) fillTasks(n int, req *stateReq) (nodes []common.Hash, paths []trie.SyncPath, codes []common.Hash) { + // Refill available tasks from the scheduler. + if fill := n - (len(s.trieTasks) + len(s.codeTasks)); fill > 0 { + nodes, paths, codes := s.sched.Missing(fill) + for i, hash := range nodes { + s.trieTasks[hash] = &trieTask{ + path: paths[i], + attempts: make(map[string]struct{}), + } + } + for _, hash := range codes { + s.codeTasks[hash] = &codeTask{ + attempts: make(map[string]struct{}), + } + } + } + // Find tasks that haven't been tried with the request's peer. Prefer code + // over trie nodes as those can be written to disk and forgotten about. + nodes = make([]common.Hash, 0, n) + paths = make([]trie.SyncPath, 0, n) + codes = make([]common.Hash, 0, n) + + req.trieTasks = make(map[common.Hash]*trieTask, n) + req.codeTasks = make(map[common.Hash]*codeTask, n) + + for hash, t := range s.codeTasks { + // Stop when we've gathered enough requests + if len(nodes)+len(codes) == n { + break + } + // Skip any requests we've already tried from this peer + if _, ok := t.attempts[req.peer.id]; ok { + continue + } + // Assign the request to this peer + t.attempts[req.peer.id] = struct{}{} + codes = append(codes, hash) + req.codeTasks[hash] = t + delete(s.codeTasks, hash) + } + for hash, t := range s.trieTasks { + // Stop when we've gathered enough requests + if len(nodes)+len(codes) == n { + break + } + // Skip any requests we've already tried from this peer + if _, ok := t.attempts[req.peer.id]; ok { + continue + } + // Assign the request to this peer + t.attempts[req.peer.id] = struct{}{} + + nodes = append(nodes, hash) + paths = append(paths, t.path) + + req.trieTasks[hash] = t + delete(s.trieTasks, hash) + } + req.nItems = uint16(len(nodes) + len(codes)) + return nodes, paths, codes +} + +// process iterates over a batch of delivered state data, injecting each item +// into a running state sync, re-queuing any items that were requested but not +// delivered. Returns whether the peer actually managed to deliver anything of +// value, and any error that occurred. +func (s *stateSync) process(req *stateReq) (int, error) { + // Collect processing stats and update progress if valid data was received + duplicate, unexpected, successful := 0, 0, 0 + + defer func(start time.Time) { + if duplicate > 0 || unexpected > 0 { + s.updateStats(0, duplicate, unexpected, time.Since(start)) + } + }(time.Now()) + + // Iterate over all the delivered data and inject one-by-one into the trie + for _, blob := range req.response { + hash, err := s.processNodeData(blob) + switch err { + case nil: + s.numUncommitted++ + s.bytesUncommitted += len(blob) + successful++ + case trie.ErrNotRequested: + unexpected++ + case trie.ErrAlreadyProcessed: + duplicate++ + default: + return successful, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err) + } + // Delete from both queues (one delivery is enough for the syncer) + delete(req.trieTasks, hash) + delete(req.codeTasks, hash) + } + // Put unfulfilled tasks back into the retry queue + npeers := s.d.peers.Len() + for hash, task := range req.trieTasks { + // If the node did deliver something, missing items may be due to a protocol + // limit or a previous timeout + delayed delivery. Both cases should permit + // the node to retry the missing items (to avoid single-peer stalls). + if len(req.response) > 0 || req.timedOut() { + delete(task.attempts, req.peer.id) + } + // If we've requested the node too many times already, it may be a malicious + // sync where nobody has the right data. Abort. + if len(task.attempts) >= npeers { + return successful, fmt.Errorf("trie node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) + } + // Missing item, place into the retry queue. + s.trieTasks[hash] = task + } + for hash, task := range req.codeTasks { + // If the node did deliver something, missing items may be due to a protocol + // limit or a previous timeout + delayed delivery. Both cases should permit + // the node to retry the missing items (to avoid single-peer stalls). + if len(req.response) > 0 || req.timedOut() { + delete(task.attempts, req.peer.id) + } + // If we've requested the node too many times already, it may be a malicious + // sync where nobody has the right data. Abort. + if len(task.attempts) >= npeers { + return successful, fmt.Errorf("byte code %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) + } + // Missing item, place into the retry queue. + s.codeTasks[hash] = task + } + return successful, nil +} + +// processNodeData tries to inject a trie node data blob delivered from a remote +// peer into the state trie, returning whether anything useful was written or any +// error occurred. +func (s *stateSync) processNodeData(blob []byte) (common.Hash, error) { + res := trie.SyncResult{Data: blob} + s.keccak.Reset() + s.keccak.Write(blob) + s.keccak.Read(res.Hash[:]) + err := s.sched.Process(res) + return res.Hash, err +} + +// updateStats bumps the various state sync progress counters and displays a log +// message for the user to see. +func (s *stateSync) updateStats(written, duplicate, unexpected int, duration time.Duration) { + s.d.syncStatsLock.Lock() + defer s.d.syncStatsLock.Unlock() + + s.d.syncStatsState.pending = uint64(s.sched.Pending()) + s.d.syncStatsState.processed += uint64(written) + s.d.syncStatsState.duplicate += uint64(duplicate) + s.d.syncStatsState.unexpected += uint64(unexpected) + + if written > 0 || duplicate > 0 || unexpected > 0 { + log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "trieretry", len(s.trieTasks), "coderetry", len(s.codeTasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected) + } + if written > 0 { + rawdb.WriteFastTrieProgress(s.d.stateDB, s.d.syncStatsState.processed) + } +} diff --git a/les/downloader/testchain_test.go b/les/downloader/testchain_test.go new file mode 100644 index 0000000000000..b9865f7e032b3 --- /dev/null +++ b/les/downloader/testchain_test.go @@ -0,0 +1,230 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "fmt" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/params" +) + +// Test chain parameters. +var ( + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + testAddress = crypto.PubkeyToAddress(testKey.PublicKey) + testDB = rawdb.NewMemoryDatabase() + testGenesis = core.GenesisBlockForTesting(testDB, testAddress, big.NewInt(1000000000000000)) +) + +// The common prefix of all test chains: +var testChainBase = newTestChain(blockCacheMaxItems+200, testGenesis) + +// Different forks on top of the base chain: +var testChainForkLightA, testChainForkLightB, testChainForkHeavy *testChain + +func init() { + var forkLen = int(fullMaxForkAncestry + 50) + var wg sync.WaitGroup + wg.Add(3) + go func() { testChainForkLightA = testChainBase.makeFork(forkLen, false, 1); wg.Done() }() + go func() { testChainForkLightB = testChainBase.makeFork(forkLen, false, 2); wg.Done() }() + go func() { testChainForkHeavy = testChainBase.makeFork(forkLen, true, 3); wg.Done() }() + wg.Wait() +} + +type testChain struct { + genesis *types.Block + chain []common.Hash + headerm map[common.Hash]*types.Header + blockm map[common.Hash]*types.Block + receiptm map[common.Hash][]*types.Receipt + tdm map[common.Hash]*big.Int +} + +// newTestChain creates a blockchain of the given length. +func newTestChain(length int, genesis *types.Block) *testChain { + tc := new(testChain).copy(length) + tc.genesis = genesis + tc.chain = append(tc.chain, genesis.Hash()) + tc.headerm[tc.genesis.Hash()] = tc.genesis.Header() + tc.tdm[tc.genesis.Hash()] = tc.genesis.Difficulty() + tc.blockm[tc.genesis.Hash()] = tc.genesis + tc.generate(length-1, 0, genesis, false) + return tc +} + +// makeFork creates a fork on top of the test chain. +func (tc *testChain) makeFork(length int, heavy bool, seed byte) *testChain { + fork := tc.copy(tc.len() + length) + fork.generate(length, seed, tc.headBlock(), heavy) + return fork +} + +// shorten creates a copy of the chain with the given length. It panics if the +// length is longer than the number of available blocks. +func (tc *testChain) shorten(length int) *testChain { + if length > tc.len() { + panic(fmt.Errorf("can't shorten test chain to %d blocks, it's only %d blocks long", length, tc.len())) + } + return tc.copy(length) +} + +func (tc *testChain) copy(newlen int) *testChain { + cpy := &testChain{ + genesis: tc.genesis, + headerm: make(map[common.Hash]*types.Header, newlen), + blockm: make(map[common.Hash]*types.Block, newlen), + receiptm: make(map[common.Hash][]*types.Receipt, newlen), + tdm: make(map[common.Hash]*big.Int, newlen), + } + for i := 0; i < len(tc.chain) && i < newlen; i++ { + hash := tc.chain[i] + cpy.chain = append(cpy.chain, tc.chain[i]) + cpy.tdm[hash] = tc.tdm[hash] + cpy.blockm[hash] = tc.blockm[hash] + cpy.headerm[hash] = tc.headerm[hash] + cpy.receiptm[hash] = tc.receiptm[hash] + } + return cpy +} + +// generate creates a chain of n blocks starting at and including parent. +// the returned hash chain is ordered head->parent. In addition, every 22th block +// contains a transaction and every 5th an uncle to allow testing correct block +// reassembly. +func (tc *testChain) generate(n int, seed byte, parent *types.Block, heavy bool) { + // start := time.Now() + // defer func() { fmt.Printf("test chain generated in %v\n", time.Since(start)) }() + + blocks, receipts := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), testDB, n, func(i int, block *core.BlockGen) { + block.SetCoinbase(common.Address{seed}) + // If a heavy chain is requested, delay blocks to raise difficulty + if heavy { + block.OffsetTime(-1) + } + // Include transactions to the miner to make blocks more interesting. + if parent == tc.genesis && i%22 == 0 { + signer := types.MakeSigner(params.TestChainConfig, block.Number()) + tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, block.BaseFee(), nil), signer, testKey) + if err != nil { + panic(err) + } + block.AddTx(tx) + } + // if the block number is a multiple of 5, add a bonus uncle to the block + if i > 0 && i%5 == 0 { + block.AddUncle(&types.Header{ + ParentHash: block.PrevBlock(i - 1).Hash(), + Number: big.NewInt(block.Number().Int64() - 1), + }) + } + }) + + // Convert the block-chain into a hash-chain and header/block maps + td := new(big.Int).Set(tc.td(parent.Hash())) + for i, b := range blocks { + td := td.Add(td, b.Difficulty()) + hash := b.Hash() + tc.chain = append(tc.chain, hash) + tc.blockm[hash] = b + tc.headerm[hash] = b.Header() + tc.receiptm[hash] = receipts[i] + tc.tdm[hash] = new(big.Int).Set(td) + } +} + +// len returns the total number of blocks in the chain. +func (tc *testChain) len() int { + return len(tc.chain) +} + +// headBlock returns the head of the chain. +func (tc *testChain) headBlock() *types.Block { + return tc.blockm[tc.chain[len(tc.chain)-1]] +} + +// td returns the total difficulty of the given block. +func (tc *testChain) td(hash common.Hash) *big.Int { + return tc.tdm[hash] +} + +// headersByHash returns headers in order from the given hash. +func (tc *testChain) headersByHash(origin common.Hash, amount int, skip int, reverse bool) []*types.Header { + num, _ := tc.hashToNumber(origin) + return tc.headersByNumber(num, amount, skip, reverse) +} + +// headersByNumber returns headers from the given number. +func (tc *testChain) headersByNumber(origin uint64, amount int, skip int, reverse bool) []*types.Header { + result := make([]*types.Header, 0, amount) + + if !reverse { + for num := origin; num < uint64(len(tc.chain)) && len(result) < amount; num += uint64(skip) + 1 { + if header, ok := tc.headerm[tc.chain[int(num)]]; ok { + result = append(result, header) + } + } + } else { + for num := int64(origin); num >= 0 && len(result) < amount; num -= int64(skip) + 1 { + if header, ok := tc.headerm[tc.chain[int(num)]]; ok { + result = append(result, header) + } + } + } + return result +} + +// receipts returns the receipts of the given block hashes. +func (tc *testChain) receipts(hashes []common.Hash) [][]*types.Receipt { + results := make([][]*types.Receipt, 0, len(hashes)) + for _, hash := range hashes { + if receipt, ok := tc.receiptm[hash]; ok { + results = append(results, receipt) + } + } + return results +} + +// bodies returns the block bodies of the given block hashes. +func (tc *testChain) bodies(hashes []common.Hash) ([][]*types.Transaction, [][]*types.Header) { + transactions := make([][]*types.Transaction, 0, len(hashes)) + uncles := make([][]*types.Header, 0, len(hashes)) + for _, hash := range hashes { + if block, ok := tc.blockm[hash]; ok { + transactions = append(transactions, block.Transactions()) + uncles = append(uncles, block.Uncles()) + } + } + return transactions, uncles +} + +func (tc *testChain) hashToNumber(target common.Hash) (uint64, bool) { + for num, hash := range tc.chain { + if hash == target { + return uint64(num), true + } + } + return 0, false +} diff --git a/les/downloader/types.go b/les/downloader/types.go new file mode 100644 index 0000000000000..ff70bfa0e3cc4 --- /dev/null +++ b/les/downloader/types.go @@ -0,0 +1,79 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package downloader + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/core/types" +) + +// peerDropFn is a callback type for dropping a peer detected as malicious. +type peerDropFn func(id string) + +// dataPack is a data message returned by a peer for some query. +type dataPack interface { + PeerId() string + Items() int + Stats() string +} + +// headerPack is a batch of block headers returned by a peer. +type headerPack struct { + peerID string + headers []*types.Header +} + +func (p *headerPack) PeerId() string { return p.peerID } +func (p *headerPack) Items() int { return len(p.headers) } +func (p *headerPack) Stats() string { return fmt.Sprintf("%d", len(p.headers)) } + +// bodyPack is a batch of block bodies returned by a peer. +type bodyPack struct { + peerID string + transactions [][]*types.Transaction + uncles [][]*types.Header +} + +func (p *bodyPack) PeerId() string { return p.peerID } +func (p *bodyPack) Items() int { + if len(p.transactions) <= len(p.uncles) { + return len(p.transactions) + } + return len(p.uncles) +} +func (p *bodyPack) Stats() string { return fmt.Sprintf("%d:%d", len(p.transactions), len(p.uncles)) } + +// receiptPack is a batch of receipts returned by a peer. +type receiptPack struct { + peerID string + receipts [][]*types.Receipt +} + +func (p *receiptPack) PeerId() string { return p.peerID } +func (p *receiptPack) Items() int { return len(p.receipts) } +func (p *receiptPack) Stats() string { return fmt.Sprintf("%d", len(p.receipts)) } + +// statePack is a batch of states returned by a peer. +type statePack struct { + peerID string + states [][]byte +} + +func (p *statePack) PeerId() string { return p.peerID } +func (p *statePack) Items() int { return len(p.states) } +func (p *statePack) Stats() string { return fmt.Sprintf("%d", len(p.states)) } diff --git a/les/fetcher.go b/les/fetcher.go index 5eea996748745..d944d32858e7b 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -27,8 +27,8 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth/fetcher" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/les/fetcher" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" diff --git a/les/fetcher/block_fetcher.go b/les/fetcher/block_fetcher.go new file mode 100644 index 0000000000000..283008db0f1e5 --- /dev/null +++ b/les/fetcher/block_fetcher.go @@ -0,0 +1,889 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// This is a temporary package whilst working on the eth/66 blocking refactors. +// After that work is done, les needs to be refactored to use the new package, +// or alternatively use a stripped down version of it. Either way, we need to +// keep the changes scoped so duplicating temporarily seems the sanest. +package fetcher + +import ( + "errors" + "math/rand" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/prque" + "github.com/ethereum/go-ethereum/consensus" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + lightTimeout = time.Millisecond // Time allowance before an announced header is explicitly requested + arriveTimeout = 500 * time.Millisecond // Time allowance before an announced block/transaction is explicitly requested + gatherSlack = 100 * time.Millisecond // Interval used to collate almost-expired announces with fetches + fetchTimeout = 5 * time.Second // Maximum allotted time to return an explicitly requested block/transaction +) + +const ( + maxUncleDist = 7 // Maximum allowed backward distance from the chain head + maxQueueDist = 32 // Maximum allowed distance from the chain head to queue + hashLimit = 256 // Maximum number of unique blocks or headers a peer may have announced + blockLimit = 64 // Maximum number of unique blocks a peer may have delivered +) + +var ( + blockAnnounceInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/announces/in", nil) + blockAnnounceOutTimer = metrics.NewRegisteredTimer("eth/fetcher/block/announces/out", nil) + blockAnnounceDropMeter = metrics.NewRegisteredMeter("eth/fetcher/block/announces/drop", nil) + blockAnnounceDOSMeter = metrics.NewRegisteredMeter("eth/fetcher/block/announces/dos", nil) + + blockBroadcastInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/broadcasts/in", nil) + blockBroadcastOutTimer = metrics.NewRegisteredTimer("eth/fetcher/block/broadcasts/out", nil) + blockBroadcastDropMeter = metrics.NewRegisteredMeter("eth/fetcher/block/broadcasts/drop", nil) + blockBroadcastDOSMeter = metrics.NewRegisteredMeter("eth/fetcher/block/broadcasts/dos", nil) + + headerFetchMeter = metrics.NewRegisteredMeter("eth/fetcher/block/headers", nil) + bodyFetchMeter = metrics.NewRegisteredMeter("eth/fetcher/block/bodies", nil) + + headerFilterInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/headers/in", nil) + headerFilterOutMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/headers/out", nil) + bodyFilterInMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/bodies/in", nil) + bodyFilterOutMeter = metrics.NewRegisteredMeter("eth/fetcher/block/filter/bodies/out", nil) +) + +var errTerminated = errors.New("terminated") + +// HeaderRetrievalFn is a callback type for retrieving a header from the local chain. +type HeaderRetrievalFn func(common.Hash) *types.Header + +// blockRetrievalFn is a callback type for retrieving a block from the local chain. +type blockRetrievalFn func(common.Hash) *types.Block + +// headerRequesterFn is a callback type for sending a header retrieval request. +type headerRequesterFn func(common.Hash) error + +// bodyRequesterFn is a callback type for sending a body retrieval request. +type bodyRequesterFn func([]common.Hash) error + +// headerVerifierFn is a callback type to verify a block's header for fast propagation. +type headerVerifierFn func(header *types.Header) error + +// blockBroadcasterFn is a callback type for broadcasting a block to connected peers. +type blockBroadcasterFn func(block *types.Block, propagate bool) + +// chainHeightFn is a callback type to retrieve the current chain height. +type chainHeightFn func() uint64 + +// headersInsertFn is a callback type to insert a batch of headers into the local chain. +type headersInsertFn func(headers []*types.Header) (int, error) + +// chainInsertFn is a callback type to insert a batch of blocks into the local chain. +type chainInsertFn func(types.Blocks) (int, error) + +// peerDropFn is a callback type for dropping a peer detected as malicious. +type peerDropFn func(id string) + +// blockAnnounce is the hash notification of the availability of a new block in the +// network. +type blockAnnounce struct { + hash common.Hash // Hash of the block being announced + number uint64 // Number of the block being announced (0 = unknown | old protocol) + header *types.Header // Header of the block partially reassembled (new protocol) + time time.Time // Timestamp of the announcement + + origin string // Identifier of the peer originating the notification + + fetchHeader headerRequesterFn // Fetcher function to retrieve the header of an announced block + fetchBodies bodyRequesterFn // Fetcher function to retrieve the body of an announced block +} + +// headerFilterTask represents a batch of headers needing fetcher filtering. +type headerFilterTask struct { + peer string // The source peer of block headers + headers []*types.Header // Collection of headers to filter + time time.Time // Arrival time of the headers +} + +// bodyFilterTask represents a batch of block bodies (transactions and uncles) +// needing fetcher filtering. +type bodyFilterTask struct { + peer string // The source peer of block bodies + transactions [][]*types.Transaction // Collection of transactions per block bodies + uncles [][]*types.Header // Collection of uncles per block bodies + time time.Time // Arrival time of the blocks' contents +} + +// blockOrHeaderInject represents a schedules import operation. +type blockOrHeaderInject struct { + origin string + + header *types.Header // Used for light mode fetcher which only cares about header. + block *types.Block // Used for normal mode fetcher which imports full block. +} + +// number returns the block number of the injected object. +func (inject *blockOrHeaderInject) number() uint64 { + if inject.header != nil { + return inject.header.Number.Uint64() + } + return inject.block.NumberU64() +} + +// number returns the block hash of the injected object. +func (inject *blockOrHeaderInject) hash() common.Hash { + if inject.header != nil { + return inject.header.Hash() + } + return inject.block.Hash() +} + +// BlockFetcher is responsible for accumulating block announcements from various peers +// and scheduling them for retrieval. +type BlockFetcher struct { + light bool // The indicator whether it's a light fetcher or normal one. + + // Various event channels + notify chan *blockAnnounce + inject chan *blockOrHeaderInject + + headerFilter chan chan *headerFilterTask + bodyFilter chan chan *bodyFilterTask + + done chan common.Hash + quit chan struct{} + + // Announce states + announces map[string]int // Per peer blockAnnounce counts to prevent memory exhaustion + announced map[common.Hash][]*blockAnnounce // Announced blocks, scheduled for fetching + fetching map[common.Hash]*blockAnnounce // Announced blocks, currently fetching + fetched map[common.Hash][]*blockAnnounce // Blocks with headers fetched, scheduled for body retrieval + completing map[common.Hash]*blockAnnounce // Blocks with headers, currently body-completing + + // Block cache + queue *prque.Prque // Queue containing the import operations (block number sorted) + queues map[string]int // Per peer block counts to prevent memory exhaustion + queued map[common.Hash]*blockOrHeaderInject // Set of already queued blocks (to dedup imports) + + // Callbacks + getHeader HeaderRetrievalFn // Retrieves a header from the local chain + getBlock blockRetrievalFn // Retrieves a block from the local chain + verifyHeader headerVerifierFn // Checks if a block's headers have a valid proof of work + broadcastBlock blockBroadcasterFn // Broadcasts a block to connected peers + chainHeight chainHeightFn // Retrieves the current chain's height + insertHeaders headersInsertFn // Injects a batch of headers into the chain + insertChain chainInsertFn // Injects a batch of blocks into the chain + dropPeer peerDropFn // Drops a peer for misbehaving + + // Testing hooks + announceChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a hash from the blockAnnounce list + queueChangeHook func(common.Hash, bool) // Method to call upon adding or deleting a block from the import queue + fetchingHook func([]common.Hash) // Method to call upon starting a block (eth/61) or header (eth/62) fetch + completingHook func([]common.Hash) // Method to call upon starting a block body fetch (eth/62) + importedHook func(*types.Header, *types.Block) // Method to call upon successful header or block import (both eth/61 and eth/62) +} + +// NewBlockFetcher creates a block fetcher to retrieve blocks based on hash announcements. +func NewBlockFetcher(light bool, getHeader HeaderRetrievalFn, getBlock blockRetrievalFn, verifyHeader headerVerifierFn, broadcastBlock blockBroadcasterFn, chainHeight chainHeightFn, insertHeaders headersInsertFn, insertChain chainInsertFn, dropPeer peerDropFn) *BlockFetcher { + return &BlockFetcher{ + light: light, + notify: make(chan *blockAnnounce), + inject: make(chan *blockOrHeaderInject), + headerFilter: make(chan chan *headerFilterTask), + bodyFilter: make(chan chan *bodyFilterTask), + done: make(chan common.Hash), + quit: make(chan struct{}), + announces: make(map[string]int), + announced: make(map[common.Hash][]*blockAnnounce), + fetching: make(map[common.Hash]*blockAnnounce), + fetched: make(map[common.Hash][]*blockAnnounce), + completing: make(map[common.Hash]*blockAnnounce), + queue: prque.New(nil), + queues: make(map[string]int), + queued: make(map[common.Hash]*blockOrHeaderInject), + getHeader: getHeader, + getBlock: getBlock, + verifyHeader: verifyHeader, + broadcastBlock: broadcastBlock, + chainHeight: chainHeight, + insertHeaders: insertHeaders, + insertChain: insertChain, + dropPeer: dropPeer, + } +} + +// Start boots up the announcement based synchroniser, accepting and processing +// hash notifications and block fetches until termination requested. +func (f *BlockFetcher) Start() { + go f.loop() +} + +// Stop terminates the announcement based synchroniser, canceling all pending +// operations. +func (f *BlockFetcher) Stop() { + close(f.quit) +} + +// Notify announces the fetcher of the potential availability of a new block in +// the network. +func (f *BlockFetcher) Notify(peer string, hash common.Hash, number uint64, time time.Time, + headerFetcher headerRequesterFn, bodyFetcher bodyRequesterFn) error { + block := &blockAnnounce{ + hash: hash, + number: number, + time: time, + origin: peer, + fetchHeader: headerFetcher, + fetchBodies: bodyFetcher, + } + select { + case f.notify <- block: + return nil + case <-f.quit: + return errTerminated + } +} + +// Enqueue tries to fill gaps the fetcher's future import queue. +func (f *BlockFetcher) Enqueue(peer string, block *types.Block) error { + op := &blockOrHeaderInject{ + origin: peer, + block: block, + } + select { + case f.inject <- op: + return nil + case <-f.quit: + return errTerminated + } +} + +// FilterHeaders extracts all the headers that were explicitly requested by the fetcher, +// returning those that should be handled differently. +func (f *BlockFetcher) FilterHeaders(peer string, headers []*types.Header, time time.Time) []*types.Header { + log.Trace("Filtering headers", "peer", peer, "headers", len(headers)) + + // Send the filter channel to the fetcher + filter := make(chan *headerFilterTask) + + select { + case f.headerFilter <- filter: + case <-f.quit: + return nil + } + // Request the filtering of the header list + select { + case filter <- &headerFilterTask{peer: peer, headers: headers, time: time}: + case <-f.quit: + return nil + } + // Retrieve the headers remaining after filtering + select { + case task := <-filter: + return task.headers + case <-f.quit: + return nil + } +} + +// FilterBodies extracts all the block bodies that were explicitly requested by +// the fetcher, returning those that should be handled differently. +func (f *BlockFetcher) FilterBodies(peer string, transactions [][]*types.Transaction, uncles [][]*types.Header, time time.Time) ([][]*types.Transaction, [][]*types.Header) { + log.Trace("Filtering bodies", "peer", peer, "txs", len(transactions), "uncles", len(uncles)) + + // Send the filter channel to the fetcher + filter := make(chan *bodyFilterTask) + + select { + case f.bodyFilter <- filter: + case <-f.quit: + return nil, nil + } + // Request the filtering of the body list + select { + case filter <- &bodyFilterTask{peer: peer, transactions: transactions, uncles: uncles, time: time}: + case <-f.quit: + return nil, nil + } + // Retrieve the bodies remaining after filtering + select { + case task := <-filter: + return task.transactions, task.uncles + case <-f.quit: + return nil, nil + } +} + +// Loop is the main fetcher loop, checking and processing various notification +// events. +func (f *BlockFetcher) loop() { + // Iterate the block fetching until a quit is requested + var ( + fetchTimer = time.NewTimer(0) + completeTimer = time.NewTimer(0) + ) + <-fetchTimer.C // clear out the channel + <-completeTimer.C + defer fetchTimer.Stop() + defer completeTimer.Stop() + + for { + // Clean up any expired block fetches + for hash, announce := range f.fetching { + if time.Since(announce.time) > fetchTimeout { + f.forgetHash(hash) + } + } + // Import any queued blocks that could potentially fit + height := f.chainHeight() + for !f.queue.Empty() { + op := f.queue.PopItem().(*blockOrHeaderInject) + hash := op.hash() + if f.queueChangeHook != nil { + f.queueChangeHook(hash, false) + } + // If too high up the chain or phase, continue later + number := op.number() + if number > height+1 { + f.queue.Push(op, -int64(number)) + if f.queueChangeHook != nil { + f.queueChangeHook(hash, true) + } + break + } + // Otherwise if fresh and still unknown, try and import + if (number+maxUncleDist < height) || (f.light && f.getHeader(hash) != nil) || (!f.light && f.getBlock(hash) != nil) { + f.forgetBlock(hash) + continue + } + if f.light { + f.importHeaders(op.origin, op.header) + } else { + f.importBlocks(op.origin, op.block) + } + } + // Wait for an outside event to occur + select { + case <-f.quit: + // BlockFetcher terminating, abort all operations + return + + case notification := <-f.notify: + // A block was announced, make sure the peer isn't DOSing us + blockAnnounceInMeter.Mark(1) + + count := f.announces[notification.origin] + 1 + if count > hashLimit { + log.Debug("Peer exceeded outstanding announces", "peer", notification.origin, "limit", hashLimit) + blockAnnounceDOSMeter.Mark(1) + break + } + // If we have a valid block number, check that it's potentially useful + if notification.number > 0 { + if dist := int64(notification.number) - int64(f.chainHeight()); dist < -maxUncleDist || dist > maxQueueDist { + log.Debug("Peer discarded announcement", "peer", notification.origin, "number", notification.number, "hash", notification.hash, "distance", dist) + blockAnnounceDropMeter.Mark(1) + break + } + } + // All is well, schedule the announce if block's not yet downloading + if _, ok := f.fetching[notification.hash]; ok { + break + } + if _, ok := f.completing[notification.hash]; ok { + break + } + f.announces[notification.origin] = count + f.announced[notification.hash] = append(f.announced[notification.hash], notification) + if f.announceChangeHook != nil && len(f.announced[notification.hash]) == 1 { + f.announceChangeHook(notification.hash, true) + } + if len(f.announced) == 1 { + f.rescheduleFetch(fetchTimer) + } + + case op := <-f.inject: + // A direct block insertion was requested, try and fill any pending gaps + blockBroadcastInMeter.Mark(1) + + // Now only direct block injection is allowed, drop the header injection + // here silently if we receive. + if f.light { + continue + } + f.enqueue(op.origin, nil, op.block) + + case hash := <-f.done: + // A pending import finished, remove all traces of the notification + f.forgetHash(hash) + f.forgetBlock(hash) + + case <-fetchTimer.C: + // At least one block's timer ran out, check for needing retrieval + request := make(map[string][]common.Hash) + + for hash, announces := range f.announced { + // In current LES protocol(les2/les3), only header announce is + // available, no need to wait too much time for header broadcast. + timeout := arriveTimeout - gatherSlack + if f.light { + timeout = 0 + } + if time.Since(announces[0].time) > timeout { + // Pick a random peer to retrieve from, reset all others + announce := announces[rand.Intn(len(announces))] + f.forgetHash(hash) + + // If the block still didn't arrive, queue for fetching + if (f.light && f.getHeader(hash) == nil) || (!f.light && f.getBlock(hash) == nil) { + request[announce.origin] = append(request[announce.origin], hash) + f.fetching[hash] = announce + } + } + } + // Send out all block header requests + for peer, hashes := range request { + log.Trace("Fetching scheduled headers", "peer", peer, "list", hashes) + + // Create a closure of the fetch and schedule in on a new thread + fetchHeader, hashes := f.fetching[hashes[0]].fetchHeader, hashes + go func() { + if f.fetchingHook != nil { + f.fetchingHook(hashes) + } + for _, hash := range hashes { + headerFetchMeter.Mark(1) + fetchHeader(hash) // Suboptimal, but protocol doesn't allow batch header retrievals + } + }() + } + // Schedule the next fetch if blocks are still pending + f.rescheduleFetch(fetchTimer) + + case <-completeTimer.C: + // At least one header's timer ran out, retrieve everything + request := make(map[string][]common.Hash) + + for hash, announces := range f.fetched { + // Pick a random peer to retrieve from, reset all others + announce := announces[rand.Intn(len(announces))] + f.forgetHash(hash) + + // If the block still didn't arrive, queue for completion + if f.getBlock(hash) == nil { + request[announce.origin] = append(request[announce.origin], hash) + f.completing[hash] = announce + } + } + // Send out all block body requests + for peer, hashes := range request { + log.Trace("Fetching scheduled bodies", "peer", peer, "list", hashes) + + // Create a closure of the fetch and schedule in on a new thread + if f.completingHook != nil { + f.completingHook(hashes) + } + bodyFetchMeter.Mark(int64(len(hashes))) + go f.completing[hashes[0]].fetchBodies(hashes) + } + // Schedule the next fetch if blocks are still pending + f.rescheduleComplete(completeTimer) + + case filter := <-f.headerFilter: + // Headers arrived from a remote peer. Extract those that were explicitly + // requested by the fetcher, and return everything else so it's delivered + // to other parts of the system. + var task *headerFilterTask + select { + case task = <-filter: + case <-f.quit: + return + } + headerFilterInMeter.Mark(int64(len(task.headers))) + + // Split the batch of headers into unknown ones (to return to the caller), + // known incomplete ones (requiring body retrievals) and completed blocks. + unknown, incomplete, complete, lightHeaders := []*types.Header{}, []*blockAnnounce{}, []*types.Block{}, []*blockAnnounce{} + for _, header := range task.headers { + hash := header.Hash() + + // Filter fetcher-requested headers from other synchronisation algorithms + if announce := f.fetching[hash]; announce != nil && announce.origin == task.peer && f.fetched[hash] == nil && f.completing[hash] == nil && f.queued[hash] == nil { + // If the delivered header does not match the promised number, drop the announcer + if header.Number.Uint64() != announce.number { + log.Trace("Invalid block number fetched", "peer", announce.origin, "hash", header.Hash(), "announced", announce.number, "provided", header.Number) + f.dropPeer(announce.origin) + f.forgetHash(hash) + continue + } + // Collect all headers only if we are running in light + // mode and the headers are not imported by other means. + if f.light { + if f.getHeader(hash) == nil { + announce.header = header + lightHeaders = append(lightHeaders, announce) + } + f.forgetHash(hash) + continue + } + // Only keep if not imported by other means + if f.getBlock(hash) == nil { + announce.header = header + announce.time = task.time + + // If the block is empty (header only), short circuit into the final import queue + if header.TxHash == types.EmptyRootHash && header.UncleHash == types.EmptyUncleHash { + log.Trace("Block empty, skipping body retrieval", "peer", announce.origin, "number", header.Number, "hash", header.Hash()) + + block := types.NewBlockWithHeader(header) + block.ReceivedAt = task.time + + complete = append(complete, block) + f.completing[hash] = announce + continue + } + // Otherwise add to the list of blocks needing completion + incomplete = append(incomplete, announce) + } else { + log.Trace("Block already imported, discarding header", "peer", announce.origin, "number", header.Number, "hash", header.Hash()) + f.forgetHash(hash) + } + } else { + // BlockFetcher doesn't know about it, add to the return list + unknown = append(unknown, header) + } + } + headerFilterOutMeter.Mark(int64(len(unknown))) + select { + case filter <- &headerFilterTask{headers: unknown, time: task.time}: + case <-f.quit: + return + } + // Schedule the retrieved headers for body completion + for _, announce := range incomplete { + hash := announce.header.Hash() + if _, ok := f.completing[hash]; ok { + continue + } + f.fetched[hash] = append(f.fetched[hash], announce) + if len(f.fetched) == 1 { + f.rescheduleComplete(completeTimer) + } + } + // Schedule the header for light fetcher import + for _, announce := range lightHeaders { + f.enqueue(announce.origin, announce.header, nil) + } + // Schedule the header-only blocks for import + for _, block := range complete { + if announce := f.completing[block.Hash()]; announce != nil { + f.enqueue(announce.origin, nil, block) + } + } + + case filter := <-f.bodyFilter: + // Block bodies arrived, extract any explicitly requested blocks, return the rest + var task *bodyFilterTask + select { + case task = <-filter: + case <-f.quit: + return + } + bodyFilterInMeter.Mark(int64(len(task.transactions))) + blocks := []*types.Block{} + // abort early if there's nothing explicitly requested + if len(f.completing) > 0 { + for i := 0; i < len(task.transactions) && i < len(task.uncles); i++ { + // Match up a body to any possible completion request + var ( + matched = false + uncleHash common.Hash // calculated lazily and reused + txnHash common.Hash // calculated lazily and reused + ) + for hash, announce := range f.completing { + if f.queued[hash] != nil || announce.origin != task.peer { + continue + } + if uncleHash == (common.Hash{}) { + uncleHash = types.CalcUncleHash(task.uncles[i]) + } + if uncleHash != announce.header.UncleHash { + continue + } + if txnHash == (common.Hash{}) { + txnHash = types.DeriveSha(types.Transactions(task.transactions[i]), trie.NewStackTrie(nil)) + } + if txnHash != announce.header.TxHash { + continue + } + // Mark the body matched, reassemble if still unknown + matched = true + if f.getBlock(hash) == nil { + block := types.NewBlockWithHeader(announce.header).WithBody(task.transactions[i], task.uncles[i]) + block.ReceivedAt = task.time + blocks = append(blocks, block) + } else { + f.forgetHash(hash) + } + + } + if matched { + task.transactions = append(task.transactions[:i], task.transactions[i+1:]...) + task.uncles = append(task.uncles[:i], task.uncles[i+1:]...) + i-- + continue + } + } + } + bodyFilterOutMeter.Mark(int64(len(task.transactions))) + select { + case filter <- task: + case <-f.quit: + return + } + // Schedule the retrieved blocks for ordered import + for _, block := range blocks { + if announce := f.completing[block.Hash()]; announce != nil { + f.enqueue(announce.origin, nil, block) + } + } + } + } +} + +// rescheduleFetch resets the specified fetch timer to the next blockAnnounce timeout. +func (f *BlockFetcher) rescheduleFetch(fetch *time.Timer) { + // Short circuit if no blocks are announced + if len(f.announced) == 0 { + return + } + // Schedule announcement retrieval quickly for light mode + // since server won't send any headers to client. + if f.light { + fetch.Reset(lightTimeout) + return + } + // Otherwise find the earliest expiring announcement + earliest := time.Now() + for _, announces := range f.announced { + if earliest.After(announces[0].time) { + earliest = announces[0].time + } + } + fetch.Reset(arriveTimeout - time.Since(earliest)) +} + +// rescheduleComplete resets the specified completion timer to the next fetch timeout. +func (f *BlockFetcher) rescheduleComplete(complete *time.Timer) { + // Short circuit if no headers are fetched + if len(f.fetched) == 0 { + return + } + // Otherwise find the earliest expiring announcement + earliest := time.Now() + for _, announces := range f.fetched { + if earliest.After(announces[0].time) { + earliest = announces[0].time + } + } + complete.Reset(gatherSlack - time.Since(earliest)) +} + +// enqueue schedules a new header or block import operation, if the component +// to be imported has not yet been seen. +func (f *BlockFetcher) enqueue(peer string, header *types.Header, block *types.Block) { + var ( + hash common.Hash + number uint64 + ) + if header != nil { + hash, number = header.Hash(), header.Number.Uint64() + } else { + hash, number = block.Hash(), block.NumberU64() + } + // Ensure the peer isn't DOSing us + count := f.queues[peer] + 1 + if count > blockLimit { + log.Debug("Discarded delivered header or block, exceeded allowance", "peer", peer, "number", number, "hash", hash, "limit", blockLimit) + blockBroadcastDOSMeter.Mark(1) + f.forgetHash(hash) + return + } + // Discard any past or too distant blocks + if dist := int64(number) - int64(f.chainHeight()); dist < -maxUncleDist || dist > maxQueueDist { + log.Debug("Discarded delivered header or block, too far away", "peer", peer, "number", number, "hash", hash, "distance", dist) + blockBroadcastDropMeter.Mark(1) + f.forgetHash(hash) + return + } + // Schedule the block for future importing + if _, ok := f.queued[hash]; !ok { + op := &blockOrHeaderInject{origin: peer} + if header != nil { + op.header = header + } else { + op.block = block + } + f.queues[peer] = count + f.queued[hash] = op + f.queue.Push(op, -int64(number)) + if f.queueChangeHook != nil { + f.queueChangeHook(hash, true) + } + log.Debug("Queued delivered header or block", "peer", peer, "number", number, "hash", hash, "queued", f.queue.Size()) + } +} + +// importHeaders spawns a new goroutine to run a header insertion into the chain. +// If the header's number is at the same height as the current import phase, it +// updates the phase states accordingly. +func (f *BlockFetcher) importHeaders(peer string, header *types.Header) { + hash := header.Hash() + log.Debug("Importing propagated header", "peer", peer, "number", header.Number, "hash", hash) + + go func() { + defer func() { f.done <- hash }() + // If the parent's unknown, abort insertion + parent := f.getHeader(header.ParentHash) + if parent == nil { + log.Debug("Unknown parent of propagated header", "peer", peer, "number", header.Number, "hash", hash, "parent", header.ParentHash) + return + } + // Validate the header and if something went wrong, drop the peer + if err := f.verifyHeader(header); err != nil && err != consensus.ErrFutureBlock { + log.Debug("Propagated header verification failed", "peer", peer, "number", header.Number, "hash", hash, "err", err) + f.dropPeer(peer) + return + } + // Run the actual import and log any issues + if _, err := f.insertHeaders([]*types.Header{header}); err != nil { + log.Debug("Propagated header import failed", "peer", peer, "number", header.Number, "hash", hash, "err", err) + return + } + // Invoke the testing hook if needed + if f.importedHook != nil { + f.importedHook(header, nil) + } + }() +} + +// importBlocks spawns a new goroutine to run a block insertion into the chain. If the +// block's number is at the same height as the current import phase, it updates +// the phase states accordingly. +func (f *BlockFetcher) importBlocks(peer string, block *types.Block) { + hash := block.Hash() + + // Run the import on a new thread + log.Debug("Importing propagated block", "peer", peer, "number", block.Number(), "hash", hash) + go func() { + defer func() { f.done <- hash }() + + // If the parent's unknown, abort insertion + parent := f.getBlock(block.ParentHash()) + if parent == nil { + log.Debug("Unknown parent of propagated block", "peer", peer, "number", block.Number(), "hash", hash, "parent", block.ParentHash()) + return + } + // Quickly validate the header and propagate the block if it passes + switch err := f.verifyHeader(block.Header()); err { + case nil: + // All ok, quickly propagate to our peers + blockBroadcastOutTimer.UpdateSince(block.ReceivedAt) + go f.broadcastBlock(block, true) + + case consensus.ErrFutureBlock: + // Weird future block, don't fail, but neither propagate + + default: + // Something went very wrong, drop the peer + log.Debug("Propagated block verification failed", "peer", peer, "number", block.Number(), "hash", hash, "err", err) + f.dropPeer(peer) + return + } + // Run the actual import and log any issues + if _, err := f.insertChain(types.Blocks{block}); err != nil { + log.Debug("Propagated block import failed", "peer", peer, "number", block.Number(), "hash", hash, "err", err) + return + } + // If import succeeded, broadcast the block + blockAnnounceOutTimer.UpdateSince(block.ReceivedAt) + go f.broadcastBlock(block, false) + + // Invoke the testing hook if needed + if f.importedHook != nil { + f.importedHook(nil, block) + } + }() +} + +// forgetHash removes all traces of a block announcement from the fetcher's +// internal state. +func (f *BlockFetcher) forgetHash(hash common.Hash) { + // Remove all pending announces and decrement DOS counters + if announceMap, ok := f.announced[hash]; ok { + for _, announce := range announceMap { + f.announces[announce.origin]-- + if f.announces[announce.origin] <= 0 { + delete(f.announces, announce.origin) + } + } + delete(f.announced, hash) + if f.announceChangeHook != nil { + f.announceChangeHook(hash, false) + } + } + // Remove any pending fetches and decrement the DOS counters + if announce := f.fetching[hash]; announce != nil { + f.announces[announce.origin]-- + if f.announces[announce.origin] <= 0 { + delete(f.announces, announce.origin) + } + delete(f.fetching, hash) + } + + // Remove any pending completion requests and decrement the DOS counters + for _, announce := range f.fetched[hash] { + f.announces[announce.origin]-- + if f.announces[announce.origin] <= 0 { + delete(f.announces, announce.origin) + } + } + delete(f.fetched, hash) + + // Remove any pending completions and decrement the DOS counters + if announce := f.completing[hash]; announce != nil { + f.announces[announce.origin]-- + if f.announces[announce.origin] <= 0 { + delete(f.announces, announce.origin) + } + delete(f.completing, hash) + } +} + +// forgetBlock removes all traces of a queued block from the fetcher's internal +// state. +func (f *BlockFetcher) forgetBlock(hash common.Hash) { + if insert := f.queued[hash]; insert != nil { + f.queues[insert.origin]-- + if f.queues[insert.origin] == 0 { + delete(f.queues, insert.origin) + } + delete(f.queued, hash) + } +} diff --git a/les/fetcher/block_fetcher_test.go b/les/fetcher/block_fetcher_test.go new file mode 100644 index 0000000000000..b6d1125b56537 --- /dev/null +++ b/les/fetcher/block_fetcher_test.go @@ -0,0 +1,896 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package fetcher + +import ( + "errors" + "math/big" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" +) + +var ( + testdb = rawdb.NewMemoryDatabase() + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + testAddress = crypto.PubkeyToAddress(testKey.PublicKey) + genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000000000)) + unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit, BaseFee: big.NewInt(params.InitialBaseFee)}, nil, nil, nil, trie.NewStackTrie(nil)) +) + +// makeChain creates a chain of n blocks starting at and including parent. +// the returned hash chain is ordered head->parent. In addition, every 3rd block +// contains a transaction and every 5th an uncle to allow testing correct block +// reassembly. +func makeChain(n int, seed byte, parent *types.Block) ([]common.Hash, map[common.Hash]*types.Block) { + blocks, _ := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), testdb, n, func(i int, block *core.BlockGen) { + block.SetCoinbase(common.Address{seed}) + + // If the block number is multiple of 3, send a bonus transaction to the miner + if parent == genesis && i%3 == 0 { + signer := types.MakeSigner(params.TestChainConfig, block.Number()) + tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, block.BaseFee(), nil), signer, testKey) + if err != nil { + panic(err) + } + block.AddTx(tx) + } + // If the block number is a multiple of 5, add a bonus uncle to the block + if i%5 == 0 { + block.AddUncle(&types.Header{ParentHash: block.PrevBlock(i - 1).Hash(), Number: big.NewInt(int64(i - 1))}) + } + }) + hashes := make([]common.Hash, n+1) + hashes[len(hashes)-1] = parent.Hash() + blockm := make(map[common.Hash]*types.Block, n+1) + blockm[parent.Hash()] = parent + for i, b := range blocks { + hashes[len(hashes)-i-2] = b.Hash() + blockm[b.Hash()] = b + } + return hashes, blockm +} + +// fetcherTester is a test simulator for mocking out local block chain. +type fetcherTester struct { + fetcher *BlockFetcher + + hashes []common.Hash // Hash chain belonging to the tester + headers map[common.Hash]*types.Header // Headers belonging to the tester + blocks map[common.Hash]*types.Block // Blocks belonging to the tester + drops map[string]bool // Map of peers dropped by the fetcher + + lock sync.RWMutex +} + +// newTester creates a new fetcher test mocker. +func newTester(light bool) *fetcherTester { + tester := &fetcherTester{ + hashes: []common.Hash{genesis.Hash()}, + headers: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()}, + blocks: map[common.Hash]*types.Block{genesis.Hash(): genesis}, + drops: make(map[string]bool), + } + tester.fetcher = NewBlockFetcher(light, tester.getHeader, tester.getBlock, tester.verifyHeader, tester.broadcastBlock, tester.chainHeight, tester.insertHeaders, tester.insertChain, tester.dropPeer) + tester.fetcher.Start() + + return tester +} + +// getHeader retrieves a header from the tester's block chain. +func (f *fetcherTester) getHeader(hash common.Hash) *types.Header { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.headers[hash] +} + +// getBlock retrieves a block from the tester's block chain. +func (f *fetcherTester) getBlock(hash common.Hash) *types.Block { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.blocks[hash] +} + +// verifyHeader is a nop placeholder for the block header verification. +func (f *fetcherTester) verifyHeader(header *types.Header) error { + return nil +} + +// broadcastBlock is a nop placeholder for the block broadcasting. +func (f *fetcherTester) broadcastBlock(block *types.Block, propagate bool) { +} + +// chainHeight retrieves the current height (block number) of the chain. +func (f *fetcherTester) chainHeight() uint64 { + f.lock.RLock() + defer f.lock.RUnlock() + + if f.fetcher.light { + return f.headers[f.hashes[len(f.hashes)-1]].Number.Uint64() + } + return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() +} + +// insertChain injects a new headers into the simulated chain. +func (f *fetcherTester) insertHeaders(headers []*types.Header) (int, error) { + f.lock.Lock() + defer f.lock.Unlock() + + for i, header := range headers { + // Make sure the parent in known + if _, ok := f.headers[header.ParentHash]; !ok { + return i, errors.New("unknown parent") + } + // Discard any new blocks if the same height already exists + if header.Number.Uint64() <= f.headers[f.hashes[len(f.hashes)-1]].Number.Uint64() { + return i, nil + } + // Otherwise build our current chain + f.hashes = append(f.hashes, header.Hash()) + f.headers[header.Hash()] = header + } + return 0, nil +} + +// insertChain injects a new blocks into the simulated chain. +func (f *fetcherTester) insertChain(blocks types.Blocks) (int, error) { + f.lock.Lock() + defer f.lock.Unlock() + + for i, block := range blocks { + // Make sure the parent in known + if _, ok := f.blocks[block.ParentHash()]; !ok { + return i, errors.New("unknown parent") + } + // Discard any new blocks if the same height already exists + if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() { + return i, nil + } + // Otherwise build our current chain + f.hashes = append(f.hashes, block.Hash()) + f.blocks[block.Hash()] = block + } + return 0, nil +} + +// dropPeer is an emulator for the peer removal, simply accumulating the various +// peers dropped by the fetcher. +func (f *fetcherTester) dropPeer(peer string) { + f.lock.Lock() + defer f.lock.Unlock() + + f.drops[peer] = true +} + +// makeHeaderFetcher retrieves a block header fetcher associated with a simulated peer. +func (f *fetcherTester) makeHeaderFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) headerRequesterFn { + closure := make(map[common.Hash]*types.Block) + for hash, block := range blocks { + closure[hash] = block + } + // Create a function that return a header from the closure + return func(hash common.Hash) error { + // Gather the blocks to return + headers := make([]*types.Header, 0, 1) + if block, ok := closure[hash]; ok { + headers = append(headers, block.Header()) + } + // Return on a new thread + go f.fetcher.FilterHeaders(peer, headers, time.Now().Add(drift)) + + return nil + } +} + +// makeBodyFetcher retrieves a block body fetcher associated with a simulated peer. +func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) bodyRequesterFn { + closure := make(map[common.Hash]*types.Block) + for hash, block := range blocks { + closure[hash] = block + } + // Create a function that returns blocks from the closure + return func(hashes []common.Hash) error { + // Gather the block bodies to return + transactions := make([][]*types.Transaction, 0, len(hashes)) + uncles := make([][]*types.Header, 0, len(hashes)) + + for _, hash := range hashes { + if block, ok := closure[hash]; ok { + transactions = append(transactions, block.Transactions()) + uncles = append(uncles, block.Uncles()) + } + } + // Return on a new thread + go f.fetcher.FilterBodies(peer, transactions, uncles, time.Now().Add(drift)) + + return nil + } +} + +// verifyFetchingEvent verifies that one single event arrive on a fetching channel. +func verifyFetchingEvent(t *testing.T, fetching chan []common.Hash, arrive bool) { + if arrive { + select { + case <-fetching: + case <-time.After(time.Second): + t.Fatalf("fetching timeout") + } + } else { + select { + case <-fetching: + t.Fatalf("fetching invoked") + case <-time.After(10 * time.Millisecond): + } + } +} + +// verifyCompletingEvent verifies that one single event arrive on an completing channel. +func verifyCompletingEvent(t *testing.T, completing chan []common.Hash, arrive bool) { + if arrive { + select { + case <-completing: + case <-time.After(time.Second): + t.Fatalf("completing timeout") + } + } else { + select { + case <-completing: + t.Fatalf("completing invoked") + case <-time.After(10 * time.Millisecond): + } + } +} + +// verifyImportEvent verifies that one single event arrive on an import channel. +func verifyImportEvent(t *testing.T, imported chan interface{}, arrive bool) { + if arrive { + select { + case <-imported: + case <-time.After(time.Second): + t.Fatalf("import timeout") + } + } else { + select { + case <-imported: + t.Fatalf("import invoked") + case <-time.After(20 * time.Millisecond): + } + } +} + +// verifyImportCount verifies that exactly count number of events arrive on an +// import hook channel. +func verifyImportCount(t *testing.T, imported chan interface{}, count int) { + for i := 0; i < count; i++ { + select { + case <-imported: + case <-time.After(time.Second): + t.Fatalf("block %d: import timeout", i+1) + } + } + verifyImportDone(t, imported) +} + +// verifyImportDone verifies that no more events are arriving on an import channel. +func verifyImportDone(t *testing.T, imported chan interface{}) { + select { + case <-imported: + t.Fatalf("extra block imported") + case <-time.After(50 * time.Millisecond): + } +} + +// verifyChainHeight verifies the chain height is as expected. +func verifyChainHeight(t *testing.T, fetcher *fetcherTester, height uint64) { + if fetcher.chainHeight() != height { + t.Fatalf("chain height mismatch, got %d, want %d", fetcher.chainHeight(), height) + } +} + +// Tests that a fetcher accepts block/header announcements and initiates retrievals +// for them, successfully importing into the local chain. +func TestFullSequentialAnnouncements(t *testing.T) { testSequentialAnnouncements(t, false) } +func TestLightSequentialAnnouncements(t *testing.T) { testSequentialAnnouncements(t, true) } + +func testSequentialAnnouncements(t *testing.T, light bool) { + // Create a chain of blocks to import + targetBlocks := 4 * hashLimit + hashes, blocks := makeChain(targetBlocks, 0, genesis) + + tester := newTester(light) + headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + // Iteratively announce blocks until all are imported + imported := make(chan interface{}) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { + if light { + if header == nil { + t.Fatalf("Fetcher try to import empty header") + } + imported <- header + } else { + if block == nil { + t.Fatalf("Fetcher try to import empty block") + } + imported <- block + } + } + for i := len(hashes) - 2; i >= 0; i-- { + tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + verifyImportEvent(t, imported, true) + } + verifyImportDone(t, imported) + verifyChainHeight(t, tester, uint64(len(hashes)-1)) +} + +// Tests that if blocks are announced by multiple peers (or even the same buggy +// peer), they will only get downloaded at most once. +func TestFullConcurrentAnnouncements(t *testing.T) { testConcurrentAnnouncements(t, false) } +func TestLightConcurrentAnnouncements(t *testing.T) { testConcurrentAnnouncements(t, true) } + +func testConcurrentAnnouncements(t *testing.T, light bool) { + // Create a chain of blocks to import + targetBlocks := 4 * hashLimit + hashes, blocks := makeChain(targetBlocks, 0, genesis) + + // Assemble a tester with a built in counter for the requests + tester := newTester(light) + firstHeaderFetcher := tester.makeHeaderFetcher("first", blocks, -gatherSlack) + firstBodyFetcher := tester.makeBodyFetcher("first", blocks, 0) + secondHeaderFetcher := tester.makeHeaderFetcher("second", blocks, -gatherSlack) + secondBodyFetcher := tester.makeBodyFetcher("second", blocks, 0) + + counter := uint32(0) + firstHeaderWrapper := func(hash common.Hash) error { + atomic.AddUint32(&counter, 1) + return firstHeaderFetcher(hash) + } + secondHeaderWrapper := func(hash common.Hash) error { + atomic.AddUint32(&counter, 1) + return secondHeaderFetcher(hash) + } + // Iteratively announce blocks until all are imported + imported := make(chan interface{}) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { + if light { + if header == nil { + t.Fatalf("Fetcher try to import empty header") + } + imported <- header + } else { + if block == nil { + t.Fatalf("Fetcher try to import empty block") + } + imported <- block + } + } + for i := len(hashes) - 2; i >= 0; i-- { + tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher) + tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher) + tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher) + verifyImportEvent(t, imported, true) + } + verifyImportDone(t, imported) + + // Make sure no blocks were retrieved twice + if int(counter) != targetBlocks { + t.Fatalf("retrieval count mismatch: have %v, want %v", counter, targetBlocks) + } + verifyChainHeight(t, tester, uint64(len(hashes)-1)) +} + +// Tests that announcements arriving while a previous is being fetched still +// results in a valid import. +func TestFullOverlappingAnnouncements(t *testing.T) { testOverlappingAnnouncements(t, false) } +func TestLightOverlappingAnnouncements(t *testing.T) { testOverlappingAnnouncements(t, true) } + +func testOverlappingAnnouncements(t *testing.T, light bool) { + // Create a chain of blocks to import + targetBlocks := 4 * hashLimit + hashes, blocks := makeChain(targetBlocks, 0, genesis) + + tester := newTester(light) + headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + // Iteratively announce blocks, but overlap them continuously + overlap := 16 + imported := make(chan interface{}, len(hashes)-1) + for i := 0; i < overlap; i++ { + imported <- nil + } + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { + if light { + if header == nil { + t.Fatalf("Fetcher try to import empty header") + } + imported <- header + } else { + if block == nil { + t.Fatalf("Fetcher try to import empty block") + } + imported <- block + } + } + + for i := len(hashes) - 2; i >= 0; i-- { + tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + select { + case <-imported: + case <-time.After(time.Second): + t.Fatalf("block %d: import timeout", len(hashes)-i) + } + } + // Wait for all the imports to complete and check count + verifyImportCount(t, imported, overlap) + verifyChainHeight(t, tester, uint64(len(hashes)-1)) +} + +// Tests that announces already being retrieved will not be duplicated. +func TestFullPendingDeduplication(t *testing.T) { testPendingDeduplication(t, false) } +func TestLightPendingDeduplication(t *testing.T) { testPendingDeduplication(t, true) } + +func testPendingDeduplication(t *testing.T, light bool) { + // Create a hash and corresponding block + hashes, blocks := makeChain(1, 0, genesis) + + // Assemble a tester with a built in counter and delayed fetcher + tester := newTester(light) + headerFetcher := tester.makeHeaderFetcher("repeater", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("repeater", blocks, 0) + + delay := 50 * time.Millisecond + counter := uint32(0) + headerWrapper := func(hash common.Hash) error { + atomic.AddUint32(&counter, 1) + + // Simulate a long running fetch + go func() { + time.Sleep(delay) + headerFetcher(hash) + }() + return nil + } + checkNonExist := func() bool { + return tester.getBlock(hashes[0]) == nil + } + if light { + checkNonExist = func() bool { + return tester.getHeader(hashes[0]) == nil + } + } + // Announce the same block many times until it's fetched (wait for any pending ops) + for checkNonExist() { + tester.fetcher.Notify("repeater", hashes[0], 1, time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher) + time.Sleep(time.Millisecond) + } + time.Sleep(delay) + + // Check that all blocks were imported and none fetched twice + if int(counter) != 1 { + t.Fatalf("retrieval count mismatch: have %v, want %v", counter, 1) + } + verifyChainHeight(t, tester, 1) +} + +// Tests that announcements retrieved in a random order are cached and eventually +// imported when all the gaps are filled in. +func TestFullRandomArrivalImport(t *testing.T) { testRandomArrivalImport(t, false) } +func TestLightRandomArrivalImport(t *testing.T) { testRandomArrivalImport(t, true) } + +func testRandomArrivalImport(t *testing.T, light bool) { + // Create a chain of blocks to import, and choose one to delay + targetBlocks := maxQueueDist + hashes, blocks := makeChain(targetBlocks, 0, genesis) + skip := targetBlocks / 2 + + tester := newTester(light) + headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + // Iteratively announce blocks, skipping one entry + imported := make(chan interface{}, len(hashes)-1) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { + if light { + if header == nil { + t.Fatalf("Fetcher try to import empty header") + } + imported <- header + } else { + if block == nil { + t.Fatalf("Fetcher try to import empty block") + } + imported <- block + } + } + for i := len(hashes) - 1; i >= 0; i-- { + if i != skip { + tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + time.Sleep(time.Millisecond) + } + } + // Finally announce the skipped entry and check full import + tester.fetcher.Notify("valid", hashes[skip], uint64(len(hashes)-skip-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + verifyImportCount(t, imported, len(hashes)-1) + verifyChainHeight(t, tester, uint64(len(hashes)-1)) +} + +// Tests that direct block enqueues (due to block propagation vs. hash announce) +// are correctly schedule, filling and import queue gaps. +func TestQueueGapFill(t *testing.T) { + // Create a chain of blocks to import, and choose one to not announce at all + targetBlocks := maxQueueDist + hashes, blocks := makeChain(targetBlocks, 0, genesis) + skip := targetBlocks / 2 + + tester := newTester(false) + headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + // Iteratively announce blocks, skipping one entry + imported := make(chan interface{}, len(hashes)-1) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block } + + for i := len(hashes) - 1; i >= 0; i-- { + if i != skip { + tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + time.Sleep(time.Millisecond) + } + } + // Fill the missing block directly as if propagated + tester.fetcher.Enqueue("valid", blocks[hashes[skip]]) + verifyImportCount(t, imported, len(hashes)-1) + verifyChainHeight(t, tester, uint64(len(hashes)-1)) +} + +// Tests that blocks arriving from various sources (multiple propagations, hash +// announces, etc) do not get scheduled for import multiple times. +func TestImportDeduplication(t *testing.T) { + // Create two blocks to import (one for duplication, the other for stalling) + hashes, blocks := makeChain(2, 0, genesis) + + // Create the tester and wrap the importer with a counter + tester := newTester(false) + headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + counter := uint32(0) + tester.fetcher.insertChain = func(blocks types.Blocks) (int, error) { + atomic.AddUint32(&counter, uint32(len(blocks))) + return tester.insertChain(blocks) + } + // Instrument the fetching and imported events + fetching := make(chan []common.Hash) + imported := make(chan interface{}, len(hashes)-1) + tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes } + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block } + + // Announce the duplicating block, wait for retrieval, and also propagate directly + tester.fetcher.Notify("valid", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + <-fetching + + tester.fetcher.Enqueue("valid", blocks[hashes[0]]) + tester.fetcher.Enqueue("valid", blocks[hashes[0]]) + tester.fetcher.Enqueue("valid", blocks[hashes[0]]) + + // Fill the missing block directly as if propagated, and check import uniqueness + tester.fetcher.Enqueue("valid", blocks[hashes[1]]) + verifyImportCount(t, imported, 2) + + if counter != 2 { + t.Fatalf("import invocation count mismatch: have %v, want %v", counter, 2) + } +} + +// Tests that blocks with numbers much lower or higher than out current head get +// discarded to prevent wasting resources on useless blocks from faulty peers. +func TestDistantPropagationDiscarding(t *testing.T) { + // Create a long chain to import and define the discard boundaries + hashes, blocks := makeChain(3*maxQueueDist, 0, genesis) + head := hashes[len(hashes)/2] + + low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1 + + // Create a tester and simulate a head block being the middle of the above chain + tester := newTester(false) + + tester.lock.Lock() + tester.hashes = []common.Hash{head} + tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} + tester.lock.Unlock() + + // Ensure that a block with a lower number than the threshold is discarded + tester.fetcher.Enqueue("lower", blocks[hashes[low]]) + time.Sleep(10 * time.Millisecond) + if !tester.fetcher.queue.Empty() { + t.Fatalf("fetcher queued stale block") + } + // Ensure that a block with a higher number than the threshold is discarded + tester.fetcher.Enqueue("higher", blocks[hashes[high]]) + time.Sleep(10 * time.Millisecond) + if !tester.fetcher.queue.Empty() { + t.Fatalf("fetcher queued future block") + } +} + +// Tests that announcements with numbers much lower or higher than out current +// head get discarded to prevent wasting resources on useless blocks from faulty +// peers. +func TestFullDistantAnnouncementDiscarding(t *testing.T) { testDistantAnnouncementDiscarding(t, false) } +func TestLightDistantAnnouncementDiscarding(t *testing.T) { testDistantAnnouncementDiscarding(t, true) } + +func testDistantAnnouncementDiscarding(t *testing.T, light bool) { + // Create a long chain to import and define the discard boundaries + hashes, blocks := makeChain(3*maxQueueDist, 0, genesis) + head := hashes[len(hashes)/2] + + low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1 + + // Create a tester and simulate a head block being the middle of the above chain + tester := newTester(light) + + tester.lock.Lock() + tester.hashes = []common.Hash{head} + tester.headers = map[common.Hash]*types.Header{head: blocks[head].Header()} + tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} + tester.lock.Unlock() + + headerFetcher := tester.makeHeaderFetcher("lower", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("lower", blocks, 0) + + fetching := make(chan struct{}, 2) + tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- struct{}{} } + + // Ensure that a block with a lower number than the threshold is discarded + tester.fetcher.Notify("lower", hashes[low], blocks[hashes[low]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + select { + case <-time.After(50 * time.Millisecond): + case <-fetching: + t.Fatalf("fetcher requested stale header") + } + // Ensure that a block with a higher number than the threshold is discarded + tester.fetcher.Notify("higher", hashes[high], blocks[hashes[high]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + select { + case <-time.After(50 * time.Millisecond): + case <-fetching: + t.Fatalf("fetcher requested future header") + } +} + +// Tests that peers announcing blocks with invalid numbers (i.e. not matching +// the headers provided afterwards) get dropped as malicious. +func TestFullInvalidNumberAnnouncement(t *testing.T) { testInvalidNumberAnnouncement(t, false) } +func TestLightInvalidNumberAnnouncement(t *testing.T) { testInvalidNumberAnnouncement(t, true) } + +func testInvalidNumberAnnouncement(t *testing.T, light bool) { + // Create a single block to import and check numbers against + hashes, blocks := makeChain(1, 0, genesis) + + tester := newTester(light) + badHeaderFetcher := tester.makeHeaderFetcher("bad", blocks, -gatherSlack) + badBodyFetcher := tester.makeBodyFetcher("bad", blocks, 0) + + imported := make(chan interface{}) + announced := make(chan interface{}) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { + if light { + if header == nil { + t.Fatalf("Fetcher try to import empty header") + } + imported <- header + } else { + if block == nil { + t.Fatalf("Fetcher try to import empty block") + } + imported <- block + } + } + // Announce a block with a bad number, check for immediate drop + tester.fetcher.announceChangeHook = func(hash common.Hash, b bool) { + announced <- nil + } + tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher) + verifyAnnounce := func() { + for i := 0; i < 2; i++ { + select { + case <-announced: + continue + case <-time.After(1 * time.Second): + t.Fatal("announce timeout") + return + } + } + } + verifyAnnounce() + verifyImportEvent(t, imported, false) + tester.lock.RLock() + dropped := tester.drops["bad"] + tester.lock.RUnlock() + + if !dropped { + t.Fatalf("peer with invalid numbered announcement not dropped") + } + goodHeaderFetcher := tester.makeHeaderFetcher("good", blocks, -gatherSlack) + goodBodyFetcher := tester.makeBodyFetcher("good", blocks, 0) + // Make sure a good announcement passes without a drop + tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher) + verifyAnnounce() + verifyImportEvent(t, imported, true) + + tester.lock.RLock() + dropped = tester.drops["good"] + tester.lock.RUnlock() + + if dropped { + t.Fatalf("peer with valid numbered announcement dropped") + } + verifyImportDone(t, imported) +} + +// Tests that if a block is empty (i.e. header only), no body request should be +// made, and instead the header should be assembled into a whole block in itself. +func TestEmptyBlockShortCircuit(t *testing.T) { + // Create a chain of blocks to import + hashes, blocks := makeChain(32, 0, genesis) + + tester := newTester(false) + headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + // Add a monitoring hook for all internal events + fetching := make(chan []common.Hash) + tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes } + + completing := make(chan []common.Hash) + tester.fetcher.completingHook = func(hashes []common.Hash) { completing <- hashes } + + imported := make(chan interface{}) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { + if block == nil { + t.Fatalf("Fetcher try to import empty block") + } + imported <- block + } + // Iteratively announce blocks until all are imported + for i := len(hashes) - 2; i >= 0; i-- { + tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) + + // All announces should fetch the header + verifyFetchingEvent(t, fetching, true) + + // Only blocks with data contents should request bodies + verifyCompletingEvent(t, completing, len(blocks[hashes[i]].Transactions()) > 0 || len(blocks[hashes[i]].Uncles()) > 0) + + // Irrelevant of the construct, import should succeed + verifyImportEvent(t, imported, true) + } + verifyImportDone(t, imported) +} + +// Tests that a peer is unable to use unbounded memory with sending infinite +// block announcements to a node, but that even in the face of such an attack, +// the fetcher remains operational. +func TestHashMemoryExhaustionAttack(t *testing.T) { + // Create a tester with instrumented import hooks + tester := newTester(false) + + imported, announces := make(chan interface{}), int32(0) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block } + tester.fetcher.announceChangeHook = func(hash common.Hash, added bool) { + if added { + atomic.AddInt32(&announces, 1) + } else { + atomic.AddInt32(&announces, -1) + } + } + // Create a valid chain and an infinite junk chain + targetBlocks := hashLimit + 2*maxQueueDist + hashes, blocks := makeChain(targetBlocks, 0, genesis) + validHeaderFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) + validBodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) + + attack, _ := makeChain(targetBlocks, 0, unknownBlock) + attackerHeaderFetcher := tester.makeHeaderFetcher("attacker", nil, -gatherSlack) + attackerBodyFetcher := tester.makeBodyFetcher("attacker", nil, 0) + + // Feed the tester a huge hashset from the attacker, and a limited from the valid peer + for i := 0; i < len(attack); i++ { + if i < maxQueueDist { + tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], uint64(i+1), time.Now(), validHeaderFetcher, validBodyFetcher) + } + tester.fetcher.Notify("attacker", attack[i], 1 /* don't distance drop */, time.Now(), attackerHeaderFetcher, attackerBodyFetcher) + } + if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist { + t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist) + } + // Wait for fetches to complete + verifyImportCount(t, imported, maxQueueDist) + + // Feed the remaining valid hashes to ensure DOS protection state remains clean + for i := len(hashes) - maxQueueDist - 2; i >= 0; i-- { + tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), validHeaderFetcher, validBodyFetcher) + verifyImportEvent(t, imported, true) + } + verifyImportDone(t, imported) +} + +// Tests that blocks sent to the fetcher (either through propagation or via hash +// announces and retrievals) don't pile up indefinitely, exhausting available +// system memory. +func TestBlockMemoryExhaustionAttack(t *testing.T) { + // Create a tester with instrumented import hooks + tester := newTester(false) + + imported, enqueued := make(chan interface{}), int32(0) + tester.fetcher.importedHook = func(header *types.Header, block *types.Block) { imported <- block } + tester.fetcher.queueChangeHook = func(hash common.Hash, added bool) { + if added { + atomic.AddInt32(&enqueued, 1) + } else { + atomic.AddInt32(&enqueued, -1) + } + } + // Create a valid chain and a batch of dangling (but in range) blocks + targetBlocks := hashLimit + 2*maxQueueDist + hashes, blocks := makeChain(targetBlocks, 0, genesis) + attack := make(map[common.Hash]*types.Block) + for i := byte(0); len(attack) < blockLimit+2*maxQueueDist; i++ { + hashes, blocks := makeChain(maxQueueDist-1, i, unknownBlock) + for _, hash := range hashes[:maxQueueDist-2] { + attack[hash] = blocks[hash] + } + } + // Try to feed all the attacker blocks make sure only a limited batch is accepted + for _, block := range attack { + tester.fetcher.Enqueue("attacker", block) + } + time.Sleep(200 * time.Millisecond) + if queued := atomic.LoadInt32(&enqueued); queued != blockLimit { + t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit) + } + // Queue up a batch of valid blocks, and check that a new peer is allowed to do so + for i := 0; i < maxQueueDist-1; i++ { + tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]]) + } + time.Sleep(100 * time.Millisecond) + if queued := atomic.LoadInt32(&enqueued); queued != blockLimit+maxQueueDist-1 { + t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1) + } + // Insert the missing piece (and sanity check the import) + tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2]]) + verifyImportCount(t, imported, maxQueueDist) + + // Insert the remaining blocks in chunks to ensure clean DOS protection + for i := maxQueueDist; i < len(hashes)-1; i++ { + tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2-i]]) + verifyImportEvent(t, imported, true) + } + verifyImportDone(t, imported) +} diff --git a/les/handler_test.go b/les/handler_test.go index bb8ad33829f11..aba45764b3068 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -30,7 +30,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/les/downloader" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/params" diff --git a/les/sync.go b/les/sync.go index fa5ef4ff82009..31cd06ca704ab 100644 --- a/les/sync.go +++ b/les/sync.go @@ -23,7 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/les/downloader" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params"