Skip to content

Commit

Permalink
plotting|tests: Cache plot data in PlotManager
Browse files Browse the repository at this point in the history
  • Loading branch information
xdustinface committed Sep 1, 2021
1 parent 0fb0dd4 commit eb8a244
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 38 deletions.
185 changes: 147 additions & 38 deletions chia/plotting/manager.py
@@ -1,3 +1,4 @@
from dataclasses import dataclass
import logging
import threading
import time
Expand All @@ -21,12 +22,98 @@
stream_plot_info_pk,
stream_plot_info_ph,
)
from chia.util.ints import uint16
from chia.util.path import mkdir
from chia.util.streamable import Streamable, streamable
from chia.types.blockchain_format.proof_of_space import ProofOfSpace
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.wallet.derive_keys import master_sk_to_local_sk

log = logging.getLogger(__name__)

CURRENT_VERSION: uint16 = uint16(0)


@dataclass(frozen=True)
@streamable
class CacheEntry(Streamable):
pool_public_key: Optional[G1Element]
pool_contract_puzzle_hash: Optional[bytes32]
plot_public_key: G1Element


@dataclass(frozen=True)
@streamable
class DiskCache(Streamable):
version: uint16
data: List[Tuple[bytes32, CacheEntry]]


class Cache:
_changed: bool
_data: Dict[bytes32, CacheEntry]

def __init__(self, path: Path):
self._changed = False
self._data = {}
self._path = path
if not path.parent.exists():
mkdir(path.parent)

def __len__(self):
return len(self._data)

def update(self, plot_id: bytes32, entry: CacheEntry):
self._data[plot_id] = entry
self._changed = True

def remove(self, cache_keys: List[bytes32]):
for key in cache_keys:
if key in self._data:
del self._data[key]
self._changed = True

def save(self):
try:
disk_cache: DiskCache = DiskCache(
CURRENT_VERSION, [(plot_id, cache_entry) for plot_id, cache_entry in self.items()]
)
serialized: bytes = bytes(disk_cache)
self._path.write_bytes(serialized)
self._changed = False
log.info(f"Saved {len(serialized)} bytes of cached data")
except Exception as e:
log.error(f"Failed to save cache: {e}, {traceback.format_exc()}")

def load(self):
try:
serialized = self._path.read_bytes()
log.info(f"Loaded {len(serialized)} bytes of cached data")
stored_cache: DiskCache = DiskCache.from_bytes(serialized)
if stored_cache.version != CURRENT_VERSION:
# TODO, Migrate or drop current cache if the version changes.
raise ValueError(f"Invalid cache version {stored_cache.version}. Expected version {CURRENT_VERSION}.")
self._data = {plot_id: cache_entry for plot_id, cache_entry in stored_cache.data}
except FileNotFoundError:
log.debug(f"Cache {self._path} not found")
except Exception as e:
log.error(f"Failed to load cache: {e}, {traceback.format_exc()}")

def keys(self):
return self._data.keys()

def items(self):
return self._data.items()

def get(self, plot_id):
return self._data.get(plot_id)

def changed(self):
return self._changed

def path(self):
return self._path


