diff --git a/storage/statedb/committer.go b/storage/statedb/committer.go new file mode 100644 index 0000000000..bc319aaa04 --- /dev/null +++ b/storage/statedb/committer.go @@ -0,0 +1,268 @@ +// Modifications Copyright 2022 The klaytn Authors +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . +// +// This file is derived from trie/comitter.go (2022/11/14). +// Modified and improved for the klaytn development. + +package statedb + +import ( + "errors" + "fmt" + "sync" + + "github.com/klaytn/klaytn/common" + "golang.org/x/crypto/sha3" +) + +// leafChanSize is the size of the leafCh. It's a pretty arbitrary number, to allow +// some paralellism but not incur too much memory overhead. +const leafChanSize = 200 + +// leaf represents a trie leaf value +type leaf struct { + size int // size of the rlp data (estimate) + hash common.Hash // hash of rlp data + node node // the node to commit + vnodes bool // set to true if the node (possibly) contains a valueNode +} + +// committer is a type used for the trie Commit operation. A committer has some +// internal preallocated temp space, and also a callback that is invoked when +// leaves are committed. The leafs are passed through the `leafCh`, to allow +// some level of paralellism. +// By 'some level' of parallelism, it's still the case that all leaves will be +// processed sequentially - onleaf will never be called in parallel or out of order. +type committer struct { + sha KeccakState + + onleaf LeafCallback + leafCh chan *leaf +} + +// committers live in a global sync.Pool +var committerPool = sync.Pool{ + New: func() interface{} { + return &committer{ + sha: sha3.NewLegacyKeccak256().(KeccakState), + } + }, +} + +// newCommitter creates a new committer or picks one from the pool. +func newCommitter() *committer { + return committerPool.Get().(*committer) +} + +func returnCommitterToPool(h *committer) { + h.onleaf = nil + h.leafCh = nil + committerPool.Put(h) +} + +// commitNeeded returns 'false' if the given node is already in sync with db +func (c *committer) commitNeeded(n node) bool { + hash, dirty := n.cache() + return hash == nil || dirty +} + +// commit collapses a node down into a hash node and inserts it into the database +func (c *committer) Commit(n node, db *Database) (hashNode, error) { + if db == nil { + return nil, errors.New("no db provided") + } + h, err := c.commit(n, db, true) + if err != nil { + return nil, err + } + return h.(hashNode), nil +} + +// commit collapses a node down into a hash node and inserts it into the database +func (c *committer) commit(n node, db *Database, force bool) (node, error) { + // if this path is clean, use available cached data + hash, dirty := n.cache() + if hash != nil && !dirty { + return hash, nil + } + // Commit children, then parent, and remove remove the dirty flag. + switch cn := n.(type) { + case *shortNode: + // Commit child + collapsed := cn.copy() + if _, ok := cn.Val.(valueNode); !ok { + if childV, err := c.commit(cn.Val, db, false); err != nil { + return nil, err + } else { + collapsed.Val = childV + } + } + // The key needs to be copied, since we're delivering it to database + collapsed.Key = hexToCompact(cn.Key) + hashedNode := c.store(collapsed, db, force, true) + if hn, ok := hashedNode.(hashNode); ok { + return hn, nil + } else { + return collapsed, nil + } + case *fullNode: + hashedKids, hasVnodes, err := c.commitChildren(cn, db, force) + if err != nil { + return nil, err + } + collapsed := cn.copy() + collapsed.Children = hashedKids + + hashedNode := c.store(collapsed, db, force, hasVnodes) + if hn, ok := hashedNode.(hashNode); ok { + return hn, nil + } else { + return collapsed, nil + } + case valueNode: + return c.store(cn, db, force, false), nil + // hashnodes aren't stored + case hashNode: + return cn, nil + } + return hash, nil +} + +// commitChildren commits the children of the given fullnode +func (c *committer) commitChildren(n *fullNode, db *Database, force bool) ([17]node, bool, error) { + var children [17]node + hasValueNodeChildren := false + for i, child := range n.Children { + if child == nil { + continue + } + hnode, err := c.commit(child, db, false) + if err != nil { + return children, false, err + } + children[i] = hnode + if _, ok := hnode.(valueNode); ok { + hasValueNodeChildren = true + } + } + return children, hasValueNodeChildren, nil +} + +// store hashes the node n and if we have a storage layer specified, it writes +// the key/value pair to it and tracks any node->child references as well as any +// node->external trie references. +func (c *committer) store(n node, db *Database, force bool, hasVnodeChildren bool) node { + // Larger nodes are replaced by their hash and stored in the database. + var ( + hash, _ = n.cache() + size int + ) + if hash == nil { + // This was not generated - must be a small node stored in the parent + // No need to do anything here + return n + } + // We have the hash already, estimate the RLP encoding-size of the node. + // The size is used for mem tracking, does not need to be exact + size = estimateSize(n) + + // If we're using channel-based leaf-reporting, send to channel. + // The leaf channel will be active only when there an active leaf-callback + if c.leafCh != nil { + c.leafCh <- &leaf{ + size: size, + hash: common.BytesToHash(hash), + node: n, + vnodes: hasVnodeChildren, + } + } else if db != nil { + // No leaf-callback used, but there's still a database. Do serial + // insertion + db.lock.Lock() + db.insert(common.BytesToHash(hash), uint16(size), n) + db.lock.Unlock() + } + return hash +} + +// commitLoop does the actual insert + leaf callback for nodes +func (c *committer) commitLoop(db *Database) { + for item := range c.leafCh { + var ( + hash = item.hash + size = item.size + n = item.node + hasVnodes = item.vnodes + ) + // We are pooling the trie nodes into an intermediate memory cache + db.lock.Lock() + db.insert(hash, uint16(size), n) + db.lock.Unlock() + if c.onleaf != nil && hasVnodes { + switch n := n.(type) { + case *shortNode: + if child, ok := n.Val.(valueNode); ok { + c.onleaf(nil, nil, child, hash, 0) + } + case *fullNode: + // For children in range [0, 15], it's impossible + // to contain valueNode. Only check the 17th child. + if n.Children[16] != nil { + c.onleaf(nil, nil, n.Children[16].(valueNode), hash, 0) + } + } + } + } +} + +func (c *committer) makeHashNode(data []byte) hashNode { + n := make(hashNode, c.sha.Size()) + c.sha.Reset() + c.sha.Write(data) + c.sha.Read(n) + return n +} + +// estimateSize estimates the size of an rlp-encoded node, without actually +// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie +// with 1000 leafs, the only errors above 1% are on small shortnodes, where this +// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35) +func estimateSize(n node) int { + switch n := n.(type) { + case *shortNode: + // A short node contains a compacted key, and a value. + return 3 + len(n.Key) + estimateSize(n.Val) + case *fullNode: + // A full node contains up to 16 hashes (some nils), and a key + s := 3 + for i := 0; i < 16; i++ { + if child := n.Children[i]; child != nil { + s += estimateSize(child) + } else { + s += 1 + } + } + return s + case valueNode: + return 1 + len(n) + case hashNode: + return 1 + len(n) + default: + panic(fmt.Sprintf("node type %T", n)) + + } +} diff --git a/storage/statedb/database.go b/storage/statedb/database.go index 042dc2e734..c619f087a1 100644 --- a/storage/statedb/database.go +++ b/storage/statedb/database.go @@ -146,16 +146,9 @@ func (n rawFullNode) fstring(ind string) string { panic("this should never e func (n rawFullNode) lenEncoded() uint16 { panic("this should never end up in a live trie") } func (n rawFullNode) EncodeRLP(w io.Writer) error { - var nodes [17]node - - for i, child := range n { - if child != nil { - nodes[i] = child - } else { - nodes[i] = nilValueNode - } - } - return rlp.Encode(w, nodes) + encodeByte := rlp.NewEncoderBuffer(w) + n.encode(encodeByte) + return encodeByte.Flush() } // rawShortNode represents only the useful data content of a short node, with the @@ -193,11 +186,7 @@ func (n *cachedNode) rlp() []byte { if node, ok := n.node.(rawNode); ok { return node } - blob, err := rlp.EncodeToBytes(n.node) - if err != nil { - panic(err) - } - return blob + return nodeToBytes(n.node) } // obj returns the decoded and expanded trie node, either directly from the cache, diff --git a/storage/statedb/hasher.go b/storage/statedb/hasher.go index 9733b403e0..c9421e5025 100644 --- a/storage/statedb/hasher.go +++ b/storage/statedb/hasher.go @@ -24,17 +24,10 @@ import ( "hash" "sync" - "github.com/klaytn/klaytn/common" - "github.com/klaytn/klaytn/crypto/sha3" "github.com/klaytn/klaytn/rlp" + "golang.org/x/crypto/sha3" ) -type hasher struct { - tmp sliceBuffer - sha KeccakState - onleaf LeafCallback -} - // KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports // Read to get a variable amount of data from the hash state. Read is faster than Sum // because it doesn't copy the internal state, but also modifies the internal state. @@ -43,30 +36,29 @@ type KeccakState interface { Read([]byte) (int, error) } -type sliceBuffer []byte - -func (b *sliceBuffer) Write(data []byte) (n int, err error) { - *b = append(*b, data...) - return len(data), nil -} - -func (b *sliceBuffer) Reset() { - *b = (*b)[:0] +// hasher is a type used for the trie Hash operation. A hasher has some +// internal preallocated temp space +type hasher struct { + sha KeccakState + tmp []byte + encBuf rlp.EncoderBuffer + parallel bool } -// hashers live in a global db. +// hasherPool holds pureHashers var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ - tmp: make(sliceBuffer, 0, 550), // cap is as large as a full fullNode. - sha: sha3.NewKeccak256().(KeccakState), + tmp: make([]byte, 0, 550), // cap is as large as a full fullNode. + sha: sha3.NewLegacyKeccak256().(KeccakState), + encBuf: rlp.NewEncoderBuffer(nil), } }, } -func newHasher(onleaf LeafCallback) *hasher { +func newHasher(parallel bool) *hasher { h := hasherPool.Get().(*hasher) - h.onleaf = onleaf + h.parallel = parallel return h } @@ -76,226 +68,155 @@ func returnHasherToPool(h *hasher) { // hash collapses a node down into a hash node, also returning a copy of the // original node initialized with the computed hash to replace the original one. -func (h *hasher) hash(n node, db *Database, force bool) (node, node) { - // If we're not storing the node, just hashing, use available cached data - if hash, dirty := n.cache(); hash != nil { - if db == nil { - return hash, n - } - if !dirty { - switch n.(type) { - case *fullNode, *shortNode: - return hash, hash - default: - return hash, n - } - } +func (h *hasher) hash(n node, force bool) (hashed node, cached node) { + // We're not storing the node, just hashing, use available cached data + if hash, _ := n.cache(); hash != nil { + return hash, n } // Trie not processed yet or needs storage, walk the children - collapsed, cached := h.hashChildren(n, db) - hashed, lenEncoded := h.store(collapsed, db, force) - // Cache the hash of the node for later reuse and remove - // the dirty flag in commit mode. It's fine to assign these values directly - // without copying the node first because hashChildren copies it. - cachedHash, _ := hashed.(hashNode) - switch cn := cached.(type) { + switch n := n.(type) { case *shortNode: - cn.flags.hash = cachedHash - cn.flags.lenEncoded = lenEncoded - if db != nil { - cn.flags.dirty = false - } + collapsed, cached := h.hashShortNodeChildren(n) + hashed := h.shortnodeToHash(collapsed, force) + // We need to retain the possibly _not_ hashed node, in case it was too + // small to be hashed + if hn, ok := hashed.(hashNode); ok { + cached.flags.hash = hn + } else { + cached.flags.hash = nil + } + return hashed, cached case *fullNode: - cn.flags.hash = cachedHash - cn.flags.lenEncoded = lenEncoded - if db != nil { - cn.flags.dirty = false - } + collapsed, cached := h.hashFullNodeChildren(n) + hashed = h.fullnodeToHash(collapsed, force) + if hn, ok := hashed.(hashNode); ok { + cached.flags.hash = hn + } else { + cached.flags.hash = nil + } + return hashed, cached + default: + // Value and hash nodes don't have children so they're left as were + return n, n } - return hashed, cached } -func (h *hasher) hashRoot(n node, db *Database, force bool) (node, node) { - // If we're not storing the node, just hashing, use available cached data - if hash, dirty := n.cache(); hash != nil { - if db == nil { - return hash, n - } - if !dirty { - switch n.(type) { - case *fullNode, *shortNode: - return hash, hash - default: - return hash, n - } - } +// hashShortNodeChildren collapses the short node. The returned collapsed node +// holds a live reference to the Key, and must not be modified. +// The cached +func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) { + // Hash the short node's child, caching the newly hashed subtree + collapsed, cached = n.copy(), n.copy() + // Previously, we did copy this one. We don't seem to need to actually + // do that, since we don't overwrite/reuse keys + // cached.Key = common.CopyBytes(n.Key) + collapsed.Key = hexToCompact(n.Key) + // Unless the child is a valuenode or hashnode, hash it + switch n.Val.(type) { + case *fullNode, *shortNode: + collapsed.Val, cached.Val = h.hash(n.Val, false) } - // Trie not processed yet or needs storage, walk the children - collapsed, cached := h.hashChildrenFromRoot(n, db) - hashed, lenEncoded := h.store(collapsed, db, force) - // Cache the hash of the node for later reuse and remove - // the dirty flag in commit mode. It's fine to assign these values directly - // without copying the node first because hashChildren copies it. - cachedHash, _ := hashed.(hashNode) - switch cn := cached.(type) { - case *shortNode: - cn.flags.hash = cachedHash - cn.flags.lenEncoded = lenEncoded - if db != nil { - cn.flags.dirty = false - } - case *fullNode: - cn.flags.hash = cachedHash - cn.flags.lenEncoded = lenEncoded - if db != nil { - cn.flags.dirty = false - } - } - return hashed, cached + return collapsed, cached } -// hashChildren replaces the children of a node with their hashes if the encoded -// size of the child is larger than a hash, returning the collapsed node as well -// as a replacement for the original node with the child hashes cached in. -func (h *hasher) hashChildren(original node, db *Database) (node, node) { - switch n := original.(type) { - case *shortNode: - // Hash the short node's child, caching the newly hashed subtree - collapsed, cached := n.copy(), n.copy() - collapsed.Key = hexToCompact(n.Key) - cached.Key = common.CopyBytes(n.Key) - - if _, ok := n.Val.(valueNode); !ok { - collapsed.Val, cached.Val = h.hash(n.Val, db, false) +func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached *fullNode) { + // Hash the full node's children, caching the newly hashed subtrees + cached = n.copy() + collapsed = n.copy() + if h.parallel { + var wg sync.WaitGroup + wg.Add(16) + for i := 0; i < 16; i++ { + go func(i int) { + hasher := newHasher(false) + if child := n.Children[i]; child != nil { + collapsed.Children[i], cached.Children[i] = hasher.hash(child, false) + } else { + collapsed.Children[i] = nilValueNode + } + returnHasherToPool(hasher) + wg.Done() + }(i) } - return collapsed, cached - - case *fullNode: - // Hash the full node's children, caching the newly hashed subtrees - collapsed, cached := n.copy(), n.copy() - + wg.Wait() + } else { for i := 0; i < 16; i++ { - if n.Children[i] != nil { - collapsed.Children[i], cached.Children[i] = h.hash(n.Children[i], db, false) + if child := n.Children[i]; child != nil { + collapsed.Children[i], cached.Children[i] = h.hash(child, false) + } else { + collapsed.Children[i] = nilValueNode } } - cached.Children[16] = n.Children[16] - return collapsed, cached - - default: - // Value and hash nodes don't have children so they're left as were - return n, original } + cached.Children[16] = n.Children[16] + return collapsed, cached } -type hashResult struct { - index int - collapsed node - cached node -} +// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode +// should have hex-type Key, which will be converted (without modification) +// into compact form for RLP encoding. +// If the rlp data is smaller than 32 bytes, `nil` is returned. +func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { + n.encode(h.encBuf) + encResultBytes := h.encodedBytes() -func (h *hasher) hashChildrenFromRoot(original node, db *Database) (node, node) { - switch n := original.(type) { - case *shortNode: - // Hash the short node's child, caching the newly hashed subtree - collapsed, cached := n.copy(), n.copy() - collapsed.Key = hexToCompact(n.Key) - cached.Key = common.CopyBytes(n.Key) - - if _, ok := n.Val.(valueNode); !ok { - collapsed.Val, cached.Val = h.hash(n.Val, db, false) - } - return collapsed, cached - - case *fullNode: - // Hash the full node's children, caching the newly hashed subtrees - collapsed, cached := n.copy(), n.copy() - - hashResultCh := make(chan hashResult, 16) - numRootChildren := 0 - for i := 0; i < 16; i++ { - if n.Children[i] != nil { - numRootChildren++ - go func(i int, n node) { - childHasher := newHasher(h.onleaf) - defer returnHasherToPool(childHasher) - collapsedFromChild, cachedFromChild := childHasher.hash(n, db, false) - hashResultCh <- hashResult{i, collapsedFromChild, cachedFromChild} - }(i, n.Children[i]) - } - } - - for i := 0; i < numRootChildren; i++ { - hashResult := <-hashResultCh - idx := hashResult.index - collapsed.Children[idx], cached.Children[idx] = hashResult.collapsed, hashResult.cached - } - - cached.Children[16] = n.Children[16] - return collapsed, cached - - default: - // Value and hash nodes don't have children so they're left as were - return n, original + if len(encResultBytes) < 32 && !force { + return n // Nodes smaller than 32 bytes are stored inside their parent } + return h.hashData(encResultBytes) } -// store hashes the node n and if we have a storage layer specified, it writes -// the key/value pair to it and tracks any node->child references as well as any -// node->external trie references. -func (h *hasher) store(n node, db *Database, force bool) (node, uint16) { - // Don't store hashes or empty nodes. - if _, isHash := n.(hashNode); n == nil || isHash { - return n, 0 - } - hash, _ := n.cache() - lenEncoded := n.lenEncoded() - if hash == nil || lenEncoded == 0 { - // Generate the RLP encoding of the node - h.tmp.Reset() - if err := rlp.Encode(&h.tmp, n); err != nil { - panic("encode error: " + err.Error()) - } +// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which +// may contain nil values) +func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { + n.encode(h.encBuf) + encResultBytes := h.encodedBytes() - lenEncoded = uint16(len(h.tmp)) - } - if lenEncoded < 32 && !force { - return n, lenEncoded // Nodes smaller than 32 bytes are stored inside their parent - } - if hash == nil { - hash = h.makeHashNode(h.tmp) + if len(encResultBytes) < 32 && !force { + return n // Nodes smaller than 32 bytes are stored inside their parent } - if db != nil { - // We are pooling the trie nodes into an intermediate memory cache - hash := common.BytesToHash(hash) - - db.lock.Lock() - db.insert(hash, lenEncoded, n) - db.lock.Unlock() + return h.hashData(encResultBytes) +} - // Track external references from account->storage trie - if h.onleaf != nil { - switch n := n.(type) { - case *shortNode: - if child, ok := n.Val.(valueNode); ok { - h.onleaf(nil, nil, child, hash, 0) - } - case *fullNode: - for i := 0; i < 16; i++ { - if child, ok := n.Children[i].(valueNode); ok { - h.onleaf(nil, nil, child, hash, 0) - } - } - } - } - } - return hash, lenEncoded +// encodedBytes returns the result of the last encoding operation on h.encBuf. +// This also resets the encoder buffer. +// +// All node encoding must be done like this: +// +// node.encode(h.encBuf) +// enc := h.encodedBytes() +// +// This convention exists because node.encode can only be inlined/escape-analyzed when +// called on a concrete receiver type. +func (h *hasher) encodedBytes() []byte { + h.tmp = h.encBuf.AppendToBytes(h.tmp[:0]) + h.encBuf.Reset(nil) + return h.tmp } -func (h *hasher) makeHashNode(data []byte) hashNode { - n := make(hashNode, h.sha.Size()) +// hashData hashes the provided data +func (h *hasher) hashData(data []byte) hashNode { + n := make(hashNode, 32) h.sha.Reset() h.sha.Write(data) h.sha.Read(n) return n } + +// proofHash is used to construct trie proofs, and returns the 'collapsed' +// node (for later RLP encoding) aswell as the hashed node -- unless the +// node is smaller than 32 bytes, in which case it will be returned as is. +// This method does not do anything on value- or hash-nodes. +func (h *hasher) proofHash(original node) (collapsed, hashed node) { + switch n := original.(type) { + case *shortNode: + sn, _ := h.hashShortNodeChildren(n) + return sn, h.shortnodeToHash(sn, false) + case *fullNode: + fn, _ := h.hashFullNodeChildren(n) + return fn, h.fullnodeToHash(fn, false) + default: + // Value and hash nodes don't have children so they're left as were + return n, n + } +} diff --git a/storage/statedb/iterator.go b/storage/statedb/iterator.go index 95a4d33b19..db1665650c 100644 --- a/storage/statedb/iterator.go +++ b/storage/statedb/iterator.go @@ -28,7 +28,6 @@ import ( "github.com/klaytn/klaytn/storage/database" "github.com/klaytn/klaytn/common" - "github.com/klaytn/klaytn/rlp" ) // Iterator is a key-value trie iterator that traverses a Trie. @@ -195,18 +194,16 @@ func (it *nodeIterator) LeafKey() []byte { func (it *nodeIterator) LeafProof() [][]byte { if len(it.stack) > 0 { if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { - hasher := newHasher(nil) + hasher := newHasher(false) defer returnHasherToPool(hasher) proofs := make([][]byte, 0, len(it.stack)) for i, item := range it.stack[:len(it.stack)-1] { // Gather nodes that end up as hash nodes (or the root) - node, _ := hasher.hashChildren(item.node, nil) - hashed, _ := hasher.store(node, nil, false) + node, hashed := hasher.proofHash(item.node) if _, ok := hashed.(hashNode); ok || i == 0 { - enc, _ := rlp.EncodeToBytes(node) - proofs = append(proofs, enc) + proofs = append(proofs, nodeToBytes(node)) } } return proofs diff --git a/storage/statedb/node.go b/storage/statedb/node.go index acb3a9fb8f..6f9bbeac39 100644 --- a/storage/statedb/node.go +++ b/storage/statedb/node.go @@ -35,6 +35,7 @@ type node interface { fstring(string) string cache() (hashNode, bool) lenEncoded() uint16 + encode(w rlp.EncoderBuffer) } type ( @@ -57,16 +58,9 @@ var nilValueNode = valueNode(nil) // EncodeRLP encodes a full node into the consensus RLP format. func (n *fullNode) EncodeRLP(w io.Writer) error { - var nodes [17]node - - for i, child := range &n.Children { - if child != nil { - nodes[i] = child - } else { - nodes[i] = nilValueNode - } - } - return rlp.Encode(w, nodes) + encodeByte := rlp.NewEncoderBuffer(w) + n.encode(encodeByte) + return encodeByte.Flush() } func (n *fullNode) copy() *fullNode { copy := *n; return © } @@ -127,8 +121,29 @@ func mustDecodeNode(hash, buf []byte) node { return n } -// decodeNode parses the RLP encoding of a trie node. +// mustDecodeNodeUnsafe is a wrapper of decodeNodeUnsafe and panic if any error is +// encountered. +func mustDecodeNodeUnsafe(hash, buf []byte) node { + n, err := decodeNodeUnsafe(hash, buf) + if err != nil { + panic(fmt.Sprintf("node %x: %v", hash, err)) + } + return n +} + +// decodeNode parses the RLP encoding of a trie node. It will deep-copy the passed +// byte slice for decoding, so it's safe to modify the byte slice afterwards. The- +// decode performance of this function is not optimal, but it is suitable for most +// scenarios with low performance requirements and hard to determine whether the +// byte slice be modified or not. func decodeNode(hash, buf []byte) (node, error) { + return decodeNodeUnsafe(hash, common.CopyBytes(buf)) +} + +// decodeNodeUnsafe parses the RLP encoding of a trie node. The passed byte slice +// will be directly referenced by node without bytes deep copy, so the input MUST +// not be changed after. +func decodeNodeUnsafe(hash, buf []byte) (node, error) { if len(buf) == 0 { return nil, io.ErrUnexpectedEOF } diff --git a/storage/statedb/node_enc.go b/storage/statedb/node_enc.go new file mode 100644 index 0000000000..8b6d468eb7 --- /dev/null +++ b/storage/statedb/node_enc.go @@ -0,0 +1,85 @@ +// Copyright 2022 The klaytn Authors +// This file is part of the klaytn library. +// +// The klaytn library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The klaytn library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the klaytn library. If not, see . + +package statedb + +import "github.com/klaytn/klaytn/rlp" + +func nodeToBytes(n node) []byte { + w := rlp.NewEncoderBuffer(nil) + n.encode(w) + result := w.ToBytes() + w.Flush() + return result +} + +func (n *fullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n.Children { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *shortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n hashNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n valueNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n rawFullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *rawShortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n rawNode) encode(w rlp.EncoderBuffer) { + w.Write(n) +} diff --git a/storage/statedb/node_test.go b/storage/statedb/node_test.go new file mode 100644 index 0000000000..7ffc55dff7 --- /dev/null +++ b/storage/statedb/node_test.go @@ -0,0 +1,215 @@ +// Copyright 2022 The klaytn Authors +// This file is part of the klaytn library. +// +// The klaytn library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The klaytn library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the klaytn library. If not, see . + +package statedb + +import ( + "bytes" + "testing" + + "github.com/klaytn/klaytn/crypto" + "github.com/klaytn/klaytn/rlp" +) + +func newTestFullNode(v []byte) []interface{} { + fullNodeData := []interface{}{} + for i := 0; i < 16; i++ { + k := bytes.Repeat([]byte{byte(i + 1)}, 32) + fullNodeData = append(fullNodeData, k) + } + fullNodeData = append(fullNodeData, v) + return fullNodeData +} + +func TestDecodeNestedNode(t *testing.T) { + fullNodeData := newTestFullNode([]byte("fullnode")) + + data := [][]byte{} + for i := 0; i < 16; i++ { + data = append(data, nil) + } + data = append(data, []byte("subnode")) + fullNodeData[15] = data + + buf := bytes.NewBuffer([]byte{}) + rlp.Encode(buf, fullNodeData) + + if _, err := decodeNode([]byte("testdecode"), buf.Bytes()); err != nil { + t.Fatalf("decode nested full node err: %v", err) + } +} + +func TestDecodeFullNodeWrongSizeChild(t *testing.T) { + fullNodeData := newTestFullNode([]byte("wrongsizechild")) + fullNodeData[0] = []byte("00") + buf := bytes.NewBuffer([]byte{}) + rlp.Encode(buf, fullNodeData) + + _, err := decodeNode([]byte("testdecode"), buf.Bytes()) + if _, ok := err.(*decodeError); !ok { + t.Fatalf("decodeNode returned wrong err: %v", err) + } +} + +func TestDecodeFullNodeWrongNestedFullNode(t *testing.T) { + fullNodeData := newTestFullNode([]byte("fullnode")) + + data := [][]byte{} + for i := 0; i < 16; i++ { + data = append(data, []byte("123456")) + } + data = append(data, []byte("subnode")) + fullNodeData[15] = data + + buf := bytes.NewBuffer([]byte{}) + rlp.Encode(buf, fullNodeData) + + _, err := decodeNode([]byte("testdecode"), buf.Bytes()) + if _, ok := err.(*decodeError); !ok { + t.Fatalf("decodeNode returned wrong err: %v", err) + } +} + +func TestDecodeFullNode(t *testing.T) { + fullNodeData := newTestFullNode([]byte("decodefullnode")) + buf := bytes.NewBuffer([]byte{}) + rlp.Encode(buf, fullNodeData) + + _, err := decodeNode([]byte("testdecode"), buf.Bytes()) + if err != nil { + t.Fatalf("decode full node err: %v", err) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkEncodeShortNode +// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op +func BenchmarkEncodeShortNode(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + nodeToBytes(node) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkEncodeFullNode +// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op +func BenchmarkEncodeFullNode(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + nodeToBytes(node) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeShortNode +// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op +func BenchmarkDecodeShortNode(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNode(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeShortNodeUnsafe +// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op +func BenchmarkDecodeShortNodeUnsafe(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNodeUnsafe(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeFullNode +// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op +func BenchmarkDecodeFullNode(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNode(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeFullNodeUnsafe +// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op +func BenchmarkDecodeFullNodeUnsafe(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNodeUnsafe(hash, blob) + } +} diff --git a/storage/statedb/proof.go b/storage/statedb/proof.go index 3a24ea8867..90694dfa93 100644 --- a/storage/statedb/proof.go +++ b/storage/statedb/proof.go @@ -27,7 +27,6 @@ import ( "github.com/klaytn/klaytn/common" "github.com/klaytn/klaytn/crypto" - "github.com/klaytn/klaytn/rlp" "github.com/klaytn/klaytn/storage/database" ) @@ -77,21 +76,21 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDB ProofDBWriter) error { panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - hasher := newHasher(nil) + hasher := newHasher(false) defer returnHasherToPool(hasher) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. - n, _ = hasher.hashChildren(n, nil) - hn, _ := hasher.store(n, nil, false) + var hn node + n, hn = hasher.proofHash(n) if hash, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. if fromLevel > 0 { fromLevel-- } else { - enc, _ := rlp.EncodeToBytes(n) + enc := nodeToBytes(n) if !ok { hash = crypto.Keccak256(enc) } diff --git a/storage/statedb/secure_trie.go b/storage/statedb/secure_trie.go index f9fdba0a54..b06d9f7eb8 100644 --- a/storage/statedb/secure_trie.go +++ b/storage/statedb/secure_trie.go @@ -202,7 +202,7 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { // The caller must not hold onto the return value because it will become // invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { - h := newHasher(nil) + h := newHasher(false) h.sha.Reset() h.sha.Write(key) buf := h.sha.Sum(t.hashKeyBuf[:0]) diff --git a/storage/statedb/stacktrie.go b/storage/statedb/stacktrie.go index b5b52c0b4a..1b09fcd67c 100644 --- a/storage/statedb/stacktrie.go +++ b/storage/statedb/stacktrie.go @@ -30,7 +30,6 @@ import ( "sync" "github.com/klaytn/klaytn/common" - "github.com/klaytn/klaytn/rlp" "github.com/klaytn/klaytn/storage/database" ) @@ -57,12 +56,11 @@ func returnToPool(st *StackTrie) { // in order. Once it determines that a subtree will no longer be inserted // into, it will hash it and free up the memory it uses. type StackTrie struct { - nodeType uint8 // node type (as in branch, ext, leaf) - val []byte // value contained by this node if it's a leaf - key []byte // key chunk covered by this (full|ext) node - keyOffset int // offset of the key chunk inside a full key - children [16]*StackTrie // list of children (for fullnodes and exts) - db database.DBManager // Pointer to the commit db, can be nil + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (leaf|ext) node + children [16]*StackTrie // list of children (for branch and exts) + db database.DBManager // Pointer to the commit db, can be nil } // NewStackTrie allocates and initializes an empty trie. @@ -93,13 +91,11 @@ func (st *StackTrie) MarshalBinary() (data []byte, err error) { w = bufio.NewWriter(&b) ) if err := gob.NewEncoder(w).Encode(struct { - Nodetype uint8 - KeyOffset uint8 - Val []byte - Key []byte + Nodetype uint8 + Val []byte + Key []byte }{ st.nodeType, - uint8(st.keyOffset), st.val, st.key, }); err != nil { @@ -129,17 +125,14 @@ func (st *StackTrie) UnmarshalBinary(data []byte) error { func (st *StackTrie) unmarshalBinary(r io.Reader) error { var dec struct { - Nodetype uint8 - KeyOffset uint8 - Val []byte - Key []byte + Nodetype uint8 + Val []byte + Key []byte } gob.NewDecoder(r).Decode(&dec) st.nodeType = dec.Nodetype st.val = dec.Val st.key = dec.Key - st.keyOffset = int(dec.KeyOffset) - hasChild := make([]byte, 1) for i := range st.children { if _, err := r.Read(hasChild); err != nil { @@ -163,20 +156,18 @@ func (st *StackTrie) setDb(db database.DBManager) { } } -func newLeaf(ko int, key, val []byte, db database.DBManager) *StackTrie { +func newLeaf(key, val []byte, db database.DBManager) *StackTrie { st := stackTrieFromPool(db) st.nodeType = leafNode - st.keyOffset = ko - st.key = append(st.key, key[ko:]...) + st.key = append(st.key, key...) st.val = val return st } -func newExt(ko int, key []byte, child *StackTrie, db database.DBManager) *StackTrie { +func newExt(key []byte, child *StackTrie, db database.DBManager) *StackTrie { st := stackTrieFromPool(db) st.nodeType = extNode - st.keyOffset = ko - st.key = append(st.key, key[ko:]...) + st.key = append(st.key, key...) st.children[0] = child return st } @@ -214,17 +205,18 @@ func (st *StackTrie) Reset() { st.children[i] = nil } st.nodeType = emptyNode - st.keyOffset = 0 } // Helper function that, given a full key, determines the index // at which the chunk pointed by st.keyOffset is different from // the same chunk in the full key. func (st *StackTrie) getDiffIndex(key []byte) int { - diffindex := 0 - for ; diffindex < len(st.key) && st.key[diffindex] == key[st.keyOffset+diffindex]; diffindex++ { + for idx, nibble := range st.key { + if nibble != key[idx] { + return idx + } } - return diffindex + return len(st.key) } // Helper function to that inserts a (key, value) pair into @@ -232,7 +224,8 @@ func (st *StackTrie) getDiffIndex(key []byte) int { func (st *StackTrie) insert(key, value []byte) { switch st.nodeType { case branchNode: /* Branch */ - idx := int(key[st.keyOffset]) + idx := int(key[0]) + // Unresolve elder siblings for i := idx - 1; i >= 0; i-- { if st.children[i] != nil { @@ -242,12 +235,14 @@ func (st *StackTrie) insert(key, value []byte) { break } } + // Add new child if st.children[idx] == nil { - st.children[idx] = stackTrieFromPool(st.db) - st.children[idx].keyOffset = st.keyOffset + 1 + st.children[idx] = newLeaf(key[1:], value, st.db) + } else { + st.children[idx].insert(key[1:], value) } - st.children[idx].insert(key, value) + case extNode: /* Ext */ // Compare both key chunks and see where they differ diffidx := st.getDiffIndex(key) @@ -260,7 +255,7 @@ func (st *StackTrie) insert(key, value []byte) { if diffidx == len(st.key) { // Ext key and key segment are identical, recurse into // the child node. - st.children[0].insert(key, value) + st.children[0].insert(key[diffidx:], value) return } // Save the original part. Depending if the break is @@ -269,7 +264,7 @@ func (st *StackTrie) insert(key, value []byte) { // node directly. var n *StackTrie if diffidx < len(st.key)-1 { - n = newExt(diffidx+1, st.key, st.children[0], st.db) + n = newExt(st.key[diffidx+1:], st.children[0], st.db) } else { // Break on the last byte, no need to insert // an extension node: reuse the current node @@ -291,15 +286,13 @@ func (st *StackTrie) insert(key, value []byte) { // node. st.children[0] = stackTrieFromPool(st.db) st.children[0].nodeType = branchNode - st.children[0].keyOffset = st.keyOffset + diffidx p = st.children[0] } // Create a leaf for the inserted part - o := newLeaf(st.keyOffset+diffidx+1, key, value, st.db) - + o := newLeaf(key[diffidx+1:], value, st.db) // Insert both child leaves where they belong: origIdx := st.key[diffidx] - newIdx := key[diffidx+st.keyOffset] + newIdx := key[diffidx] p.children[origIdx] = n p.children[newIdx] = o st.key = st.key[:diffidx] @@ -333,38 +326,37 @@ func (st *StackTrie) insert(key, value []byte) { st.nodeType = extNode st.children[0] = NewStackTrie(st.db) st.children[0].nodeType = branchNode - st.children[0].keyOffset = st.keyOffset + diffidx p = st.children[0] } - // Create the two child leaves: the one containing the - // original value and the one containing the new value - // The child leave will be hashed directly in order to - // free up some memory. + // Create the two child leaves: one containing the original + // value and another containing the new value. The child leaf + // is hashed directly in order to free up some memory. origIdx := st.key[diffidx] - p.children[origIdx] = newLeaf(diffidx+1, st.key, st.val, st.db) + p.children[origIdx] = newLeaf(st.key[diffidx+1:], st.val, st.db) p.children[origIdx].hash() - - newIdx := key[diffidx+st.keyOffset] - p.children[newIdx] = newLeaf(p.keyOffset+1, key, value, st.db) - + newIdx := key[diffidx] + p.children[newIdx] = newLeaf(key[diffidx+1:], value, st.db) // Finally, cut off the key part that has been passed // over to the children. st.key = st.key[:diffidx] st.val = nil + case emptyNode: /* Empty */ st.nodeType = leafNode - st.key = key[st.keyOffset:] + st.key = key st.val = value + case hashedNode: panic("trying to insert into hash") + default: panic("invalid type") } } -// hash() hashes the node 'st' and converts it into 'hashedNode', if possible. -// Possible outcomes: +// hash converts st into a 'hashedNode', if possible. Possible outcomes: +// // 1. The rlp-encoded value was >= 32 bytes: // - Then the 32-byte `hash` will be accessible in `st.val`. // - And the 'st.type' will be 'hashedNode' @@ -372,119 +364,116 @@ func (st *StackTrie) insert(key, value []byte) { // - Then the <32 byte rlp-encoded value will be accessible in 'st.val'. // - And the 'st.type' will be 'hashedNode' AGAIN // -// This method will also: -// set 'st.type' to hashedNode -// clear 'st.key' +// This method also sets 'st.type' to hashedNode, and clears 'st.key'. func (st *StackTrie) hash() { - /* Shortcut if node is already hashed */ - if st.nodeType == hashedNode { - return - } - // The 'hasher' is taken from a pool, but we don't actually - // claim an instance until all children are done with their hashing, - // and we actually need one - var h *hasher + h := newHasher(false) + defer returnHasherToPool(h) + + st.hashRec(h) +} + +func (st *StackTrie) hashRec(hasher *hasher) { + // The switch below sets this to the RLP-encoding of this node. + var encodedNode []byte switch st.nodeType { + case hashedNode: + return + + case emptyNode: + st.val = emptyRoot.Bytes() + st.key = st.key[:0] + st.nodeType = hashedNode + return + case branchNode: - var nodes [17]node + var nodes rawFullNode for i, child := range st.children { if child == nil { nodes[i] = nilValueNode continue } - child.hash() + + child.hashRec(hasher) if len(child.val) < 32 { nodes[i] = rawNode(child.val) } else { nodes[i] = hashNode(child.val) } - st.children[i] = nil // Reclaim mem from subtree + + // Release child back to pool. + st.children[i] = nil returnToPool(child) } - nodes[16] = nilValueNode - h = newHasher(nil) - defer returnHasherToPool(h) - h.tmp.Reset() - if err := rlp.Encode(&h.tmp, nodes); err != nil { - panic(err) - } + + nodes.encode(hasher.encBuf) + encodedNode = hasher.encodedBytes() + case extNode: - st.children[0].hash() - h = newHasher(nil) - defer returnHasherToPool(h) - h.tmp.Reset() - var valuenode node + st.children[0].hashRec(hasher) + + sz := hexToCompactInPlace(st.key) + n := rawShortNode{Key: st.key[:sz]} if len(st.children[0].val) < 32 { - valuenode = rawNode(st.children[0].val) + n.Val = rawNode(st.children[0].val) } else { - valuenode = hashNode(st.children[0].val) - } - n := struct { - Key []byte - Val node - }{ - Key: hexToCompact(st.key), - Val: valuenode, - } - if err := rlp.Encode(&h.tmp, n); err != nil { - panic(err) + n.Val = hashNode(st.children[0].val) } + + n.encode(hasher.encBuf) + encodedNode = hasher.encodedBytes() + + // Release child back to pool. returnToPool(st.children[0]) - st.children[0] = nil // Reclaim mem from subtree + st.children[0] = nil + case leafNode: - h = newHasher(nil) - defer returnHasherToPool(h) - h.tmp.Reset() st.key = append(st.key, byte(16)) sz := hexToCompactInPlace(st.key) - n := [][]byte{st.key[:sz], st.val} - if err := rlp.Encode(&h.tmp, n); err != nil { - panic(err) - } - case emptyNode: - st.val = emptyRoot.Bytes() - st.key = st.key[:0] - st.nodeType = hashedNode - return + n := rawShortNode{Key: st.key[:sz], Val: valueNode(st.val)} + + n.encode(hasher.encBuf) + encodedNode = hasher.encodedBytes() + default: - panic("Invalid node type") + panic("invalid node type") } - st.key = st.key[:0] + st.nodeType = hashedNode - if len(h.tmp) < 32 { - st.val = common.CopyBytes(h.tmp) + st.key = st.key[:0] + if len(encodedNode) < 32 { + st.val = common.CopyBytes(encodedNode) return } + // Write the hash to the 'val'. We allocate a new val here to not mutate // input values - st.val = make([]byte, 32) - h.sha.Reset() - h.sha.Write(h.tmp) - h.sha.Read(st.val) + st.val = hasher.hashData(encodedNode) if st.db != nil { // TODO! Is it safe to Put the slice here? // Do all db implementations copy the value provided? - st.db.GetStateTrieDB().Put(st.val, h.tmp) + st.db.GetStateTrieDB().Put(st.val, encodedNode) } } -// Hash returns the hash of the current node +// Hash returns the hash of the current node. func (st *StackTrie) Hash() (h common.Hash) { - st.hash() - if len(st.val) != 32 { - // If the node's RLP isn't 32 bytes long, the node will not - // be hashed, and instead contain the rlp-encoding of the - // node. For the top level node, we need to force the hashing. - ret := make([]byte, 32) - h := newHasher(nil) - defer returnHasherToPool(h) - h.sha.Reset() - h.sha.Write(st.val) - h.sha.Read(ret) - return common.BytesToHash(ret) + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher) + if len(st.val) == 32 { + copy(h[:], st.val) + return h } - return common.BytesToHash(st.val) + + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed, and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + return h } // Commit will firstly hash the entrie trie if it's still not hashed @@ -494,23 +483,26 @@ func (st *StackTrie) Hash() (h common.Hash) { // // The associated database is expected, otherwise the whole commit // functionality should be disabled. -func (st *StackTrie) Commit() (common.Hash, error) { +func (st *StackTrie) Commit() (h common.Hash, err error) { if st.db == nil { return common.Hash{}, ErrCommitDisabled } - st.hash() - if len(st.val) != 32 { - // If the node's RLP isn't 32 bytes long, the node will not - // be hashed (and committed), and instead contain the rlp-encoding of the - // node. For the top level node, we need to force the hashing+commit. - ret := make([]byte, 32) - h := newHasher(nil) - defer returnHasherToPool(h) - h.sha.Reset() - h.sha.Write(st.val) - h.sha.Read(ret) - st.db.GetStateTrieDB().Put(ret, st.val) - return common.BytesToHash(ret), nil + + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher) + if len(st.val) == 32 { + copy(h[:], st.val) + return h, nil } - return common.BytesToHash(st.val), nil + + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed (and committed), and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing+commit. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + st.db.GetStateTrieDB().Put(h[:], st.val) + return h, nil } diff --git a/storage/statedb/trie.go b/storage/statedb/trie.go index b0e056279e..8a48f62e39 100644 --- a/storage/statedb/trie.go +++ b/storage/statedb/trie.go @@ -24,6 +24,7 @@ import ( "bytes" "errors" "fmt" + "sync" "github.com/klaytn/klaytn/common" "github.com/klaytn/klaytn/crypto" @@ -55,14 +56,17 @@ type LeafCallback func(paths [][]byte, hexpath []byte, leaf []byte, parent commo // Trie is a Merkle Patricia Trie. // The zero value is an empty trie with no database. -// Use NewTrie to create a trie that sits on top of a database. +// Use New to create a trie that sits on top of a database. // // Trie is not safe for concurrent use. type Trie struct { - db *Database - root node - originalRoot common.Hash - prefetching bool + db *Database + root node + prefetching bool + // Keep track of the number leafs which have been inserted since the last + // hashing operation. This number will not directly map to the number of + // actually unhashed nodes + unhashed int } // newFlag returns the cache flag value for a newly created node. @@ -78,13 +82,12 @@ func (t *Trie) newFlag() nodeFlag { // not exist in the database. Accessing the trie loads nodes from db on demand. func NewTrie(root common.Hash, db *Database) (*Trie, error) { if db == nil { - panic("statedb.NewTrie called without a database") + panic("trie.New called without a database") } trie := &Trie{ - db: db, - originalRoot: root, + db: db, } - if (root != common.Hash{}) && root != emptyRoot { + if root != (common.Hash{}) && root != emptyRoot { rootnode, err := trie.resolveHash(root[:], nil) if err != nil { return nil, err @@ -114,7 +117,7 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator { func (t *Trie) Get(key []byte) []byte { res, err := t.TryGet(key) if err != nil { - logger.Error("Unhandled trie error in Trie.Get", "err", err) + logger.Error(fmt.Sprintf("Unhandled trie error: %v", err)) } return res } @@ -252,7 +255,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new // stored in the trie. func (t *Trie) Update(key, value []byte) { if err := t.TryUpdate(key, value); err != nil { - logger.Error("Unhandled trie error in Trie.Update", "err", err) + logger.Error(fmt.Sprintf("Unhandled trie error: %v", err)) } } @@ -265,6 +268,7 @@ func (t *Trie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryUpdate(key, value []byte) error { + t.unhashed++ hexKey := keybytesToHex(key) return t.TryUpdateWithHexKey(hexKey, value) } @@ -360,13 +364,14 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte) { if err := t.TryDelete(key); err != nil { - logger.Error("Unhandled trie error in Trie.Delete", "err", err) + logger.Error(fmt.Sprintf("Unhandled trie error: %v", err)) } } // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { + t.unhashed++ k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) if err != nil { @@ -504,11 +509,7 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { hash := common.BytesToHash(n) - node, fromDB := t.db.node(hash) - if t.prefetching && fromDB { - memcacheCleanPrefetchMissMeter.Mark(1) - } - if node != nil { + if node, _ := t.db.node(hash); node != nil { return node, nil } return nil, &MissingNodeError{NodeHash: hash, Path: prefix} @@ -517,7 +518,7 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - hash, cached := t.hashRoot(nil, nil) + hash, cached, _ := t.hashRoot(nil, nil) t.root = cached return common.BytesToHash(hash.(hashNode)) } @@ -528,23 +529,60 @@ func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) { if t.db == nil { panic("commit called on trie with nil database") } - hash, cached := t.hashRoot(t.db, onleaf) - t.root = cached - return common.BytesToHash(hash.(hashNode)), nil + if t.root == nil { + return emptyRoot, nil + } + rootHash := t.Hash() + h := newCommitter() + defer returnCommitterToPool(h) + // Do a quick check if we really need to commit, before we spin + // up goroutines. This can happen e.g. if we load a trie for reading storage + // values, but don't write to it. + if !h.commitNeeded(t.root) { + return rootHash, nil + } + var wg sync.WaitGroup + if onleaf != nil { + h.onleaf = onleaf + h.leafCh = make(chan *leaf, leafChanSize) + wg.Add(1) + go func() { + defer wg.Done() + h.commitLoop(t.db) + }() + } + var newRoot hashNode + newRoot, err = h.Commit(t.root, t.db) + if onleaf != nil { + // The leafch is created in newCommitter if there was an onleaf callback + // provided. The commitLoop only _reads_ from it, and the commit + // operation was the sole writer. Therefore, it's safe to close this + // channel here. + close(h.leafCh) + wg.Wait() + } + if err != nil { + return common.Hash{}, err + } + t.root = newRoot + return rootHash, nil } -func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node) { +// hashRoot calculates the root hash of the given trie +func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) { if t.root == nil { - return hashNode(emptyRoot.Bytes()), nil + return hashNode(emptyRoot.Bytes()), nil, nil } - h := newHasher(onleaf) + // If the number of changes is below 100, we let one thread handle it + h := newHasher(t.unhashed >= 100) defer returnHasherToPool(h) - return h.hashRoot(t.root, db, true) + hashed, cached := h.hash(t.root, true) + return hashed, cached, nil } func GetHashAndHexKey(key []byte) ([]byte, []byte) { var hashKeyBuf [common.HashLength]byte - h := newHasher(nil) + h := newHasher(false) h.sha.Reset() h.sha.Write(key) hashKey := h.sha.Sum(hashKeyBuf[:0])