diff --git a/trie/database.go b/trie/database.go index 58ca4e6f3caab..dd0e82cc2aacc 100644 --- a/trie/database.go +++ b/trie/database.go @@ -164,7 +164,7 @@ func (n *cachedNode) rlp() []byte { if node, ok := n.node.(rawNode); ok { return node } - blob, err := rlp.EncodeToBytes(n.node) + blob, err := frlp.EncodeToBytes(n.node) if err != nil { panic(err) } diff --git a/trie/fast_node_encoder.go b/trie/fast_node_encoder.go new file mode 100644 index 0000000000000..4181b12013062 --- /dev/null +++ b/trie/fast_node_encoder.go @@ -0,0 +1,109 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + "io" + + "github.com/ethereum/go-ethereum/rlp" +) + +// fastNodeEncoder is the fast node encoder using rlp.EncoderBuffer. +type fastNodeEncoder struct{} + +var frlp fastNodeEncoder + +// Encode writes the RLP encoding of node to w. +func (fastNodeEncoder) Encode(w io.Writer, node node) error { + enc := rlp.NewEncoderBuffer(w) + if err := fastEncodeNode(&enc, node); err != nil { + return err + } + return enc.Flush() +} + +// EncodeToBytes returns the RLP encoding of node. +func (fastNodeEncoder) EncodeToBytes(node node) ([]byte, error) { + enc := rlp.NewEncoderBuffer(nil) + defer enc.Flush() + + if err := fastEncodeNode(&enc, node); err != nil { + return nil, err + } + return enc.ToBytes(), nil +} + +func fastEncodeNode(w *rlp.EncoderBuffer, n node) error { + switch n := n.(type) { + case *fullNode: + offset := w.List() + for _, c := range n.Children { + if c != nil { + if err := fastEncodeNode(w, c); err != nil { + return err + } + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) + case *shortNode: + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + if err := fastEncodeNode(w, n.Val); err != nil { + return err + } + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) + case hashNode: + w.WriteBytes(n) + case valueNode: + w.WriteBytes(n) + case rawFullNode: + offset := w.List() + for _, c := range n { + if c != nil { + if err := fastEncodeNode(w, c); err != nil { + return err + } + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) + case *rawShortNode: + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + if err := fastEncodeNode(w, n.Val); err != nil { + return err + } + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) + case rawNode: + w.Write(n) + default: + return fmt.Errorf("unexpected node type: %T", n) + } + return nil +} diff --git a/trie/hasher.go b/trie/hasher.go index 3a62a2f1199c2..83d3b43982c16 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -20,7 +20,6 @@ import ( "sync" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" ) @@ -154,7 +153,7 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached // If the rlp data is smaller than 32 bytes, `nil` is returned. func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { h.tmp.Reset() - if err := rlp.Encode(&h.tmp, n); err != nil { + if err := frlp.Encode(&h.tmp, n); err != nil { panic("encode error: " + err.Error()) } @@ -169,7 +168,7 @@ func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { h.tmp.Reset() // Generate the RLP encoding of the node - if err := n.EncodeRLP(&h.tmp); err != nil { + if err := frlp.Encode(&h.tmp, n); err != nil { panic("encode error: " + err.Error()) } diff --git a/trie/iterator.go b/trie/iterator.go index 9b7d97a5f58b0..568edaefcef66 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/rlp" ) // Iterator is a key-value trie iterator that traverses a Trie. @@ -214,7 +213,7 @@ func (it *nodeIterator) LeafProof() [][]byte { // Gather nodes that end up as hash nodes (or the root) node, hashed := hasher.proofHash(item.node) if _, ok := hashed.(hashNode); ok || i == 0 { - enc, _ := rlp.EncodeToBytes(node) + enc, _ := frlp.EncodeToBytes(node) proofs = append(proofs, enc) } } diff --git a/trie/node_test.go b/trie/node_test.go index 52720f1c776ee..1533e9241c3e9 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -92,3 +92,29 @@ func TestDecodeFullNode(t *testing.T) { t.Fatalf("decode full node err: %v", err) } } + +func BenchmarkEncodeFullNode(b *testing.B) { + var buf sliceBuffer + var fn fullNode + for i := 0; i < 16; i++ { + fn.Children[i] = hashNode(randBytes(32)) + } + + for i := 0; i < b.N; i++ { + buf.Reset() + rlp.Encode(&buf, &fn) + } +} + +func BenchmarkFastEncodeFullNode(b *testing.B) { + var buf sliceBuffer + var fn fullNode + for i := 0; i < 16; i++ { + fn.Children[i] = hashNode(randBytes(32)) + } + + for i := 0; i < b.N; i++ { + buf.Reset() + frlp.Encode(&buf, &fn) + } +} diff --git a/trie/proof.go b/trie/proof.go index 9be3b62216a80..dd97d101b42c3 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -25,7 +25,6 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/rlp" ) // Prove constructs a merkle proof for key. The result contains all encoded nodes @@ -79,7 +78,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e if hash, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. - enc, _ := rlp.EncodeToBytes(n) + enc, _ := frlp.EncodeToBytes(n) if !ok { hash = hasher.hashData(enc) } diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 76258c31123c2..33d989ec4ce48 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -376,7 +376,7 @@ func (st *StackTrie) hash() { switch st.nodeType { case branchNode: - var nodes [17]node + var nodes rawFullNode for i, child := range st.children { if child == nil { nodes[i] = nilValueNode @@ -395,7 +395,7 @@ func (st *StackTrie) hash() { h = newHasher(false) defer returnHasherToPool(h) h.tmp.Reset() - if err := rlp.Encode(&h.tmp, nodes); err != nil { + if err := frlp.Encode(&h.tmp, nodes); err != nil { panic(err) } case extNode: @@ -409,14 +409,11 @@ func (st *StackTrie) hash() { } else { valuenode = hashNode(st.children[0].val) } - n := struct { - Key []byte - Val node - }{ + n := &rawShortNode{ Key: hexToCompact(st.key), Val: valuenode, } - if err := rlp.Encode(&h.tmp, n); err != nil { + if err := frlp.Encode(&h.tmp, n); err != nil { panic(err) } returnToPool(st.children[0])