Skip to content

Commit

Permalink
refactor: pack MultiIPFSProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed May 10, 2024
1 parent f9d53e3 commit 28b214e
Showing 1 changed file with 12 additions and 26 deletions.
38 changes: 12 additions & 26 deletions src/providers/ipfs/multi.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,26 @@
import logging
from abc import ABC
from functools import wraps
from typing import Generic, Iterable, TypeVar, Protocol
from typing import Iterable

from .cid import CIDv0, CIDv1
from .types import IPFSError, IPFSProvider

logger = logging.getLogger(__name__)


T = TypeVar("T")


class MultiProvider(Generic[T], ABC):
"""Base class for working with multiple providers"""

providers: list[T]
current_provider_index: int = 0
last_working_provider_index: int = 0

@property
def provider(self) -> T:
return self.providers[self.current_provider_index]


class SupportsRetries(Protocol):
@property
def retries(self) -> int:
...


class MaxRetryError(IPFSError):
...


class MultiIPFSProvider(IPFSProvider, MultiProvider[IPFSProvider]):
class MultiIPFSProvider(IPFSProvider):
"""Fallback-driven provider for IPFS"""

# NOTE: The provider is NOT thread-safe.

providers: list[IPFSProvider]
current_provider_index: int = 0
last_working_provider_index: int = 0

def __init__(self, providers: Iterable[IPFSProvider], *, retries: int = 3) -> None:
super().__init__()
self.retries = retries
Expand All @@ -50,7 +32,7 @@ def __init__(self, providers: Iterable[IPFSProvider], *, retries: int = 3) -> No
@staticmethod
def with_fallback(fn):
@wraps(fn)
def wrapped(self: MultiProvider, *args, **kwargs):
def wrapped(self: "MultiIPFSProvider", *args, **kwargs):
try:
result = fn(self, *args, **kwargs)
except IPFSError:
Expand All @@ -68,7 +50,7 @@ def wrapped(self: MultiProvider, *args, **kwargs):
@staticmethod
def retry(fn):
@wraps(fn)
def wrapped(self: SupportsRetries, *args, **kwargs):
def wrapped(self: "MultiIPFSProvider", *args, **kwargs):
retries_left = self.retries
while retries_left:
try:
Expand All @@ -84,6 +66,10 @@ def wrapped(self: SupportsRetries, *args, **kwargs):

return wrapped

@property
def provider(self) -> IPFSProvider:
return self.providers[self.current_provider_index]

@with_fallback
@retry
def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
Expand Down

0 comments on commit 28b214e

Please sign in to comment.