Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encoder: Decimal to Decimal128 #31

Merged
merged 9 commits into from Oct 7, 2020
52 changes: 42 additions & 10 deletions pydantic_odm/encoders/mongodb.py
Expand Up @@ -2,6 +2,8 @@
from __future__ import annotations

import abc
from bson.decimal128 import Decimal128
from decimal import Decimal
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, List, Union, cast

Expand All @@ -18,14 +20,10 @@ def __call__(self, data: "DictStrAny") -> "DictStrAny":
raise NotImplementedError()


def _convert_enums(
data: Union["DictStrAny", List[Any]]
def _recursive_iterator(
data: Union["DictStrAny", List[Any]], transform_func: Callable[[Any], Any]
) -> Union["DictStrAny", List[Any]]:
"""
Convert Enum to Enum.value for mongo query

Note: May be this solution not good
"""
# Cast append func type
append: Callable[[Union[str, int], Any], None]
# Convert in list
Expand All @@ -41,21 +39,55 @@ def _convert_enums(
append = lambda k, v: _data.update({k: v}) # noqa: E731
# Iterate passed data
for key, value in iterator:
# Replace enum object to enum value
if isinstance(value, Enum):
value = value.value
# Convert and replace value with transform function
value = transform_func(value)
# Recursive call if find sequence
if isinstance(value, (list, dict)):
value = _convert_enums(value)
value = _recursive_iterator(value, transform_func)
# Update new data with update method (append for list and update for dict)
append(key, value)
# Return new data
return _data


def _convert_enums(
data: Union["DictStrAny", List[Any]]
) -> Union["DictStrAny", List[Any]]:
"""
Convert Enum to Enum.value for mongo query

Note: May be this solution not good
"""

def enum_to_value(value: Union[Any, Enum]) -> Any:
if isinstance(value, Enum):
value = value.value
return value

return _recursive_iterator(data, enum_to_value)


def _convert_decimals(
data: Union["DictStrAny", List[Any]]
) -> Union["DictStrAny", List[Any]]:
"""
Convert decimal.Decimal to bson.decimal128.Decimal128
"""

def python_decimal_to_bson_decimal(
value: Union[Any, Decimal]
) -> Union[Any, Decimal128]:
if isinstance(value, Decimal):
value = Decimal128(value)
return value

return _recursive_iterator(data, python_decimal_to_bson_decimal)


class BaseMongoDBEncoder(AbstractMongoDBEncoder):
"""Base MongoDB encoder"""

def __call__(self, data: "DictStrAny") -> "DictStrAny":
data = cast("DictStrAny", _convert_enums(data))
data = cast("DictStrAny", _convert_decimals(data))
return data
38 changes: 38 additions & 0 deletions tests/encoders/mongodb.py
@@ -1,5 +1,7 @@
"""Tests for mongodb encoders"""
import pytest
from bson.decimal128 import Decimal128
from decimal import Decimal
dhensen marked this conversation as resolved.
Show resolved Hide resolved
from enum import Enum

from pydantic_odm.encoders import mongodb as mongodb_encoders
Expand Down Expand Up @@ -80,6 +82,42 @@ async def test__convert_enums(self, data, expected):
assert mongodb_encoders._convert_enums(data) == expected


class EncodeDecimalsTestCase:
@pytest.mark.parametrize(
"data, expected",
[
pytest.param(
{"money_amount": Decimal("13.37")},
{"money_amount": Decimal128("13.37")},
id="simple_dict",
),
pytest.param(
[{"money_amount": Decimal("13.37")}],
[{"money_amount": Decimal128("13.37")}],
id="simple_list",
),
pytest.param(
{"author": {"money_amount": Decimal("13.37")}},
{"author": {"money_amount": Decimal128("13.37")}},
id="nested",
),
pytest.param(
{
"title": "test",
"contributors": [{"money_amount": Decimal("13.37")}],
},
{
"title": "test",
"contributors": [{"money_amount": Decimal128("13.37")}],
},
id="list_in_nested",
),
],
)
async def test__convert_decimals(self, data, expected):
assert mongodb_encoders._convert_decimals(data) == expected


class BaseMongoDBEncoderTestCase:
@pytest.mark.parametrize(
"data, expected",
Expand Down