diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go index 3611093423bc2..4cc42af8c267d 100644 --- a/consensus/beacon/consensus.go +++ b/consensus/beacon/consensus.go @@ -112,10 +112,12 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [ break } } + // All the headers have passed the transition point, use new rules. if len(preHeaders) == 0 { return beacon.verifyHeaders(chain, headers, nil) } + // The transition point exists in the middle, separate the headers // into two batches and apply different verification rules for them. var ( @@ -130,6 +132,14 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [ oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, preSeals) newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1]) ) + // Verify that pre-merge headers don't overflow the TTD + if index, err := verifyTerminalPoWBlock(chain, preHeaders); err != nil { + // Mark all subsequent pow headers with the error. + for i := index; i < len(preHeaders); i++ { + errors[i], done[i] = err, true + } + } + // Collect the results for { for ; done[out]; out++ { results <- errors[out] @@ -139,7 +149,9 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [ } select { case err := <-oldResult: - errors[old], done[old] = err, true + if !done[old] { // skip TTD-verified failures + errors[old], done[old] = err, true + } old++ case err := <-newResult: errors[new], done[new] = err, true @@ -154,6 +166,32 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [ return abort, results } +// verifyTerminalPoWBlock verifies that the preHeaders confirm to the specification +// wrt. their total difficulty. +// It expects: +// - preHeaders to be at least 1 element +// - the parent of the header element to be stored in the chain correctly +// - the preHeaders to have a set difficulty +// - the last element to be the terminal block +func verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) { + td := chain.GetTd(preHeaders[0].ParentHash, preHeaders[0].Number.Uint64()-1) + if td == nil { + return 0, consensus.ErrUnknownAncestor + } + // Check that all blocks before the last one are below the TTD + for i, head := range preHeaders { + if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 { + return i, consensus.ErrInvalidTerminalBlock + } + td.Add(td, head.Difficulty) + } + // Check that the last block is the terminal block + if td.Cmp(chain.Config().TerminalTotalDifficulty) < 0 { + return len(preHeaders) - 1, consensus.ErrInvalidTerminalBlock + } + return 0, nil +} + // VerifyUncles verifies that the given block's uncles conform to the consensus // rules of the Ethereum consensus engine. func (beacon *Beacon) VerifyUncles(chain consensus.ChainReader, block *types.Block) error { diff --git a/consensus/beacon/consensus_test.go b/consensus/beacon/consensus_test.go new file mode 100644 index 0000000000000..09c0b27c42562 --- /dev/null +++ b/consensus/beacon/consensus_test.go @@ -0,0 +1,137 @@ +package beacon + +import ( + "fmt" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" +) + +type mockChain struct { + config *params.ChainConfig + tds map[uint64]*big.Int +} + +func newMockChain() *mockChain { + return &mockChain{ + config: new(params.ChainConfig), + tds: make(map[uint64]*big.Int), + } +} + +func (m *mockChain) Config() *params.ChainConfig { + return m.config +} + +func (m *mockChain) CurrentHeader() *types.Header { panic("not implemented") } + +func (m *mockChain) GetHeader(hash common.Hash, number uint64) *types.Header { + panic("not implemented") +} + +func (m *mockChain) GetHeaderByNumber(number uint64) *types.Header { panic("not implemented") } + +func (m *mockChain) GetHeaderByHash(hash common.Hash) *types.Header { panic("not implemented") } + +func (m *mockChain) GetTd(hash common.Hash, number uint64) *big.Int { + num, ok := m.tds[number] + if ok { + return new(big.Int).Set(num) + } + return nil +} + +func TestVerifyTerminalBlock(t *testing.T) { + chain := newMockChain() + chain.tds[0] = big.NewInt(10) + chain.config.TerminalTotalDifficulty = big.NewInt(50) + + tests := []struct { + preHeaders []*types.Header + ttd *big.Int + err error + index int + }{ + // valid ttd + { + preHeaders: []*types.Header{ + {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(3), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, + }, + ttd: big.NewInt(50), + }, + // last block doesn't reach ttd + { + preHeaders: []*types.Header{ + {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(3), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(4), Difficulty: big.NewInt(9)}, + }, + ttd: big.NewInt(50), + err: consensus.ErrInvalidTerminalBlock, + index: 3, + }, + // two blocks reach ttd + { + preHeaders: []*types.Header{ + {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(3), Difficulty: big.NewInt(20)}, + {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, + }, + ttd: big.NewInt(50), + err: consensus.ErrInvalidTerminalBlock, + index: 3, + }, + // three blocks reach ttd + { + preHeaders: []*types.Header{ + {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(3), Difficulty: big.NewInt(20)}, + {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, + {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, + }, + ttd: big.NewInt(50), + err: consensus.ErrInvalidTerminalBlock, + index: 3, + }, + // parent reached ttd + { + preHeaders: []*types.Header{ + {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, + }, + ttd: big.NewInt(9), + err: consensus.ErrInvalidTerminalBlock, + index: 0, + }, + // unknown parent + { + preHeaders: []*types.Header{ + {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, + }, + ttd: big.NewInt(9), + err: consensus.ErrUnknownAncestor, + index: 0, + }, + } + + for i, test := range tests { + fmt.Printf("Test: %v\n", i) + chain.config.TerminalTotalDifficulty = test.ttd + index, err := verifyTerminalPoWBlock(chain, test.preHeaders) + if err != test.err { + t.Fatalf("Invalid error encountered, expected %v got %v", test.err, err) + } + if index != test.index { + t.Fatalf("Invalid index, expected %v got %v", test.index, index) + } + } +} diff --git a/consensus/errors.go b/consensus/errors.go index 55d0a16e73edd..f2c515dcc2951 100644 --- a/consensus/errors.go +++ b/consensus/errors.go @@ -37,4 +37,8 @@ var ( // ErrUnauthorized is returned if a block's minerNodeId or minerNodeSig is invalid. ErrUnauthorized = errors.New("unauthorized block") + + // ErrInvalidTerminalBlock is returned if a block is invalid wrt. the terminal + // total difficulty. + ErrInvalidTerminalBlock = errors.New("invalid terminal block") ) diff --git a/core/block_validator_test.go b/core/block_validator_test.go index 0f183ba527784..8dee8d576070b 100644 --- a/core/block_validator_test.go +++ b/core/block_validator_test.go @@ -107,7 +107,8 @@ func testHeaderVerificationForMerging(t *testing.T, isClique bool) { Alloc: map[common.Address]GenesisAccount{ addr: {Balance: big.NewInt(1)}, }, - BaseFee: big.NewInt(params.InitialBaseFee), + BaseFee: big.NewInt(params.InitialBaseFee), + Difficulty: new(big.Int), } copy(genspec.ExtraData[32:], addr[:]) genesis := genspec.MustCommit(testdb)