Skip to content

Commit

Permalink
feat: retries for the MultiIPFSProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed May 9, 2024
1 parent b75513b commit 3cf55c3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
36 changes: 34 additions & 2 deletions src/providers/ipfs/multi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from abc import ABC
from functools import wraps
from typing import Generic, Iterable, TypeVar
from typing import Generic, Iterable, TypeVar, Protocol

from .cid import CIDv0, CIDv1
from .types import IPFSError, IPFSProvider
Expand All @@ -24,6 +24,32 @@ def provider(self) -> T:
return self.providers[self.current_provider_index]


class SupportsRetries(Protocol):

@property
def retries(self) -> int: ...


class MaxRetryError(IPFSError):
...


def retry(fn):
@wraps(fn)
def wrapped(self: SupportsRetries, *args, **kwargs):
retries_left = self.retries
while retries_left:
try:
return fn(self, *args, **kwargs)
except IPFSError as ex:
retries_left -= 1
if not retries_left:
raise MaxRetryError from ex
logger.warning({"msg": f"Retrying a failed call of {fn.__name__}, {retries_left=}", "error": str(ex)})

return wrapped


def with_fallback(fn):
@wraps(fn)
def wrapped(self: MultiProvider, *args, **kwargs):
Expand All @@ -45,18 +71,23 @@ def wrapped(self: MultiProvider, *args, **kwargs):
class MultiIPFSProvider(IPFSProvider, MultiProvider[IPFSProvider]):
"""Fallback-driven provider for IPFS"""

def __init__(self, providers: Iterable[IPFSProvider]) -> None:
# NOTE: The provider is NOT thread-safe.

def __init__(self, providers: Iterable[IPFSProvider], * ,retries: int = 3) -> None:
super().__init__()
self.retries = retries
self.providers = list(providers)
assert self.providers
for p in self.providers:
assert isinstance(p, IPFSProvider)

@with_fallback
@retry
def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
return self.provider.fetch(cid)

@with_fallback
@retry
def publish(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
# If the current provider fails to upload or pin a file, it makes sense
# to try to both upload and to pin via a different provider.
Expand All @@ -68,5 +99,6 @@ def upload(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
raise NotImplementedError

@with_fallback
@retry
def pin(self, cid: CIDv0 | CIDv1) -> None:
return self.provider.pin(cid)
7 changes: 7 additions & 0 deletions src/providers/ipfs/pinata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import requests

from .cid import CIDv0, CIDv1, is_cid_v0
from .types import FetchError, IPFSProvider, PinError, UploadError


logger = logging.getLogger(__name__)


class Pinata(IPFSProvider):
"""pinata.cloud IPFS provider"""

Expand All @@ -22,6 +26,7 @@ def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
resp = requests.get(url, timeout=self.timeout)
resp.raise_for_status()
except requests.RequestException as ex:
logger.error({"msg": "Request has been failed", "error": str(ex)})
raise FetchError(cid) from ex
return resp.content

Expand All @@ -32,6 +37,7 @@ def upload(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
resp = s.post(url, files={"file": content})
resp.raise_for_status()
except requests.RequestException as ex:
logger.error({"msg": "Request has been failed", "error": str(ex)})
raise UploadError from ex
cid = resp.json()["IpfsHash"]
return CIDv0(cid) if is_cid_v0(cid) else CIDv1(cid)
Expand All @@ -43,4 +49,5 @@ def pin(self, cid: CIDv0 | CIDv1) -> None:
resp = s.post(url, json={"hashToPin": str(cid)})
resp.raise_for_status()
except requests.RequestException as ex:
logger.error({"msg": "Request has been failed", "error": str(ex)})
raise PinError(cid) from ex
1 change: 0 additions & 1 deletion src/providers/ipfs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class IPFSProvider(ABC):
def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
...

@abstractmethod
def publish(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
cid = self.upload(content, name)
self.pin(cid)
Expand Down

0 comments on commit 3cf55c3

Please sign in to comment.