diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index 6422de2b424d..6c70d13ce37a 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import logging import threading import time @@ -21,12 +22,98 @@ stream_plot_info_pk, stream_plot_info_ph, ) +from chia.util.ints import uint16 +from chia.util.path import mkdir +from chia.util.streamable import Streamable, streamable from chia.types.blockchain_format.proof_of_space import ProofOfSpace from chia.types.blockchain_format.sized_bytes import bytes32 from chia.wallet.derive_keys import master_sk_to_local_sk log = logging.getLogger(__name__) +CURRENT_VERSION: uint16 = uint16(0) + + +@dataclass(frozen=True) +@streamable +class CacheEntry(Streamable): + pool_public_key: Optional[G1Element] + pool_contract_puzzle_hash: Optional[bytes32] + plot_public_key: G1Element + + +@dataclass(frozen=True) +@streamable +class DiskCache(Streamable): + version: uint16 + data: List[Tuple[bytes32, CacheEntry]] + + +class Cache: + _changed: bool + _data: Dict[bytes32, CacheEntry] + + def __init__(self, path: Path): + self._changed = False + self._data = {} + self._path = path + if not path.parent.exists(): + mkdir(path.parent) + + def __len__(self): + return len(self._data) + + def update(self, plot_id: bytes32, entry: CacheEntry): + self._data[plot_id] = entry + self._changed = True + + def remove(self, cache_keys: List[bytes32]): + for key in cache_keys: + if key in self._data: + del self._data[key] + self._changed = True + + def save(self): + try: + disk_cache: DiskCache = DiskCache( + CURRENT_VERSION, [(plot_id, cache_entry) for plot_id, cache_entry in self.items()] + ) + serialized: bytes = bytes(disk_cache) + self._path.write_bytes(serialized) + self._changed = False + log.info(f"Saved {len(serialized)} bytes of cached data") + except Exception as e: + log.error(f"Failed to save cache: {e}, {traceback.format_exc()}") + + def load(self): + try: + serialized = self._path.read_bytes() + log.info(f"Loaded {len(serialized)} bytes of cached data") + stored_cache: DiskCache = DiskCache.from_bytes(serialized) + if stored_cache.version != CURRENT_VERSION: + # TODO, Migrate or drop current cache if the version changes. + raise ValueError(f"Invalid cache version {stored_cache.version}. Expected version {CURRENT_VERSION}.") + self._data = {plot_id: cache_entry for plot_id, cache_entry in stored_cache.data} + except FileNotFoundError: + log.debug(f"Cache {self._path} not found") + except Exception as e: + log.error(f"Failed to load cache: {e}, {traceback.format_exc()}") + + def keys(self): + return self._data.keys() + + def items(self): + return self._data.items() + + def get(self, plot_id): + return self._data.get(plot_id) + + def changed(self): + return self._changed + + def path(self): + return self._path + class PlotManager: plots: Dict[Path, PlotInfo] @@ -36,6 +123,7 @@ class PlotManager: no_key_filenames: Set[Path] farmer_public_keys: List[G1Element] pool_public_keys: List[G1Element] + cache: Cache match_str: Optional[str] show_memo: bool open_no_key_filenames: bool @@ -64,6 +152,7 @@ def __init__( self.no_key_filenames = set() self.farmer_public_keys = [] self.pool_public_keys = [] + self.cache = Cache(self.root_path.resolve() / "cache" / "plot_manager.dat") self.match_str = match_str self.show_memo = show_memo self.open_no_key_filenames = open_no_key_filenames @@ -101,6 +190,7 @@ def needs_refresh(self) -> bool: def start_refreshing(self): self._refreshing_enabled = True if self._refresh_thread is None or not self._refresh_thread.is_alive(): + self.cache.load() self._refresh_thread = threading.Thread(target=self._refresh_task) self._refresh_thread.start() @@ -126,12 +216,22 @@ def _refresh_task(self): total_result += batch_result self._refresh_callback(batch_result) if batch_result.remaining_files == 0: - self.last_refresh_time = time.time() break batch_sleep = self.refresh_parameter.batch_sleep_milliseconds self.log.debug(f"refresh_plots: Sleep {batch_sleep} milliseconds") time.sleep(float(batch_sleep) / 1000.0) + # Cleanup unused cache + available_ids = set([plot_info.prover.get_id() for plot_info in self.plots.values()]) + invalid_cache_keys = [plot_id for plot_id in self.cache.keys() if plot_id not in available_ids] + self.cache.remove(invalid_cache_keys) + self.log.debug(f"_refresh_task: cached entries removed: {len(invalid_cache_keys)}") + + if self.cache.changed(): + self.cache.save() + + self.last_refresh_time = time.time() + self.log.debug( f"_refresh_task: total_result.loaded_plots {total_result.loaded_plots}, " f"total_result.removed_plots {total_result.removed_plots}, " @@ -206,43 +306,52 @@ def process_file(file_path: Path) -> Dict: ) return new_provers - ( - pool_public_key_or_puzzle_hash, - farmer_public_key, - local_master_sk, - ) = parse_plot_info(prover.get_memo()) - - # Only use plots that correct keys associated with them - if self.farmer_public_keys is not None and farmer_public_key not in self.farmer_public_keys: - log.warning(f"Plot {file_path} has a farmer public key that is not in the farmer's pk list.") - self.no_key_filenames.add(file_path) - if not self.open_no_key_filenames: - return new_provers + cache_entry = self.cache.get(prover.get_id()) + if cache_entry is None: + ( + pool_public_key_or_puzzle_hash, + farmer_public_key, + local_master_sk, + ) = parse_plot_info(prover.get_memo()) - if isinstance(pool_public_key_or_puzzle_hash, G1Element): - pool_public_key = pool_public_key_or_puzzle_hash - pool_contract_puzzle_hash = None - else: - assert isinstance(pool_public_key_or_puzzle_hash, bytes32) - pool_public_key = None - pool_contract_puzzle_hash = pool_public_key_or_puzzle_hash - - if ( - self.pool_public_keys is not None - and pool_public_key is not None - and pool_public_key not in self.pool_public_keys - ): - log.warning(f"Plot {file_path} has a pool public key that is not in the farmer's pool pk list.") - self.no_key_filenames.add(file_path) - if not self.open_no_key_filenames: - return new_provers + # Only use plots that correct keys associated with them + if self.farmer_public_keys is not None and farmer_public_key not in self.farmer_public_keys: + log.warning( + f"Plot {file_path} has a farmer public key that is not in the farmer's pk list." + ) + self.no_key_filenames.add(file_path) + if not self.open_no_key_filenames: + return new_provers + + pool_public_key: Optional[G1Element] = None + pool_contract_puzzle_hash: Optional[bytes32] = None + if isinstance(pool_public_key_or_puzzle_hash, G1Element): + pool_public_key = pool_public_key_or_puzzle_hash + else: + assert isinstance(pool_public_key_or_puzzle_hash, bytes32) + pool_contract_puzzle_hash = pool_public_key_or_puzzle_hash + + if ( + self.pool_public_keys is not None + and pool_public_key is not None + and pool_public_key not in self.pool_public_keys + ): + log.warning( + f"Plot {file_path} has a pool public key that is not in the farmer's pool pk list." + ) + self.no_key_filenames.add(file_path) + if not self.open_no_key_filenames: + return new_provers - stat_info = file_path.stat() - local_sk = master_sk_to_local_sk(local_master_sk) + stat_info = file_path.stat() + local_sk = master_sk_to_local_sk(local_master_sk) - plot_public_key: G1Element = ProofOfSpace.generate_plot_public_key( - local_sk.get_g1(), farmer_public_key, pool_contract_puzzle_hash is not None - ) + plot_public_key: G1Element = ProofOfSpace.generate_plot_public_key( + local_sk.get_g1(), farmer_public_key, pool_contract_puzzle_hash is not None + ) + + cache_entry = CacheEntry(pool_public_key, pool_contract_puzzle_hash, plot_public_key) + self.cache.update(prover.get_id(), cache_entry) with self.plot_filename_paths_lock: if file_path.name not in self.plot_filename_paths: @@ -258,9 +367,9 @@ def process_file(file_path: Path) -> Dict: new_provers[file_path] = PlotInfo( prover, - pool_public_key, - pool_contract_puzzle_hash, - plot_public_key, + cache_entry.pool_public_key, + cache_entry.pool_contract_puzzle_hash, + cache_entry.plot_public_key, stat_info.st_size, stat_info.st_mtime, ) diff --git a/tests/core/test_farmer_harvester_rpc.py b/tests/core/test_farmer_harvester_rpc.py index cd28f1af8352..f291e770279c 100644 --- a/tests/core/test_farmer_harvester_rpc.py +++ b/tests/core/test_farmer_harvester_rpc.py @@ -1,5 +1,6 @@ # flake8: noqa: E501 import logging +from os import unlink from pathlib import Path from secrets import token_bytes from shutil import copy, move @@ -10,6 +11,7 @@ from chia.consensus.coinbase import create_puzzlehash_for_pk from chia.plotting.util import stream_plot_info_ph, stream_plot_info_pk, PlotRefreshResult +from chia.plotting.manager import PlotManager from chia.protocols import farmer_protocol from chia.rpc.farmer_rpc_api import FarmerRpcApi from chia.rpc.farmer_rpc_client import FarmerRpcClient @@ -200,6 +202,7 @@ async def test_case( await time_out_assert(5, harvester.plot_manager.needs_refresh, value=False) result = await client_2.get_plots() assert len(result["plots"]) == expect_total_plots + assert len(harvester.plot_manager.cache) == expect_total_plots assert len(harvester.plot_manager.failed_to_open_filenames) == 0 # Add plot_dir with two new plots @@ -293,6 +296,68 @@ async def test_case( expected_directories=1, expect_total_plots=0, ) + # Recover the plots to test caching + # First make sure cache gets written if required and new plots are loaded + await test_case( + client_2.add_plot_directory(str(get_plot_dir())), + expect_loaded=20, + expect_removed=0, + expect_processed=20, + expected_directories=2, + expect_total_plots=20, + ) + assert harvester.plot_manager.cache.path().exists() + unlink(harvester.plot_manager.cache.path()) + # Should not write the cache again on shutdown because it didn't change + assert not harvester.plot_manager.cache.path().exists() + harvester.plot_manager.stop_refreshing() + assert not harvester.plot_manager.cache.path().exists() + # Manually trigger `save_cache` and make sure it creates a new cache file + harvester.plot_manager.cache.save() + assert harvester.plot_manager.cache.path().exists() + + expected_result.loaded_plots = 20 + expected_result.removed_plots = 0 + expected_result.processed_files = 20 + expected_result.remaining_files = 0 + plot_manager: PlotManager = PlotManager(harvester.root_path, test_refresh_callback) + plot_manager.start_refreshing() + assert len(harvester.plot_manager.cache) == len(plot_manager.cache) + await time_out_assert(5, plot_manager.needs_refresh, value=False) + for path, plot_info in harvester.plot_manager.plots.items(): + assert path in plot_manager.plots + assert plot_manager.plots[path].prover.get_filename() == plot_info.prover.get_filename() + assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id() + assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo() + assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size() + assert plot_manager.plots[path].pool_public_key == plot_info.pool_public_key + assert plot_manager.plots[path].pool_contract_puzzle_hash == plot_info.pool_contract_puzzle_hash + assert plot_manager.plots[path].plot_public_key == plot_info.plot_public_key + assert plot_manager.plots[path].file_size == plot_info.file_size + assert plot_manager.plots[path].time_modified == plot_info.time_modified + + assert harvester.plot_manager.plot_filename_paths == plot_manager.plot_filename_paths + assert harvester.plot_manager.failed_to_open_filenames == plot_manager.failed_to_open_filenames + assert harvester.plot_manager.no_key_filenames == plot_manager.no_key_filenames + plot_manager.stop_refreshing() + # Modify the content of the plot_manager.dat + with open(harvester.plot_manager.cache.path(), "r+b") as file: + file.write(b"\xff\xff") # Sets Cache.version to 65535 + # Make sure it just loads the plots normally if it fails to load the cache + plot_manager = PlotManager(harvester.root_path, test_refresh_callback) + plot_manager.cache.load() + assert len(plot_manager.cache) == 0 + plot_manager.set_public_keys( + harvester.plot_manager.farmer_public_keys, harvester.plot_manager.pool_public_keys + ) + expected_result.loaded_plots = 20 + expected_result.removed_plots = 0 + expected_result.processed_files = 20 + expected_result.remaining_files = 0 + plot_manager.start_refreshing() + await time_out_assert(5, plot_manager.needs_refresh, value=False) + assert len(plot_manager.plots) == len(harvester.plot_manager.plots) + plot_manager.stop_refreshing() # Test re-trying if processing a plot failed # First save the plot