From e4a23316f8b8ff833258a738481098f42cdcc366 Mon Sep 17 00:00:00 2001 From: Randall Leeds Date: Thu, 28 Dec 2023 00:57:50 -0800 Subject: [PATCH] Use plain socket objects instead of wrapper classes Refactor socket creation to remove the socket wrapper classes so that these objects have less surprising behavior when used in worker hooks, worker classes, and custom applications. Close #3013. --- gunicorn/arbiter.py | 4 +- gunicorn/config.py | 4 +- gunicorn/sock.py | 264 ++++++++++++++++---------------------------- tests/test_sock.py | 50 +++++---- 4 files changed, 132 insertions(+), 190 deletions(-) diff --git a/gunicorn/arbiter.py b/gunicorn/arbiter.py index 008a54efe..cee55bd1d 100644 --- a/gunicorn/arbiter.py +++ b/gunicorn/arbiter.py @@ -154,7 +154,7 @@ def start(self): self.LISTENERS = sock.create_sockets(self.cfg, self.log, fds) - listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS]) + listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS]) self.log.debug("Arbiter booted") self.log.info("Listening at: %s (%s)", listeners_str, self.pid) self.log.info("Using worker: %s", self.cfg.worker_class_str) @@ -461,7 +461,7 @@ def reload(self): lnr.close() # init new listeners self.LISTENERS = sock.create_sockets(self.cfg, self.log) - listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS]) + listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS]) self.log.info("Listening at: %s", listeners_str) # do some actions on reload diff --git a/gunicorn/config.py b/gunicorn/config.py index e7e4fac54..cfc2d83f8 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2076,7 +2076,7 @@ class KeyFile(Setting): section = "SSL" cli = ["--keyfile"] meta = "FILE" - validator = validate_string + validator = validate_file_exists default = None desc = """\ SSL key file @@ -2088,7 +2088,7 @@ class CertFile(Setting): section = "SSL" cli = ["--certfile"] meta = "FILE" - validator = validate_string + validator = validate_file_exists default = None desc = """\ SSL certificate file diff --git a/gunicorn/sock.py b/gunicorn/sock.py index 7700146a8..2e5d88966 100644 --- a/gunicorn/sock.py +++ b/gunicorn/sock.py @@ -14,130 +14,56 @@ from gunicorn import util -class BaseSocket(object): - - def __init__(self, address, conf, log, fd=None): - self.log = log - self.conf = conf - - self.cfg_addr = address - if fd is None: - sock = socket.socket(self.FAMILY, socket.SOCK_STREAM) - bound = False +def _get_socket_family(addr): + if isinstance(addr, tuple): + if util.is_ipv6(addr[0]): + return socket.AF_INET6 else: - sock = socket.fromfd(fd, self.FAMILY, socket.SOCK_STREAM) - os.close(fd) - bound = True - - self.sock = self.set_options(sock, bound=bound) - - def __str__(self): - return "" % self.sock.fileno() - - def __getattr__(self, name): - return getattr(self.sock, name) - - def set_options(self, sock, bound=False): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if (self.conf.reuse_port - and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - except socket.error as err: - if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL): - raise - if not bound: - self.bind(sock) - sock.setblocking(0) + return socket.AF_INET - # make sure that the socket can be inherited - if hasattr(sock, "set_inheritable"): - sock.set_inheritable(True) + if isinstance(addr, (str, bytes)): + return socket.AF_UNIX - sock.listen(self.conf.backlog) - return sock + raise TypeError("Unable to determine socket family for: %r" % addr) - def bind(self, sock): - sock.bind(self.cfg_addr) - def close(self): - if self.sock is None: - return +def create_socket(conf, log, addr): + family = _get_socket_family(addr) + if family is socket.AF_UNIX: + # remove any existing socket at the given path try: - self.sock.close() - except socket.error as e: - self.log.info("Error while closing socket %s", str(e)) - - self.sock = None - - -class TCPSocket(BaseSocket): - - FAMILY = socket.AF_INET - - def __str__(self): - if self.conf.is_ssl: - scheme = "https" + st = os.stat(addr) + except OSError as e: + if e.args[0] != errno.ENOENT: + raise else: - scheme = "http" - - addr = self.sock.getsockname() - return "%s://%s:%d" % (scheme, addr[0], addr[1]) - - def set_options(self, sock, bound=False): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - return super().set_options(sock, bound=bound) - - -class TCP6Socket(TCPSocket): - - FAMILY = socket.AF_INET6 - - def __str__(self): - (host, port, _, _) = self.sock.getsockname() - return "http://[%s]:%d" % (host, port) - - -class UnixSocket(BaseSocket): - - FAMILY = socket.AF_UNIX - - def __init__(self, addr, conf, log, fd=None): - if fd is None: - try: - st = os.stat(addr) - except OSError as e: - if e.args[0] != errno.ENOENT: - raise + if stat.S_ISSOCK(st.st_mode): + os.remove(addr) else: - if stat.S_ISSOCK(st.st_mode): - os.remove(addr) - else: - raise ValueError("%r is not a socket" % addr) - super().__init__(addr, conf, log, fd=fd) - - def __str__(self): - return "unix:%s" % self.cfg_addr - - def bind(self, sock): - old_umask = os.umask(self.conf.umask) - sock.bind(self.cfg_addr) - util.chown(self.cfg_addr, self.conf.uid, self.conf.gid) - os.umask(old_umask) + raise ValueError("%r is not a socket" % addr) + for i in range(5): + try: + sock = socket.socket(family) + sock.bind(addr) + sock.listen(conf.backlog) + if family is socket.AF_UNIX: + util.chown(addr, conf.uid, conf.gid) + return sock + except socket.error as e: + if e.args[0] == errno.EADDRINUSE: + log.error("Connection in use: %s", str(addr)) + if e.args[0] == errno.EADDRNOTAVAIL: + log.error("Invalid address: %s", str(addr)) + if i < 5: + msg = "connection to {addr} failed: {error}" + log.debug(msg.format(addr=str(addr), error=str(e))) + log.error("Retrying in 1 second.") + time.sleep(1) -def _sock_type(addr): - if isinstance(addr, tuple): - if util.is_ipv6(addr[0]): - sock_type = TCP6Socket - else: - sock_type = TCPSocket - elif isinstance(addr, (str, bytes)): - sock_type = UnixSocket - else: - raise TypeError("Unable to create socket from: %r" % addr) - return sock_type + log.error("Can't connect to %s", str(addr)) + sys.exit(1) def create_sockets(conf, log, fds=None): @@ -150,67 +76,71 @@ def create_sockets(conf, log, fds=None): """ listeners = [] - # get it only once - addr = conf.address - fdaddr = [bind for bind in addr if isinstance(bind, int)] if fds: - fdaddr += list(fds) - laddr = [bind for bind in addr if not isinstance(bind, int)] - - # check ssl config early to raise the error on startup - # only the certfile is needed since it can contains the keyfile - if conf.certfile and not os.path.exists(conf.certfile): - raise ValueError('certfile "%s" does not exist' % conf.certfile) - - if conf.keyfile and not os.path.exists(conf.keyfile): - raise ValueError('keyfile "%s" does not exist' % conf.keyfile) - - # sockets are already bound - if fdaddr: - for fd in fdaddr: - sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) - sock_name = sock.getsockname() - sock_type = _sock_type(sock_name) - listener = sock_type(sock_name, conf, log, fd=fd) - listeners.append(listener) - - return listeners - - # no sockets is bound, first initialization of gunicorn in this env. - for addr in laddr: - sock_type = _sock_type(addr) - sock = None - for i in range(5): - try: - sock = sock_type(addr, conf, log) - except socket.error as e: - if e.args[0] == errno.EADDRINUSE: - log.error("Connection in use: %s", str(addr)) - if e.args[0] == errno.EADDRNOTAVAIL: - log.error("Invalid address: %s", str(addr)) - if i < 5: - msg = "connection to {addr} failed: {error}" - log.debug(msg.format(addr=str(addr), error=str(e))) - log.error("Retrying in 1 second.") - time.sleep(1) - else: - break - - if sock is None: - log.error("Can't connect to %s", str(addr)) - sys.exit(1) - - listeners.append(sock) + # sockets are already bound + listeners = [] + for fd in list(fds) + [a for a in conf.address if isinstance(a, int)]: + sock = socket.socket(fileno=fd) + set_socket_options(conf, sock) + listeners.append(sock) + else: + # first initialization of gunicorn + old_umask = os.umask(conf.umask) + try: + for addr in [bind for bind in conf.address if not isinstance(bind, int)]: + sock = _create_socket(conf, log, addr) + set_socket_options(conf, sock) + listeners.append(sock) + finally: + os.umask(old_umask) return listeners def close_sockets(listeners, unlink=True): for sock in listeners: - sock_name = sock.getsockname() - sock.close() - if unlink and _sock_type(sock_name) is UnixSocket: - os.unlink(sock_name) + try: + if unlink and sock.family is socket.AF_UNIX: + sock_name = sock.getsockname() + os.unlink(sock_name) + finally: + sock.close() + + +def get_uri(listener, is_ssl=False): + addr = listener.getsockname() + family = _get_socket_family(addr) + scheme = "https" if is_ssl else "http" + + if family is socket.AF_INET: + (host, port) = listener.getsockname() + return f"{scheme}://{host}:{port}" + + if family is socket.AF_INET6: + (host, port, _, _) = listener.getsockname() + return f"{scheme}://[{host}]:{port}" + + if family is socket.AF_UNIX: + path = listener.getsockname() + return f"unix://{path}" + + +def set_socket_options(conf, sock): + sock.setblocking(False) + + # make sure that the socket can be inherited + if hasattr(sock, "set_inheritable"): + sock.set_inheritable(True) + + if sock.family in (socket.AF_INET, socket.AF_INET6): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if (conf.reuse_port and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except socket.error as err: + if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL): + raise def ssl_context(conf): diff --git a/tests/test_sock.py b/tests/test_sock.py index adc348c6f..74ac1e76f 100644 --- a/tests/test_sock.py +++ b/tests/test_sock.py @@ -3,30 +3,42 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. +import socket from unittest import mock -from gunicorn import sock +import pytest +from gunicorn import sock -@mock.patch('os.stat') -def test_create_sockets_unix_bytes(stat): - conf = mock.Mock(address=[b'127.0.0.1:8000']) - log = mock.Mock() - with mock.patch.object(sock.UnixSocket, '__init__', lambda *args: None): - listeners = sock.create_sockets(conf, log) - assert len(listeners) == 1 - print(type(listeners[0])) - assert isinstance(listeners[0], sock.UnixSocket) +@pytest.fixture(scope='function') +def addr(request, tmp_path): + if isinstance(request.param, str): + return str(tmp_path / request.param) + return request.param -@mock.patch('os.stat') -def test_create_sockets_unix_strings(stat): - conf = mock.Mock(address=['127.0.0.1:8000']) +@pytest.mark.parametrize( + 'addr, family', + [ + ('gunicorn.sock', socket.AF_UNIX), + (('0.0.0.0', 0), socket.AF_INET), + (('::', 0), socket.AF_INET6), + ('gunicorn.sock', socket.AF_UNIX), + ], + indirect=['addr'], +) +@mock.patch('socket.socket') +@mock.patch('gunicorn.util.chown') +def test_create_socket(chown, socket, addr, family): + conf = mock.Mock(address=[addr], umask=0o22) log = mock.Mock() - with mock.patch.object(sock.UnixSocket, '__init__', lambda *args: None): - listeners = sock.create_sockets(conf, log) - assert len(listeners) == 1 - assert isinstance(listeners[0], sock.UnixSocket) + listener = sock.create_socket(conf, log, addr) + assert listener == socket.return_value + socket.assert_called_with(family) + listener.bind.assert_called_with(addr) + listener.listen.assert_called_with(conf.backlog) + if family is socket.AF_UNIX: + chown.assert_called_with(addr, conf.uid, conf.gid) def test_socket_close(): @@ -41,7 +53,7 @@ def test_socket_close(): @mock.patch('os.unlink') def test_unix_socket_close_unlink(unlink): - listener = mock.Mock() + listener = mock.Mock(family=socket.AF_UNIX) listener.getsockname.return_value = '/var/run/test.sock' sock.close_sockets([listener]) listener.close.assert_called_with() @@ -50,7 +62,7 @@ def test_unix_socket_close_unlink(unlink): @mock.patch('os.unlink') def test_unix_socket_close_without_unlink(unlink): - listener = mock.Mock() + listener = mock.Mock(family=socket.AF_UNIX) listener.getsockname.return_value = '/var/run/test.sock' sock.close_sockets([listener], False) listener.close.assert_called_with()