diff --git a/store/rootmulti/store_test.go b/store/rootmulti/store_test.go index a219593fbb86..1038bdef93db 100644 --- a/store/rootmulti/store_test.go +++ b/store/rootmulti/store_test.go @@ -961,3 +961,76 @@ func TestStateListeners(t *testing.T) { cacheMulti.Write() require.Equal(t, 1, len(listener.stateCache)) } + +type commitKVStoreStub struct { + types.CommitKVStore + Committed int +} + +func (stub *commitKVStoreStub) Commit() types.CommitID { + commitID := stub.CommitKVStore.Commit() + stub.Committed += 1 + return commitID +} + +func prepareStoreMap() map[types.StoreKey]types.CommitKVStore { + var db dbm.DB = dbm.NewMemDB() + store := NewStore(db, log.NewNopLogger()) + store.MountStoreWithDB(types.NewKVStoreKey("iavl1"), types.StoreTypeIAVL, nil) + store.MountStoreWithDB(types.NewKVStoreKey("iavl2"), types.StoreTypeIAVL, nil) + store.MountStoreWithDB(types.NewTransientStoreKey("trans1"), types.StoreTypeTransient, nil) + store.LoadLatestVersion() + return map[types.StoreKey]types.CommitKVStore{ + testStoreKey1: &commitKVStoreStub{ + CommitKVStore: store.GetStoreByName("iavl1").(types.CommitKVStore), + }, + testStoreKey2: &commitKVStoreStub{ + CommitKVStore: store.GetStoreByName("iavl2").(types.CommitKVStore), + }, + testStoreKey3: &commitKVStoreStub{ + CommitKVStore: store.GetStoreByName("trans1").(types.CommitKVStore), + }, + } +} + +func TestCommitStores(t *testing.T) { + testCases := []struct { + name string + committed int + exptectCommit int + }{ + { + "when upgrade not get interrupted", + 0, + 1, + }, + { + "when upgrade get interrupted once", + 1, + 0, + }, + { + "when upgrade get interrupted twice", + 2, + 0, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + storeMap := prepareStoreMap() + store := storeMap[testStoreKey1].(*commitKVStoreStub) + for i := tc.committed; i > 0; i-- { + store.Commit() + } + store.Committed = 0 + var version int64 = 1 + removalMap := map[types.StoreKey]bool{} + res := commitStores(version, storeMap, removalMap) + for _, s := range res.StoreInfos { + require.Equal(t, version, s.CommitId.Version) + } + require.Equal(t, version, res.Version) + require.Equal(t, tc.exptectCommit, store.Committed) + }) + } +}