Skip to content

Commit

Permalink
feat: gw3.io provider and MultiIPFSProvider (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed May 10, 2024
1 parent a7e6a78 commit 2f98812
Show file tree
Hide file tree
Showing 14 changed files with 384 additions and 28 deletions.
31 changes: 28 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import cast
from typing import Iterable, cast

from prometheus_client import start_http_server
from web3.middleware import simple_cache_middleware
Expand All @@ -12,7 +12,7 @@
from src.modules.ejector.ejector import Ejector
from src.modules.checks.checks_module import ChecksModule
from src.modules.csm.csm import CSOracle
from src.providers.ipfs import DummyIPFSProvider
from src.providers.ipfs import DummyIPFSProvider, GW3, IPFSProvider, MultiIPFSProvider, Pinata, PublicIPFS
from src.typings import OracleModule
from src.utils.build import get_build_info
from src.web3py.extensions import (
Expand Down Expand Up @@ -79,6 +79,12 @@ def main(module_name: OracleModule):
logger.info({'msg': 'Initialize keys api client.'})
kac = KeysAPIClientModule(variables.KEYS_API_URI, web3)

logger.info({'msg': 'Initialize IPFS providers.'})
ipfs = MultiIPFSProvider(
ipfs_providers(),
retries=variables.HTTP_REQUEST_RETRY_COUNT_IPFS,
)

logger.info({'msg': 'Check configured providers.'})
check_providers_chain_ids(web3, cc, kac)

Expand All @@ -89,7 +95,7 @@ def main(module_name: OracleModule):
'csm': LazyCSM,
'cc': lambda: cc, # type: ignore[dict-item]
'kac': lambda: kac, # type: ignore[dict-item]
'ipfs': DummyIPFSProvider, # TODO: Make a factory.
'ipfs': lambda: ipfs, # type: ignore[dict-item]
})

logger.info({'msg': 'Add metrics middleware for ETH1 requests.'})
Expand Down Expand Up @@ -139,6 +145,25 @@ def check_providers_chain_ids(web3: Web3, cc: ConsensusClientModule, kac: KeysAP
f'Keys API chain id: {keys_api_chain_id}\n')


def ipfs_providers() -> Iterable[IPFSProvider]:
if variables.GW3_ACCESS_KEY and variables.GW3_SECRET_KEY:
yield GW3(
variables.GW3_ACCESS_KEY,
variables.GW3_SECRET_KEY,
timeout=variables.HTTP_REQUEST_TIMEOUT_IPFS,
)

if variables.PINATA_JWT:
yield Pinata(
variables.PINATA_JWT,
timeout=variables.HTTP_REQUEST_TIMEOUT_IPFS,
)

yield PublicIPFS(timeout=variables.HTTP_REQUEST_TIMEOUT_IPFS)

yield DummyIPFSProvider() # FIXME: Remove after migration.


if __name__ == '__main__':
module_name_arg = sys.argv[-1]
if module_name_arg not in iter(OracleModule):
Expand Down
4 changes: 1 addition & 3 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
if cid:
logger.info({"msg": "Fetching tree by CID from IPFS", "cid": cid})
tree = Tree.decode(self.w3.ipfs.fetch(cid))

logger.info({"msg": "Restored tree from IPFS dump", "root": repr(root)})

if tree.root != root:
Expand All @@ -161,8 +160,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
if distributed:
tree = Tree.new(tuple((no_id, amount) for (no_id, amount) in shares.items()))
logger.info({"msg": "New tree built for the report", "root": repr(tree.root)})
cid = self.w3.ipfs.upload(tree.encode())
self.w3.ipfs.pin(cid)
cid = self.w3.ipfs.publish(tree.encode())
root = tree.root

if root == ZERO_HASH:
Expand Down
5 changes: 3 additions & 2 deletions src/modules/csm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from hexbytes import HexBytes

from src.providers.ipfs import CID
from src.typings import SlotNumber

logger = logging.getLogger(__name__)
Expand All @@ -13,7 +14,7 @@ class ReportData:
consensusVersion: int
ref_slot: SlotNumber
tree_root: HexBytes
tree_cid: str
tree_cid: CID
distributed: int

