Skip to content

Commit

Permalink
Get the request.client under case uvicorn is run with a fd or a unix …
Browse files Browse the repository at this point in the history
…socket (#729)

* Revert "Corrected --proxy-headers client ip/host when using a unix socket (#636)"

This reverts commit a796e1d

* Distinguish case fd/unix socket to return correctly client

* Handle windows case

* Added test for AF_UNIX socket type
Modified MockSocket peername to pass tuples instead of list because socket.getpeername() and socket.getsockname() return tuples

* Black

* Removed test, black works locally but not in CI....

* Same deal on the server side of things

* Test on AF_UNIX only if it is in socket

* Simpler handling

* Removed debug leftovers
  • Loading branch information
euri10 committed Jul 31, 2020
1 parent 8f35b6d commit 918722a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
24 changes: 18 additions & 6 deletions tests/protocols/test_utils.py
Expand Up @@ -29,42 +29,54 @@ def test_get_local_addr_with_socket():
assert get_local_addr(transport) is None

transport = MockTransport(
{"socket": MockSocket(family=socket.AF_INET6, sockname=["::1", 123])}
{"socket": MockSocket(family=socket.AF_INET6, sockname=("::1", 123))}
)
assert get_local_addr(transport) == ("::1", 123)

transport = MockTransport(
{"socket": MockSocket(family=socket.AF_INET, sockname=["123.45.6.7", 123])}
{"socket": MockSocket(family=socket.AF_INET, sockname=("123.45.6.7", 123))}
)
assert get_local_addr(transport) == ("123.45.6.7", 123)

if hasattr(socket, "AF_UNIX"):
transport = MockTransport(
{"socket": MockSocket(family=socket.AF_UNIX, sockname=("127.0.0.1", 8000))}
)
assert get_local_addr(transport) == ("127.0.0.1", 8000)


def test_get_remote_addr_with_socket():
transport = MockTransport({"socket": MockSocket(family=socket.AF_IPX)})
assert get_remote_addr(transport) is None

transport = MockTransport(
{"socket": MockSocket(family=socket.AF_INET6, peername=["::1", 123])}
{"socket": MockSocket(family=socket.AF_INET6, peername=("::1", 123))}
)
assert get_remote_addr(transport) == ("::1", 123)

transport = MockTransport(
{"socket": MockSocket(family=socket.AF_INET, peername=["123.45.6.7", 123])}
{"socket": MockSocket(family=socket.AF_INET, peername=("123.45.6.7", 123))}
)
assert get_remote_addr(transport) == ("123.45.6.7", 123)

if hasattr(socket, "AF_UNIX"):
transport = MockTransport(
{"socket": MockSocket(family=socket.AF_UNIX, peername=("127.0.0.1", 8000))}
)
assert get_remote_addr(transport) == ("127.0.0.1", 8000)


def test_get_local_addr():
transport = MockTransport({"sockname": "path/to/unix-domain-socket"})
assert get_local_addr(transport) is None

transport = MockTransport({"sockname": ["123.45.6.7", 123]})
transport = MockTransport({"sockname": ("123.45.6.7", 123)})
assert get_local_addr(transport) == ("123.45.6.7", 123)


def test_get_remote_addr():
transport = MockTransport({"peername": None})
assert get_remote_addr(transport) is None

transport = MockTransport({"peername": ["123.45.6.7", 123]})
transport = MockTransport({"peername": ("123.45.6.7", 123)})
assert get_remote_addr(transport) == ("123.45.6.7", 123)
22 changes: 4 additions & 18 deletions uvicorn/protocols/utils.py
@@ -1,29 +1,17 @@
import socket
import urllib

if hasattr(socket, "AF_UNIX"):
SUPPORTED_SOCKET_FAMILIES = (socket.AF_INET, socket.AF_INET6, socket.AF_UNIX)
else:
SUPPORTED_SOCKET_FAMILIES = (socket.AF_INET, socket.AF_INET6)


def get_remote_addr(transport):
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
try:
info = socket_info.getpeername()
return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
except OSError:
# This case appears to inconsistently occur with uvloop
# bound to a unix domain socket.
family = None
info = None
else:
family = socket_info.family

if family in SUPPORTED_SOCKET_FAMILIES:
return (str(info[0]), int(info[1]))
return None

return None
info = transport.get_extra_info("peername")
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
return (str(info[0]), int(info[1]))
Expand All @@ -34,10 +22,8 @@ def get_local_addr(transport):
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
info = socket_info.getsockname()
family = socket_info.family
if family in SUPPORTED_SOCKET_FAMILIES:
return (str(info[0]), int(info[1]))
return None

return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
info = transport.get_extra_info("sockname")
if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
return (str(info[0]), int(info[1]))
Expand Down

0 comments on commit 918722a

Please sign in to comment.