Skip to content

Commit

Permalink
core/state/snapshot: fix data race in layer flattening (ethereum#23628)
Browse files Browse the repository at this point in the history
* core/state/snapshot: fix data race in layer flattening

* core/state/snapshot: fix typo
  • Loading branch information
rjl493456442 authored and jagdeep sidhu committed Oct 15, 2021
1 parent 194ff3e commit 65a3fec
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
16 changes: 13 additions & 3 deletions core/state/snapshot/snapshot.go
Expand Up @@ -163,6 +163,9 @@ type Tree struct {
cache int // Megabytes permitted to use for read caches
layers map[common.Hash]snapshot // Collection of all known layers
lock sync.RWMutex

// Test hooks
onFlatten func() // Hook invoked when the bottom most diff layers are flattened
}

// New attempts to load an already existing snapshot from a persistent key-value
Expand Down Expand Up @@ -463,14 +466,21 @@ func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer {
return nil

case *diffLayer:
// Hold the write lock until the flattened parent is linked correctly.
// Otherwise, the stale layer may be accessed by external reads in the
// meantime.
diff.lock.Lock()
defer diff.lock.Unlock()

// Flatten the parent into the grandparent. The flattening internally obtains a
// write lock on grandparent.
flattened := parent.flatten().(*diffLayer)
t.layers[flattened.root] = flattened

diff.lock.Lock()
defer diff.lock.Unlock()

// Invoke the hook if it's registered. Ugly hack.
if t.onFlatten != nil {
t.onFlatten()
}
diff.parent = flattened
if flattened.memory < aggregatorMemoryLimit {
// Accumulator layer is smaller than the limit, so we can abort, unless
Expand Down
63 changes: 62 additions & 1 deletion core/state/snapshot/snapshot_test.go
Expand Up @@ -22,6 +22,7 @@ import (
"math/big"
"math/rand"
"testing"
"time"

"github.com/VictoriaMetrics/fastcache"
"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -324,7 +325,7 @@ func TestPostCapBasicDataAccess(t *testing.T) {
}
}

// TestSnaphots tests the functionality for retrieveing the snapshot
// TestSnaphots tests the functionality for retrieving the snapshot
// with given head root and the desired depth.
func TestSnaphots(t *testing.T) {
// setAccount is a helper to construct a random account entry and assign it to
Expand Down Expand Up @@ -423,3 +424,63 @@ func TestSnaphots(t *testing.T) {
}
}
}

// TestReadStateDuringFlattening tests the scenario that, during the
// bottom diff layers are merging which tags these as stale, the read
// happens via a pre-created top snapshot layer which tries to access
// the state in these stale layers. Ensure this read can retrieve the
// right state back(block until the flattening is finished) instead of
// an unexpected error(snapshot layer is stale).
func TestReadStateDuringFlattening(t *testing.T) {
// setAccount is a helper to construct a random account entry and assign it to
// an account slot in a snapshot
setAccount := func(accKey string) map[common.Hash][]byte {
return map[common.Hash][]byte{
common.HexToHash(accKey): randomAccount(),
}
}
// Create a starting base layer and a snapshot tree out of it
base := &diskLayer{
diskdb: rawdb.NewMemoryDatabase(),
root: common.HexToHash("0x01"),
cache: fastcache.New(1024 * 500),
}
snaps := &Tree{
layers: map[common.Hash]snapshot{
base.root: base,
},
}
// 4 layers in total, 3 diff layers and 1 disk layers
snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil)
snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil)
snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil)

// Obtain the topmost snapshot handler for state accessing
snap := snaps.Snapshot(common.HexToHash("0xa3"))

// Register the testing hook to access the state after flattening
var result = make(chan *Account)
snaps.onFlatten = func() {
// Spin up a thread to read the account from the pre-created
// snapshot handler. It's expected to be blocked.
go func() {
account, _ := snap.Account(common.HexToHash("0xa1"))
result <- account
}()
select {
case res := <-result:
t.Fatalf("Unexpected return %v", res)
case <-time.NewTimer(time.Millisecond * 300).C:
}
}
// Cap the snap tree, which will mark the bottom-most layer as stale.
snaps.Cap(common.HexToHash("0xa3"), 1)
select {
case account := <-result:
if account == nil {
t.Fatal("Failed to retrieve account")
}
case <-time.NewTimer(time.Millisecond * 300).C:
t.Fatal("Unexpected blocker")
}
}

0 comments on commit 65a3fec

Please sign in to comment.