Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix typing of from_dict and other classmethods #408

Merged
merged 6 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [[Unreleased]]
### Added:
- Add function to parse the final account balances from a transaction's metadata
- Function to parse the final account balances from a transaction's metadata

### Fixed:
- Typing for factory classmethods on models

## [1.6.0] - 2022-06-02
### Added:
- Support for dynamic fee calculation
- Add function to parse account balances from a transaction's metadata
- Function to parse account balances from a transaction's metadata
- Better error handling for invalid client URL
- Exported SubscribeBook

Expand Down
2 changes: 1 addition & 1 deletion xrpl/asyncio/transaction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def safe_sign_transaction(
serialized_bytes = bytes.fromhex(serialized_for_signing)
signature = sign(serialized_bytes, wallet.private_key)
transaction_json["TxnSignature"] = signature
return cast(Transaction, Transaction.from_xrpl(transaction_json))
return Transaction.from_xrpl(transaction_json)


async def safe_sign_and_autofill_transaction(
Expand Down
16 changes: 8 additions & 8 deletions xrpl/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC
from dataclasses import fields
from enum import Enum
from typing import Any, Dict, List, Pattern, Type, Union, cast, get_type_hints
from typing import Any, Dict, List, Pattern, Type, TypeVar, Union, cast, get_type_hints

from typing_extensions import Final, get_args, get_origin

Expand Down Expand Up @@ -35,6 +35,8 @@
# used for converting special substrings inside CamelCase fields
SPECIAL_CAMELCASE_STRINGS = ["NFToken"]

BM = TypeVar("BM", bound="BaseModel") # any type inherited from BaseModel


def _key_to_json(field: str) -> str:
"""
Expand Down Expand Up @@ -65,7 +67,7 @@ class BaseModel(ABC):
"""The base class for all model types."""

@classmethod
def is_dict_of_model(cls: Type[BaseModel], dictionary: Any) -> bool:
def is_dict_of_model(cls: Type[BM], dictionary: Any) -> bool:
"""
Checks whether the provided ``dictionary`` is a dictionary representation
of this class.
Expand Down Expand Up @@ -94,7 +96,7 @@ def is_dict_of_model(cls: Type[BaseModel], dictionary: Any) -> bool:
)

@classmethod
def from_dict(cls: Type[BaseModel], value: Dict[str, XRPL_VALUE_TYPE]) -> BaseModel:
def from_dict(cls: Type[BM], value: Dict[str, XRPL_VALUE_TYPE]) -> BM:
"""
Construct a new BaseModel from a dictionary of parameters.

Expand Down Expand Up @@ -126,7 +128,7 @@ def from_dict(cls: Type[BaseModel], value: Dict[str, XRPL_VALUE_TYPE]) -> BaseMo

@classmethod
def _from_dict_single_param(
cls: Type[BaseModel],
cls: Type[BM],
param: str,
param_type: Type[Any],
param_value: Union[int, str, bool, BaseModel, Enum, List[Any], Dict[str, Any]],
Expand Down Expand Up @@ -195,15 +197,13 @@ def _from_dict_single_param(
raise XRPLModelException(error_message)

@classmethod
def _get_only_init_args(
cls: Type[BaseModel], args: Dict[str, Any]
) -> Dict[str, Any]:
def _get_only_init_args(cls: Type[BM], args: Dict[str, Any]) -> Dict[str, Any]:
init_keys = {field.name for field in fields(cls) if field.init is True}
valid_args = {key: value for key, value in args.items() if key in init_keys}
return valid_args

@classmethod
def from_xrpl(cls: Type[BaseModel], value: Union[str, Dict[str, Any]]) -> BaseModel:
def from_xrpl(cls: Type[BM], value: Union[str, Dict[str, Any]]) -> BM:
"""
Creates a BaseModel object based on a JSON-like dictionary of keys in the JSON
format used by the binary codec, or an actual JSON string representing the same
Expand Down
11 changes: 7 additions & 4 deletions xrpl/models/requests/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Type, Union, cast
from typing import Any, Dict, Optional, Type, TypeVar, Union, cast

import xrpl.models.requests # bare import to get around circular dependency
from xrpl.models.base_model import BaseModel
Expand Down Expand Up @@ -81,6 +81,9 @@ class RequestMethod(str, Enum):
GENERIC_REQUEST = "zzgeneric_request"


R = TypeVar("R", bound="Request")


@require_kwargs_on_init
@dataclass(frozen=True)
class Request(BaseModel):
Expand All @@ -99,7 +102,7 @@ class Request(BaseModel):
id: Optional[Union[str, int]] = None

@classmethod
def from_dict(cls: Type[Request], value: Dict[str, Any]) -> Request:
def from_dict(cls: Type[R], value: Dict[str, Any]) -> R:
"""
Construct a new Request from a dictionary of parameters.

Expand All @@ -116,7 +119,7 @@ def from_dict(cls: Type[Request], value: Dict[str, Any]) -> Request:
if "method" not in value:
raise XRPLModelException("Request does not include method.")
correct_type = cls.get_method(value["method"])
return correct_type.from_dict(value)
return correct_type.from_dict(value) # type: ignore

if "method" in value:
method = value["method"]
Expand All @@ -135,7 +138,7 @@ def from_dict(cls: Type[Request], value: Dict[str, Any]) -> Request:
value = {**value}
del value["method"]

return cast(Request, super(Request, cls).from_dict(value))
return super(Request, cls).from_dict(value)

@classmethod
def get_method(cls: Type[Request], method: str) -> Type[Request]:
Expand Down
4 changes: 2 additions & 2 deletions xrpl/models/requests/sign.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Type, cast
from typing import Any, Dict, Optional, Type

from xrpl.constants import CryptoAlgorithm
from xrpl.models.requests.request import Request, RequestMethod
Expand Down Expand Up @@ -81,7 +81,7 @@ def from_dict(cls: Type[Sign], value: Dict[str, Any]) -> Sign:
del fixed_value["tx_json"]
else:
fixed_value = value
return cast(Sign, super(Sign, cls).from_dict(fixed_value))
return super(Sign, cls).from_dict(fixed_value)

def to_dict(self: Sign) -> Dict[str, Any]:
"""
Expand Down
4 changes: 2 additions & 2 deletions xrpl/models/requests/sign_and_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Type, cast
from typing import Any, Dict, Optional, Type

from xrpl.constants import CryptoAlgorithm
from xrpl.models.requests.submit import Submit
Expand Down Expand Up @@ -90,7 +90,7 @@ def from_dict(cls: Type[SignAndSubmit], value: Dict[str, Any]) -> SignAndSubmit:
del fixed_value["tx_json"]
else:
fixed_value = value
return cast(SignAndSubmit, super(SignAndSubmit, cls).from_dict(fixed_value))
return super(SignAndSubmit, cls).from_dict(fixed_value)

def to_dict(self: SignAndSubmit) -> Dict[str, Any]:
"""
Expand Down
4 changes: 2 additions & 2 deletions xrpl/models/requests/sign_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Type, cast
from typing import Any, Dict, Optional, Type

from xrpl.constants import CryptoAlgorithm
from xrpl.models.requests.request import Request, RequestMethod
Expand Down Expand Up @@ -72,7 +72,7 @@ def from_dict(cls: Type[SignFor], value: Dict[str, Any]) -> SignFor:
del fixed_value["tx_json"]
else:
fixed_value = value
return cast(SignFor, super(SignFor, cls).from_dict(fixed_value))
return super(SignFor, cls).from_dict(fixed_value)

def to_dict(self: SignFor) -> Dict[str, Any]:
"""
Expand Down
12 changes: 7 additions & 5 deletions xrpl/models/requests/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Type, cast
from typing import Any, Dict, Type, TypeVar

from xrpl.models.requests.request import Request, RequestMethod
from xrpl.models.utils import require_kwargs_on_init

S = TypeVar("S", bound="Submit") # any type inherited from Submit


@require_kwargs_on_init
@dataclass(frozen=True)
Expand Down Expand Up @@ -60,7 +62,7 @@ class Submit(Request):
method: RequestMethod = field(default=RequestMethod.SUBMIT, init=False)

@classmethod
def from_dict(cls: Type[Submit], value: Dict[str, Any]) -> Submit:
def from_dict(cls: Type[S], value: Dict[str, Any]) -> S:
"""
Construct a new Submit from a dictionary of parameters.

Expand All @@ -77,6 +79,6 @@ def from_dict(cls: Type[Submit], value: Dict[str, Any]) -> Submit:

if cls.__name__ == "Submit":
if "tx_blob" in value:
return SubmitOnly.from_dict(value)
return SignAndSubmit.from_dict(value)
return cast(Submit, super(Submit, cls).from_dict(value))
return SubmitOnly.from_dict(value) # type: ignore
return SignAndSubmit.from_dict(value) # type: ignore
return super(Submit, cls).from_dict(value)
6 changes: 2 additions & 4 deletions xrpl/models/requests/submit_multisigned.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Type, cast
from typing import Any, Dict, Type

from xrpl.models.requests.request import Request, RequestMethod
from xrpl.models.required import REQUIRED
Expand Down Expand Up @@ -57,9 +57,7 @@ def from_dict(
fixed_value = {**value}
if "TransactionType" in fixed_value["tx_json"]: # xrpl format
fixed_value["tx_json"] = Transaction.from_xrpl(fixed_value["tx_json"])
return cast(
SubmitMultisigned, super(SubmitMultisigned, cls).from_dict(fixed_value)
)
return super(SubmitMultisigned, cls).from_dict(fixed_value)

def to_dict(self: SubmitMultisigned) -> Dict[str, Any]:
"""
Expand Down
8 changes: 3 additions & 5 deletions xrpl/models/transactions/signer_list_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Dict, List, Optional, Type

from xrpl.models.base_model import BaseModel
from xrpl.models.required import REQUIRED
Expand Down Expand Up @@ -66,10 +66,8 @@ def from_dict(cls: Type[SignerEntry], value: Dict[str, Any]) -> SignerEntry:
A new SignerEntry object, constructed using the given parameters.
"""
if len(value) == 1 and "signer_entry" in value:
return cast(
SignerEntry, super(SignerEntry, cls).from_dict(value["signer_entry"])
)
return cast(SignerEntry, super(SignerEntry, cls).from_dict(value))
return super(SignerEntry, cls).from_dict(value["signer_entry"])
return super(SignerEntry, cls).from_dict(value)

def to_dict(self: SignerEntry) -> Dict[str, Any]:
"""
Expand Down
19 changes: 11 additions & 8 deletions xrpl/models/transactions/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dataclasses import dataclass
from hashlib import sha512
from typing import Any, Dict, List, Optional, Type, Union, cast
from typing import Any, Dict, List, Optional, Type, TypeVar, Union

from typing_extensions import Final

Expand Down Expand Up @@ -136,8 +136,8 @@ def from_dict(cls: Type[Memo], value: Dict[str, Any]) -> Memo:
XRPLModelException: If the dictionary provided is invalid.
"""
if "memo" not in value:
return cast(Memo, super(Memo, cls).from_dict(value))
return cast(Memo, super(Memo, cls).from_dict(value["memo"]))
return super(Memo, cls).from_dict(value)
return super(Memo, cls).from_dict(value["memo"])

def to_dict(self: Memo) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -222,8 +222,8 @@ def from_dict(cls: Type[Signer], value: Dict[str, Any]) -> Signer:
XRPLModelException: If the dictionary provided is invalid.
"""
if "signer" not in value:
return cast(Signer, super(Signer, cls).from_dict(value))
return cast(Signer, super(Signer, cls).from_dict(value["signer"]))
return super(Signer, cls).from_dict(value)
return super(Signer, cls).from_dict(value["signer"])

def to_dict(self: Signer) -> Dict[str, Any]:
"""
Expand All @@ -235,6 +235,9 @@ def to_dict(self: Signer) -> Dict[str, Any]:
return {"signer": super().to_dict()}


T = TypeVar("T", bound="Transaction") # any type inherited from Transaction


@require_kwargs_on_init
@dataclass(frozen=True)
class Transaction(BaseModel):
Expand Down Expand Up @@ -373,7 +376,7 @@ def to_xrpl(self: Transaction) -> Dict[str, Any]:
return transaction_json_to_binary_codec_form(self.to_dict())

@classmethod
def from_dict(cls: Type[Transaction], value: Dict[str, Any]) -> Transaction:
def from_dict(cls: Type[T], value: Dict[str, Any]) -> T:
"""
Construct a new Transaction from a dictionary of parameters.

Expand All @@ -393,7 +396,7 @@ def from_dict(cls: Type[Transaction], value: Dict[str, Any]) -> Transaction:
"Transaction does not include transaction_type."
)
correct_type = cls.get_transaction_type(value["transaction_type"])
return correct_type.from_dict(value)
return correct_type.from_dict(value) # type: ignore
else:
if "transaction_type" in value:
if value["transaction_type"] != cls.__name__:
Expand All @@ -404,7 +407,7 @@ def from_dict(cls: Type[Transaction], value: Dict[str, Any]) -> Transaction:
)
value = {**value}
del value["transaction_type"]
return cast(Transaction, super(Transaction, cls).from_dict(value))
return super(Transaction, cls).from_dict(value)

def has_flag(self: Transaction, flag: int) -> bool:
"""
Expand Down
4 changes: 1 addition & 3 deletions xrpl/utils/sidechain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Sidechain-related helper util functions."""

from typing import cast

from xrpl.constants import XRPLException
from xrpl.models import Memo, Payment
from xrpl.utils.str_conversions import str_to_hex
Expand Down Expand Up @@ -42,4 +40,4 @@ def create_cross_chain_payment(payment: Payment, dest_account: str) -> Payment:
payment_dict["memos"] = new_memos
if "txn_signature" in payment_dict:
del payment_dict["txn_signature"]
return cast(Payment, Payment.from_dict(payment_dict))
return Payment.from_dict(payment_dict)