def as_tuple(self):
Expand All @@ -22,6 +23,6 @@ def as_tuple(self):
self.consensusVersion,
self.ref_slot,
self.tree_root,
self.tree_cid,
str(self.tree_cid),
self.distributed,
)
4 changes: 4 additions & 0 deletions src/modules/submodules/oracle_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from src.metrics.prometheus.basic import ORACLE_BLOCK_NUMBER, ORACLE_SLOT_NUMBER
from src.modules.submodules.exceptions import IsNotMemberException, IncompatibleContractVersion
from src.providers.http_provider import NotOkResponse
from src.providers.ipfs import IPFSError
from src.providers.keys.client import KeysOutdatedException
from src.utils.cache import clear_global_cache
from src.web3py.extensions.lido_validators import CountOfKeysDiffersException
Expand Down Expand Up @@ -86,6 +87,7 @@ def _receive_last_finalized_slot(self) -> BlockStamp:
return bs

def run_cycle(self, blockstamp: BlockStamp) -> ModuleExecuteDelay:
# pylint: disable=too-many-branches
logger.info({'msg': 'Execute module.', 'value': blockstamp})

try:
Expand All @@ -112,6 +114,8 @@ def run_cycle(self, blockstamp: BlockStamp) -> ModuleExecuteDelay:
logger.error({'msg': 'Keys API service returned incorrect number of keys.', 'error': str(error)})
except Web3Exception as error:
logger.error({'msg': 'Web3py exception.', 'error': str(error)})
except IPFSError as error:
logger.error({'msg': 'IPFS provider error.', 'error': str(error)})
except ValueError as error:
logger.error({'msg': 'Unexpected error.', 'error': str(error)})
else:
Expand Down
6 changes: 5 additions & 1 deletion src/providers/ipfs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .cid import *
from .dummy import *
from .gw3 import *
from .multi import *
from .pinata import *
from .public import *
from .types import *
from .utils import *
19 changes: 19 additions & 0 deletions src/providers/ipfs/cid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections import UserString


class CID(UserString):
def __repr__(self):
return f"{self.__class__.__name__}({self.data})"


class CIDv0(CID):
...


class CIDv1(CID):
...


# @see https://github.com/multiformats/cid/blob/master/README.md#decoding-algorithm
def is_cid_v0(cid: str) -> bool:
return cid.startswith("Qm") and len(cid) == 46
20 changes: 10 additions & 10 deletions src/providers/ipfs/dummy.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import hashlib

from web3 import Web3
from web3.module import Module
from .types import CIDv0, CIDv1, IPFSError, IPFSProvider, FetchError

from .types import CIDv0, CIDv1, IPFSProvider, NotFound


class DummyIPFSProvider(IPFSProvider, Module):
class DummyIPFSProvider(IPFSProvider):
"""Dummy IPFS provider which using the local filesystem as a backend"""

# pylint: disable=unreachable

mempool: dict[CIDv0 | CIDv1, bytes]

def __init__(self, w3: Web3) -> None:
super().__init__(w3)
def __init__(self) -> None:
self.mempool = {}

def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
try:
return self.mempool[cid]
except KeyError:
try:
with open(cid, mode="r") as f:
with open(str(cid), mode="r") as f:
return f.read().encode("utf-8")
except Exception as e:
raise NotFound() from e
raise FetchError(cid) from e


