Skip to content

Commit

Permalink
tagged unions: leave tag key unless forbid_extra_keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Apr 18, 2024
1 parent 898e59c commit a13fa2e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 21 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ can now be used as decorators and have gained new features.
([#463](https://github.com/python-attrs/cattrs/pull/463))
- `cattrs.gen` generators now attach metadata to the generated functions, making them introspectable.
([#472](https://github.com/python-attrs/cattrs/pull/472))
- The [tagged union strategy](https://catt.rs/en/stable/strategies.html#tagged-unions-strategy) now leaves the tags in the payload unless `forbid_extra_keys` is set.
([#533](https://github.com/python-attrs/cattrs/issues/533) [#534](https://github.com/python-attrs/cattrs/pull/534))
- More robust support for `Annotated` and `NotRequired` in TypedDicts.
([#450](https://github.com/python-attrs/cattrs/pull/450))
- `typing_extensions.Literal` is now automatically structured, just like `typing.Literal`.
Expand Down
10 changes: 6 additions & 4 deletions docs/strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ The payload can be interpreted as about a dozen different messages, based on the

To keep the example simple we define two classes, one for the `REFUND` event and one for everything else.

```python
```{testcode} apple
@define
class Refund:
Expand All @@ -92,7 +92,9 @@ Next, we use the _tagged unions_ strategy to prepare our converter.
The tag value for the `Refund` event is `REFUND`, and we can let the `OtherAppleNotification` class handle all the other cases.
The `tag_generator` parameter is a callable, so we can give it the `get` method of a dictionary.

```python
```{doctest} apple
>>> from cattrs.strategies import configure_tagged_union
>>> c = Converter()
>>> configure_tagged_union(
Expand All @@ -107,7 +109,7 @@ The `tag_generator` parameter is a callable, so we can give it the `get` method

The converter is now ready to start structuring Apple notifications.

```python
```{doctest} apple
>>> payload = {"notificationType": "REFUND", "originalTransactionId": "1"}
>>> notification = c.structure(payload, AppleNotification)
Expand All @@ -117,7 +119,7 @@ The converter is now ready to start structuring Apple notifications.
... print(f"Refund for {txn_id}!")
... case OtherAppleNotification(not_type):
... print("Can't handle this yet")

Refund for 1!
```

```{versionadded} 23.1.0
Expand Down
57 changes: 40 additions & 17 deletions src/cattrs/strategies/_unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,50 @@ def unstructure_tagged_union(
return res

if default is NOTHING:
if getattr(converter, "forbid_extra_keys", False):

def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
val = val.copy()
return _tag_to_cl[val.pop(_tag_name)](val)
def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
val = val.copy()
return _tag_to_cl[val.pop(_tag_name)](val)

else:

def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
return _tag_to_cl[val[_tag_name]](val)

else:
if getattr(converter, "forbid_extra_keys", False):

def structure_tagged_union(
val: dict,
_,
_tag_to_hook=tag_to_hook,
_tag_name=tag_name,
_dh=default_handler,
_default=default,
) -> union:
if _tag_name in val:
val = val.copy()
return _tag_to_hook[val.pop(_tag_name)](val)
return _dh(val, _default)

def structure_tagged_union(
val: dict,
_,
_tag_to_hook=tag_to_hook,
_tag_name=tag_name,
_dh=default_handler,
_default=default,
) -> union:
if _tag_name in val:
val = val.copy()
return _tag_to_hook[val.pop(_tag_name)](val)
return _dh(val, _default)
else:

def structure_tagged_union(
val: dict,
_,
_tag_to_hook=tag_to_hook,
_tag_name=tag_name,
_dh=default_handler,
_default=default,
) -> union:
if _tag_name in val:
return _tag_to_hook[val[_tag_name]](val)
return _dh(val, _default)

converter.register_unstructure_hook(union, unstructure_tagged_union)
converter.register_structure_hook(union, structure_tagged_union)
Expand Down
22 changes: 22 additions & 0 deletions tests/strategies/test_tagged_unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def test_default_member(converter: BaseConverter) -> None:
assert converter.structure({"_type": "B", "a": 1}, union) == B("1")


def test_default_member_with_tag(converter: BaseConverter) -> None:
"""Members can access the tags, if not `forbid_extra_keys`."""

@define
class C:
_type: str = ""

union = Union[A, B, C]
configure_tagged_union(union, converter, default=C)
assert converter.unstructure(A(1), union) == {"_type": "A", "a": 1}
assert converter.unstructure(B("1"), union) == {"_type": "B", "a": "1"}

# No tag, so should structure as C.
assert converter.structure({"a": 1}, union) == C()
# Wrong tag, so should again structure as C.
assert converter.structure({"_type": "D", "a": 1}, union) == C("D")

assert converter.structure({"_type": "A", "a": 1}, union) == A(1)
assert converter.structure({"_type": "B", "a": 1}, union) == B("1")
assert converter.structure({"_type": "C", "a": 1}, union) == C("C")


def test_default_member_validation(converter: BaseConverter) -> None:
"""Default members are structured properly.."""
union = Union[A, B]
Expand Down

0 comments on commit a13fa2e

Please sign in to comment.