Skip to content

Commit

Permalink
fix: private state properties
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed May 9, 2024
1 parent 8388a26 commit e37b757
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/modules/csm/checkpoint.py
Expand Up @@ -179,7 +179,7 @@ def _process_epoch(
)
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
self.state.processed_epochs.add(duty_epoch)
self.state.add_processed_epoch(duty_epoch)
self.state.commit()
self.state.status()

Expand Down
42 changes: 20 additions & 22 deletions src/modules/csm/state.py
Expand Up @@ -31,9 +31,9 @@ def perf(self) -> float:
@dataclass
class State(UserDict[ValidatorIndex, AttestationsAggregate]):
"""Tracks processing state of CSM performance oracle frame"""
data: dict[ValidatorIndex, AttestationsAggregate] = field(default_factory=dict)
epochs_to_process: set[EpochNumber] = field(default_factory=set)
processed_epochs: set[EpochNumber] = field(default_factory=set)
_data: dict[ValidatorIndex, AttestationsAggregate] = field(default_factory=dict)
_epochs_to_process: set[EpochNumber] = field(default_factory=set)
_processed_epochs: set[EpochNumber] = field(default_factory=set)

EXTENSION = ".pkl"

Expand Down Expand Up @@ -64,9 +64,9 @@ def commit(self) -> None:
os.replace(self.buffer, self.file())

def clear(self) -> None:
self.data = {}
self.epochs_to_process.clear()
self.processed_epochs.clear()
self._data = {}
self._epochs_to_process.clear()
self._processed_epochs.clear()
assert self.is_empty

def inc(self, key: ValidatorIndex, included: bool) -> None:
Expand All @@ -75,49 +75,47 @@ def inc(self, key: ValidatorIndex, included: bool) -> None:
perf.included += 1 if included else 0
self[key] = perf

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)

def status(self) -> None:
assigned, included = reduce(
lambda acc, aggr: (acc[0] + aggr.assigned, acc[1] + aggr.included), self.values(), (0, 0)
)

logger.info(
{
"msg": f"Processed {len(self.processed_epochs)} of {len(self.epochs_to_process)} epochs",
"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs",
"assigned": assigned,
"included": included,
"avg_perf": self.avg_perf,
}
)

def validate_for_report(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
for epoch in self.epochs_to_process:
if l_epoch <= epoch <= r_epoch:
continue
if not self.is_fulfilled:
raise InvalidState()

for epoch in self.processed_epochs:
for epoch in self._processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
raise InvalidState()

if not self.is_fulfilled:
raise InvalidState()

for epoch in sequence(l_epoch, r_epoch):
if epoch not in self.processed_epochs:
if epoch not in self._processed_epochs:
raise InvalidState()

def validate_for_collect(self, l_epoch: EpochNumber, r_epoch: EpochNumber):

invalidated = False

for epoch in self.epochs_to_process:
for epoch in self._epochs_to_process:
if l_epoch <= epoch <= r_epoch:
continue
invalidated = True
break

for epoch in self.processed_epochs:
for epoch in self._processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
invalidated = True
Expand All @@ -128,19 +126,19 @@ def validate_for_collect(self, l_epoch: EpochNumber, r_epoch: EpochNumber):
self.clear()
self.commit()

if self.is_empty or r_epoch > max(self.epochs_to_process):
self.epochs_to_process.update(sequence(l_epoch, r_epoch))
if self.is_empty or r_epoch > max(self._epochs_to_process):
self._epochs_to_process.update(sequence(l_epoch, r_epoch))
self.commit()

@property
def is_empty(self) -> bool:
return not self.data and not self.epochs_to_process and not self.processed_epochs
return not self._data and not self._epochs_to_process and not self._processed_epochs

@property
def unprocessed_epochs(self) -> set[EpochNumber]:
if not self.epochs_to_process:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
return self.epochs_to_process - self.processed_epochs
return self._epochs_to_process - self._processed_epochs

@property
def is_fulfilled(self) -> bool:
Expand Down

0 comments on commit e37b757

Please sign in to comment.