class PlotManager:
plots: Dict[Path, PlotInfo]
Expand All @@ -36,6 +123,7 @@ class PlotManager:
no_key_filenames: Set[Path]
farmer_public_keys: List[G1Element]
pool_public_keys: List[G1Element]
cache: Cache
match_str: Optional[str]
show_memo: bool
open_no_key_filenames: bool
Expand Down Expand Up @@ -64,6 +152,7 @@ def __init__(
self.no_key_filenames = set()
self.farmer_public_keys = []
self.pool_public_keys = []
self.cache = Cache(self.root_path.resolve() / "cache" / "plot_manager.dat")
self.match_str = match_str
self.show_memo = show_memo
self.open_no_key_filenames = open_no_key_filenames
Expand Down Expand Up @@ -101,6 +190,7 @@ def needs_refresh(self) -> bool:
def start_refreshing(self):
self._refreshing_enabled = True
if self._refresh_thread is None or not self._refresh_thread.is_alive():
self.cache.load()
self._refresh_thread = threading.Thread(target=self._refresh_task)
self._refresh_thread.start()

Expand All @@ -126,12 +216,22 @@ def _refresh_task(self):
total_result += batch_result
self._refresh_callback(batch_result)
if batch_result.remaining_files == 0:
self.last_refresh_time = time.time()
break
batch_sleep = self.refresh_parameter.batch_sleep_milliseconds
self.log.debug(f"refresh_plots: Sleep {batch_sleep} milliseconds")
time.sleep(float(batch_sleep) / 1000.0)

# Cleanup unused cache
available_ids = set([plot_info.prover.get_id() for plot_info in self.plots.values()])
invalid_cache_keys = [plot_id for plot_id in self.cache.keys() if plot_id not in available_ids]
self.cache.remove(invalid_cache_keys)
self.log.debug(f"_refresh_task: cached entries removed: {len(invalid_cache_keys)}")

if self.cache.changed():
self.cache.save()

self.last_refresh_time = time.time()

self.log.debug(
f"_refresh_task: total_result.loaded_plots {total_result.loaded_plots}, "
f"total_result.removed_plots {total_result.removed_plots}, "
Expand Down Expand Up @@ -206,43 +306,52 @@ def process_file(file_path: Path) -> Dict:
)
return new_provers

(
pool_public_key_or_puzzle_hash,
farmer_public_key,
local_master_sk,
) = parse_plot_info(prover.get_memo())

# Only use plots that correct keys associated with them
if self.farmer_public_keys is not None and farmer_public_key not in self.farmer_public_keys:
log.warning(f"Plot {file_path} has a farmer public key that is not in the farmer's pk list.")
self.no_key_filenames.add(file_path)
if not self.open_no_key_filenames:
return new_provers
cache_entry = self.cache.get(prover.get_id())
if cache_entry is None:
(
pool_public_key_or_puzzle_hash,
farmer_public_key,
local_master_sk,
) = parse_plot_info(prover.get_memo())

if isinstance(pool_public_key_or_puzzle_hash, G1Element):
pool_public_key = pool_public_key_or_puzzle_hash
pool_contract_puzzle_hash = None
else:
assert isinstance(pool_public_key_or_puzzle_hash, bytes32)
pool_public_key = None
pool_contract_puzzle_hash = pool_public_key_or_puzzle_hash

if (
self.pool_public_keys is not None
and pool_public_key is not None
and pool_public_key not in self.pool_public_keys
):
log.warning(f"Plot {file_path} has a pool public key that is not in the farmer's pool pk list.")
self.no_key_filenames.add(file_path)
if not self.open_no_key_filenames:
return new_provers
# Only use plots that correct keys associated with them
if self.farmer_public_keys is not None and farmer_public_key not in self.farmer_public_keys:
log.warning(
f"Plot {file_path} has a farmer public key that is not in the farmer's pk list."
)
self.no_key_filenames.add(file_path)
if not self.open_no_key_filenames:
return new_provers

pool_public_key: Optional[G1Element] = None
pool_contract_puzzle_hash: Optional[bytes32] = None
if isinstance(pool_public_key_or_puzzle_hash, G1Element):
pool_public_key = pool_public_key_or_puzzle_hash
else:
assert isinstance(pool_public_key_or_puzzle_hash, bytes32)
pool_contract_puzzle_hash = pool_public_key_or_puzzle_hash

if (
self.pool_public_keys is not None
and pool_public_key is not None
and pool_public_key not in self.pool_public_keys
):
log.warning(
f"Plot {file_path} has a pool public key that is not in the farmer's pool pk list."
)
self.no_key_filenames.add(file_path)
if not self.open_no_key_filenames:
return new_provers

stat_info = file_path.stat()
local_sk = master_sk_to_local_sk(local_master_sk)
stat_info = file_path.stat()
local_sk = master_sk_to_local_sk(local_master_sk)

plot_public_key: G1Element = ProofOfSpace.generate_plot_public_key(
local_sk.get_g1(), farmer_public_key, pool_contract_puzzle_hash is not None
)
plot_public_key: G1Element = ProofOfSpace.generate_plot_public_key(
local_sk.get_g1(), farmer_public_key, pool_contract_puzzle_hash is not None
)

cache_entry = CacheEntry(pool_public_key, pool_contract_puzzle_hash, plot_public_key)
self.cache.update(prover.get_id(), cache_entry)

with self.plot_filename_paths_lock:
if file_path.name not in self.plot_filename_paths:
Expand All @@ -258,9 +367,9 @@ def process_file(file_path: Path) -> Dict:

new_provers[file_path] = PlotInfo(
prover,
pool_public_key,
pool_contract_puzzle_hash,
plot_public_key,
cache_entry.pool_public_key,
cache_entry.pool_contract_puzzle_hash,
cache_entry.plot_public_key,
stat_info.st_size,
stat_info.st_mtime,
)
Expand Down
65 changes: 65 additions & 0 deletions tests/core/test_farmer_harvester_rpc.py
@@ -1,5 +1,6 @@
# flake8: noqa: E501
import logging
from os import unlink
from pathlib import Path
from secrets import token_bytes
from shutil import copy, move
Expand All @@ -10,6 +11,7 @@

from chia.consensus.coinbase import create_puzzlehash_for_pk
from chia.plotting.util import stream_plot_info_ph, stream_plot_info_pk, PlotRefreshResult
from chia.plotting.manager import PlotManager
from chia.protocols import farmer_protocol
from chia.rpc.farmer_rpc_api import FarmerRpcApi
from chia.rpc.farmer_rpc_client import FarmerRpcClient
Expand Down Expand Up @@ -200,6 +202,7 @@ async def test_case(
await time_out_assert(5, harvester.plot_manager.needs_refresh, value=False)
result = await client_2.get_plots()
assert len(result["plots"]) == expect_total_plots
assert len(harvester.plot_manager.cache) == expect_total_plots
assert len(harvester.plot_manager.failed_to_open_filenames) == 0

# Add plot_dir with two new plots
Expand Down Expand Up @@ -293,6 +296,68 @@ async def test_case(
expected_directories=1,
expect_total_plots=0,
)
# Recover the plots to test caching
# First make sure cache gets written if required and new plots are loaded
await test_case(
client_2.add_plot_directory(str(get_plot_dir())),
expect_loaded=20,
expect_removed=0,
expect_processed=20,
expected_directories=2,
expect_total_plots=20,
)
assert harvester.plot_manager.cache.path().exists()
unlink(harvester.plot_manager.cache.path())
# Should not write the cache again on shutdown because it didn't change
assert not harvester.plot_manager.cache.path().exists()
harvester.plot_manager.stop_refreshing()
assert not harvester.plot_manager.cache.path().exists()
# Manually trigger `save_cache` and make sure it creates a new cache file
harvester.plot_manager.cache.save()
assert harvester.plot_manager.cache.path().exists()

expected_result.loaded_plots = 20
expected_result.removed_plots = 0
expected_result.processed_files = 20
expected_result.remaining_files = 0
plot_manager: PlotManager = PlotManager(harvester.root_path, test_refresh_callback)
plot_manager.start_refreshing()
assert len(harvester.plot_manager.cache) == len(plot_manager.cache)
await time_out_assert(5, plot_manager.needs_refresh, value=False)
for path, plot_info in harvester.plot_manager.plots.items():
assert path in plot_manager.plots
assert plot_manager.plots[path].prover.get_filename() == plot_info.prover.get_filename()
assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id()
assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo()
assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size()
assert plot_manager.plots[path].pool_public_key == plot_info.pool_public_key
assert plot_manager.plots[path].pool_contract_puzzle_hash == plot_info.pool_contract_puzzle_hash
assert plot_manager.plots[path].plot_public_key == plot_info.plot_public_key
assert plot_manager.plots[path].file_size == plot_info.file_size
assert plot_manager.plots[path].time_modified == plot_info.time_modified

assert harvester.plot_manager.plot_filename_paths == plot_manager.plot_filename_paths
assert harvester.plot_manager.failed_to_open_filenames == plot_manager.failed_to_open_filenames
assert harvester.plot_manager.no_key_filenames == plot_manager.no_key_filenames
plot_manager.stop_refreshing()
# Modify the content of the plot_manager.dat
with open(harvester.plot_manager.cache.path(), "r+b") as file:
file.write(b"\xff\xff") # Sets Cache.version to 65535
# Make sure it just loads the plots normally if it fails to load the cache
plot_manager = PlotManager(harvester.root_path, test_refresh_callback)
plot_manager.cache.load()
assert len(plot_manager.cache) == 0
plot_manager.set_public_keys(
harvester.plot_manager.farmer_public_keys, harvester.plot_manager.pool_public_keys
)
expected_result.loaded_plots = 20
expected_result.removed_plots = 0
expected_result.processed_files = 20
expected_result.remaining_files = 0
plot_manager.start_refreshing()
await time_out_assert(5, plot_manager.needs_refresh, value=False)
assert len(plot_manager.plots) == len(harvester.plot_manager.plots)
plot_manager.stop_refreshing()

# Test re-trying if processing a plot failed
# First save the plot
Expand Down

0 comments on commit eb8a244

Please sign in to comment.