-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
sock.py
162 lines (130 loc) · 4.95 KB
/
sock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -*- coding: utf-8 -
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
import errno
import os
import socket
import ssl
import stat
import sys
import time
from gunicorn import util
def _get_socket_family(addr):
if isinstance(addr, tuple):
if util.is_ipv6(addr[0]):
return socket.AF_INET6
else:
return socket.AF_INET
if isinstance(addr, (str, bytes)):
return socket.AF_UNIX
raise TypeError("Unable to determine socket family for: %r" % addr)
def create_socket(conf, log, addr):
family = _get_socket_family(addr)
if family is socket.AF_UNIX:
# remove any existing socket at the given path
try:
st = os.stat(addr)
except OSError as e:
if e.args[0] != errno.ENOENT:
raise
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(addr)
else:
raise ValueError("%r is not a socket" % addr)
for i in range(5):
try:
sock = socket.socket(family)
sock.bind(addr)
sock.listen(conf.backlog)
if family is socket.AF_UNIX:
util.chown(addr, conf.uid, conf.gid)
return sock
except socket.error as e:
if e.args[0] == errno.EADDRINUSE:
log.error("Connection in use: %s", str(addr))
if e.args[0] == errno.EADDRNOTAVAIL:
log.error("Invalid address: %s", str(addr))
if i < 5:
msg = "connection to {addr} failed: {error}"
log.debug(msg.format(addr=str(addr), error=str(e)))
log.error("Retrying in 1 second.")
time.sleep(1)
log.error("Can't connect to %s", str(addr))
sys.exit(1)
def create_sockets(conf, log, fds=None):
"""
Create a new socket for the configured addresses or file descriptors.
If a configured address is a tuple then a TCP socket is created.
If it is a string, a Unix socket is created. Otherwise, a TypeError is
raised.
"""
listeners = []
if fds:
# sockets are already bound
listeners = []
for fd in list(fds) + [a for a in conf.address if isinstance(a, int)]:
sock = socket.socket(fileno=fd)
set_socket_options(conf, sock)
listeners.append(sock)
else:
# first initialization of gunicorn
old_umask = os.umask(conf.umask)
try:
for addr in [bind for bind in conf.address if not isinstance(bind, int)]:
sock = _create_socket(conf, log, addr)
set_socket_options(conf, sock)
listeners.append(sock)
finally:
os.umask(old_umask)
return listeners
def close_sockets(listeners, unlink=True):
for sock in listeners:
try:
if unlink and sock.family is socket.AF_UNIX:
sock_name = sock.getsockname()
os.unlink(sock_name)
finally:
sock.close()
def get_uri(listener, is_ssl=False):
addr = listener.getsockname()
family = _get_socket_family(addr)
scheme = "https" if is_ssl else "http"
if family is socket.AF_INET:
(host, port) = listener.getsockname()
return f"{scheme}://{host}:{port}"
if family is socket.AF_INET6:
(host, port, _, _) = listener.getsockname()
return f"{scheme}://[{host}]:{port}"
if family is socket.AF_UNIX:
path = listener.getsockname()
return f"unix://{path}"
def set_socket_options(conf, sock):
sock.setblocking(False)
# make sure that the socket can be inherited
if hasattr(sock, "set_inheritable"):
sock.set_inheritable(True)
if sock.family in (socket.AF_INET, socket.AF_INET6):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if (conf.reuse_port and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except socket.error as err:
if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
raise
def ssl_context(conf):
def default_ssl_context_factory():
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=conf.ca_certs)
context.load_cert_chain(certfile=conf.certfile, keyfile=conf.keyfile)
context.verify_mode = conf.cert_reqs
if conf.ciphers:
context.set_ciphers(conf.ciphers)
return context
return conf.ssl_context(conf, default_ssl_context_factory)
def ssl_wrap_socket(sock, conf):
return ssl_context(conf).wrap_socket(sock,
server_side=True,
suppress_ragged_eofs=conf.suppress_ragged_eofs,
do_handshake_on_connect=conf.do_handshake_on_connect)