def upload(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
raise IPFSError # FIXME: Remove after migration
cid = CIDv0("Qm" + hashlib.sha256(content).hexdigest()) # XXX: Dummy.
self.mempool[cid] = content
return cid

def pin(self, cid: CIDv0 | CIDv1) -> None:
raise IPFSError # FIXME: Remove after migration
content = self.fetch(cid)

with open(cid, mode="w", encoding="utf-8") as f:
with open(str(cid), mode="w", encoding="utf-8") as f:
f.write(content.decode())
85 changes: 85 additions & 0 deletions src/providers/ipfs/gw3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import base64
import hashlib
import hmac
import logging
import time
from urllib.parse import urlencode, urlparse

import requests

from src.providers.ipfs.cid import CIDv0, CIDv1, is_cid_v0

from .types import IPFSError, IPFSProvider, FetchError, PinError, UploadError

logger = logging.getLogger(__name__)


class GW3(IPFSProvider):
"""gw3.io client"""

ENDPOINT = "https://gw3.io"

def __init__(self, access_key: str, access_secret: str, *, timeout: int) -> None:
super().__init__()
self.access_key = access_key
self.access_secret = base64.urlsafe_b64decode(access_secret)
self.timeout = timeout

def fetch(self, cid: CIDv0 | CIDv1):
try:
resp = self._send("GET", f"{self.ENDPOINT}/ipfs/{cid}")
except IPFSError as ex:
raise FetchError(cid) from ex
return resp.content

def upload(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
url = self._auth_upload(len(content))
try:
response = requests.post(url, data=content, timeout=self.timeout)
cid = response.headers["IPFS-Hash"]
except IPFSError as ex:
raise UploadError from ex
except KeyError as ex:
raise UploadError from ex

return CIDv0(cid) if is_cid_v0(cid) else CIDv1(cid)

def pin(self, cid: CIDv0 | CIDv1) -> None:
try:
self._send("POST", f"{self.ENDPOINT}/api/v0/pin/add", {"arg": str(cid)})
except IPFSError as ex:
raise PinError(cid) from ex

def _auth_upload(self, size: int) -> str:
try:
response = self._send("POST", f"{self.ENDPOINT}/ipfs/", {"size": size})
return response.json()["data"]["url"]
except IPFSError as ex:
raise UploadError from ex
except KeyError as ex:
raise UploadError from ex

def _send(self, method: str, url: str, params: dict | None = None) -> requests.Response:
req = self._signed_req(method, url, params)
try:
response = requests.Session().send(req, timeout=self.timeout)
response.raise_for_status()
except requests.RequestException as ex:
logger.error({"msg": "Request has been failed", "error": str(ex)})
raise IPFSError from ex
return response

def _signed_req(self, method: str, url: str, params: dict | None = None) -> requests.PreparedRequest:
params = params or {}
params["ts"] = str(int(time.time()))
query = urlencode(params, doseq=True)

parsed_url = urlparse(url)
data = "\n".join((method, parsed_url.path, query)).encode("utf-8")
mac = hmac.new(self.access_secret, data, hashlib.sha256)
sign = base64.urlsafe_b64encode(mac.digest())

req = requests.Request(method=method, url=url, params=params)
req.headers["X-Access-Key"] = self.access_key
req.headers["X-Access-Signature"] = sign.decode("utf-8")
return req.prepare()
93 changes: 93 additions & 0 deletions src/providers/ipfs/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
from functools import wraps
from typing import Iterable

from .cid import CIDv0, CIDv1
from .types import IPFSError, IPFSProvider

logger = logging.getLogger(__name__)


class MaxRetryError(IPFSError):
...


class MultiIPFSProvider(IPFSProvider):
"""Fallback-driven provider for IPFS"""

# NOTE: The provider is NOT thread-safe.

providers: list[IPFSProvider]
current_provider_index: int = 0
last_working_provider_index: int = 0

def __init__(self, providers: Iterable[IPFSProvider], *, retries: int = 3) -> None:
super().__init__()
self.retries = retries
self.providers = list(providers)
assert self.providers
for p in self.providers:
assert isinstance(p, IPFSProvider)

@staticmethod
def with_fallback(fn):
@wraps(fn)
def wrapped(self: "MultiIPFSProvider", *args, **kwargs):
try:
result = fn(self, *args, **kwargs)
except IPFSError:
self.current_provider_index = (self.current_provider_index + 1) % len(self.providers)
if self.last_working_provider_index == self.current_provider_index:
logger.error({"msg": "No more IPFS providers left to call"})
raise
return wrapped(self, *args, **kwargs)

self.last_working_provider_index = self.current_provider_index
return result

return wrapped

@staticmethod
def retry(fn):
@wraps(fn)
def wrapped(self: "MultiIPFSProvider", *args, **kwargs):
retries_left = self.retries
while retries_left:
try:
return fn(self, *args, **kwargs)
except IPFSError as ex:
retries_left -= 1
if not retries_left:
raise MaxRetryError from ex
logger.warning(
{"msg": f"Retrying a failed call of {fn.__name__}, {retries_left=}", "error": str(ex)}
)
raise MaxRetryError

return wrapped

@property
def provider(self) -> IPFSProvider:
return self.providers[self.current_provider_index]

@with_fallback
@retry
def fetch(self, cid: CIDv0 | CIDv1) -> bytes:
return self.provider.fetch(cid)

@with_fallback
@retry
def publish(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
# If the current provider fails to upload or pin a file, it makes sense
# to try to both upload and to pin via a different provider.
return self.provider.publish(content, name)

def upload(self, content: bytes, name: str | None = None) -> CIDv0 | CIDv1:
# It doesn't make sense to upload a file to a different providers networks
# without a guarantee the file will be available via another one.
raise NotImplementedError

@with_fallback
@retry
def pin(self, cid: CIDv0 | CIDv1) -> None:
return self.provider.pin(cid)

0 comments on commit 2f98812

Please sign in to comment.