Skip to content

Commit

Permalink
refactor: MultiIPFSProvider fallbacks via decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed May 7, 2024
1 parent a93045b commit 09d4f35
Showing 1 changed file with 52 additions and 33 deletions.
85 changes: 52 additions & 33 deletions src/providers/ipfs/multi.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,65 @@
from typing import Any, Iterable, Sequence
import logging
from abc import ABC
from functools import wraps
from typing import Generic, Iterable, TypeVar

from .cid import CIDv0, CIDv1
from .types import IPFSError, IPFSProvider


logger = logging.getLogger(__name__)


class MultiIPFSProvider:
"""Fallback-driven provider for IPFS"""
T = TypeVar("T")


class MultiProvider(Generic[T], ABC):
"""Base class for working with multiple providers"""

providers: list[T]
current_provider_index: int = 0
last_working_provider_index: int = 0

@property
def provider(self) -> T:
return self.providers[self.current_provider_index]


providers: Sequence[IPFSProvider]
def with_fallback(fn):
@wraps(fn)
def wrapped(self: MultiProvider, *args, **kwargs):
try:
result = fn(self, *args, **kwargs)
except IPFSError:
self.current_provider_index = (self.current_provider_index + 1) % len(self.providers)
if self.last_working_provider_index == self.current_provider_index:
logger.error({"msg": "No more IPFS providers left to call"})
raise
return wrapped(self, *args, **kwargs)

_current_provider_index: int = 0
_last_working_provider_index: int = 0
self.last_working_provider_index = self.current_provider_index
return result

return wrapped


class MultiIPFSProvider(IPFSProvider, MultiProvider[IPFSProvider]):
"""Fallback-driven provider for IPFS"""

def __init__(self, providers: Iterable[IPFSProvider]) -> None:
super().__init__()
self.providers = []
for p in providers:
self.providers = list(providers)
assert self.providers
for p in self.providers:
assert isinstance(p, IPFSProvider)
self.providers.append(p)

def __getattribute__(self, name: str, /) -> Any:
if name in ("fetch", "upload", "pin"):
return self._retry_call(name)
return super().__getattribute__(name)

def _retry_call(self, name: str):
def wrapper(*args, **kwargs):
try:
provider = self.providers[self._current_provider_index]
fn = getattr(provider, name)
result = fn(*args, **kwargs)
except IPFSError:
self._current_provider_index = (self._current_provider_index + 1) % len(self.providers)
if self._last_working_provider_index == self._current_provider_index:
logger.error({"msg": "No more IPFS providers left to call"})
raise
return wrapper(*args, **kwargs)

self._last_working_provider_index = self._current_provider_index
return result

return wrapper

@with_fallback
def upload(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
return self.provider.upload(content, name)

@with_fallback
def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
return self.provider.fetch(cid)

@with_fallback
def pin(self, cid: CIDv0 | CIDv1) -> None:
self.provider.pin(cid)

0 comments on commit 09d4f35

Please sign in to comment.