Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reset the __eq__ and __hash__ of HTTPConnection to allow WebSockets to be added to … #1039

Merged
merged 3 commits into from Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions starlette/requests.py
Expand Up @@ -74,6 +74,12 @@ def __iter__(self) -> typing.Iterator[str]:
def __len__(self) -> int:
return len(self.scope)

# Don't use the `abc.Mapping.__eq__` implementation.
# Connection instances should never be considered equal
# unless `self is other`.
__eq__ = object.__eq__
graingert marked this conversation as resolved.
Show resolved Hide resolved
__hash__ = object.__hash__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the hash method on here already inherit object.__hash__? What am I missing?

Copy link
Member Author

@graingert graingert Jun 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mapping overrides __eq__ which disables the default __hash__

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mapping overrides eq

Ah gotcha, thanks.
Does that mean we actually only need to override __eq__ then?

What do we think to something like the following?...

    # Don't use the `abc.Mapping.__eq__` implementation.
    # Connection instances should never be considered equal
    # unless `self is other`.
    __eq__ = object.__eq__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to re-enable __hash__ but I'll double check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomchristie yep you need both:

from __future__ import annotations

import collections.abc
import pytest

def test_default():
    class HTTPConnection(collections.abc.Mapping):
        def __getitem__(self, *args, **kwargs):
            return None

        def __iter__(self, *args, **kwargs):
            return self

        def __next__(self, *args, **kwargs):
            raise StopIteration

        def __len__(self, *args, **kwargs):
            return 0


    assert HTTPConnection() == HTTPConnection()
    with pytest.raises(TypeError, match=r"unhashable type: 'HTTPConnection'"):
        connections = {HTTPConnection()}


def test_eq():
    class HTTPConnection(collections.abc.Mapping):

        __eq__ = object.__eq__

        def __getitem__(self, *args, **kwargs):
            return None

        def __iter__(self, *args, **kwargs):
            return self

        def __next__(self, *args, **kwargs):
            raise StopIteration

        def __len__(self, *args, **kwargs):
            return 0


    assert HTTPConnection() != HTTPConnection()
    h = HTTPConnection()
    assert h == h

    with pytest.raises(TypeError, match=r"unhashable type: 'HTTPConnection'"):
        connections = {HTTPConnection()}



def test_eq_and_hash():
    class HTTPConnection(collections.abc.Mapping):

        __eq__ = object.__eq__
        __hash__ = object.__hash__

        def __getitem__(self, *args, **kwargs):
            return None

        def __iter__(self, *args, **kwargs):
            return self

        def __next__(self, *args, **kwargs):
            raise StopIteration

        def __len__(self, *args, **kwargs):
            return 0


    assert HTTPConnection() != HTTPConnection()
    h = HTTPConnection()
    assert h == h

    assert {HTTPConnection()} != {HTTPConnection()}
    assert h in {h}
    assert {h} == {h}


@property
def app(self) -> typing.Any:
return self.scope["app"]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_websockets.py
Expand Up @@ -368,3 +368,13 @@ async def mock_send(message):
assert websocket["type"] == "websocket"
assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
assert len(websocket) == 3

# check __eq__ and __hash__
assert websocket != WebSocket(
{"type": "websocket", "path": "/abc/", "headers": []},
receive=mock_receive,
send=mock_send,
)
assert websocket == websocket
assert websocket in {websocket}
assert {websocket} == {websocket}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's all neat enough. 👍