Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a test for WANT_READ during sendall() #955

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = "-r s --strict-markers"
filterwarnings = ["ignore"]
testpaths = ["tests"]

[tool.ruff]
Expand Down
129 changes: 126 additions & 3 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
return server


def loopback(server_factory=None, client_factory=None):
def loopback(server_factory=None, client_factory=None, blocking=True):
"""
Create a connected socket pair and force two connected SSL sockets
to talk to each other via memory BIOs.
Expand All @@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):

handshake(client, server)

server.setblocking(True)
client.setblocking(True)
server.setblocking(blocking)
client.setblocking(blocking)
return server, client


Expand Down Expand Up @@ -3297,11 +3297,134 @@ def test_memoryview_really_doesnt_overfill(self):
self._doesnt_overfill_test(_make_memoryview)


@pytest.fixture
def nonblocking_tls_connections_pair():
"""Return a non-blocking TLS loopback connections pair."""
return loopback(blocking=False)


@pytest.fixture
def nonblocking_tls_server_connection(nonblocking_tls_connections_pair):
"""Return a non-blocking TLS server socket connected to loopback."""
return nonblocking_tls_connections_pair[0]


@pytest.fixture
def nonblocking_tls_client_connection(nonblocking_tls_connections_pair):
"""Return a non-blocking TLS client socket connected to loopback."""
return nonblocking_tls_connections_pair[1]


class TestConnectionSendall:
"""
Tests for `Connection.sendall`.
"""

def test_want_write(
self,
monkeypatch,
nonblocking_tls_server_connection,
nonblocking_tls_client_connection,
):
msg = b"x"
garbage_size = 1024 * 1024 * 64
large_payload = b"p" * garbage_size * 2
payload_size = len(large_payload)

sent_garbage_size = 0
try:
sent_garbage_size += nonblocking_tls_client_connection.send(
msg * garbage_size,
)
except WantWriteError:
pass
for i in range(garbage_size):
try:
sent_garbage_size += nonblocking_tls_client_connection.send(
msg,
)
except WantWriteError:
break
else:
pytest.fail(
"Failed to fill socket buffer, cannot test "
"'want write' in `sendall()`"
)
garbage_payload = sent_garbage_size * msg

def consume_garbage(conn):
assert patched_ssl_write.want_write_counter >= 1
assert not consume_garbage.garbage_consumed

while len(consume_garbage.consumed) < sent_garbage_size:
try:
consume_garbage.consumed += conn.recv(
sent_garbage_size - len(consume_garbage.consumed),
)
except WantReadError:
pass

assert consume_garbage.consumed == garbage_payload

consume_garbage.garbage_consumed = True

consume_garbage.garbage_consumed = False
consume_garbage.consumed = b""

def consume_payload(conn):
try:
consume_payload.consumed += conn.recv(payload_size)
except WantReadError:
pass

consume_payload.consumed = b""

original_ssl_write = _lib.SSL_write

def patched_ssl_write(ctx, data, size):
write_result = original_ssl_write(ctx, data, size)
try:
nonblocking_tls_client_connection._raise_ssl_error(
ctx,
write_result,
)
except WantWriteError:
patched_ssl_write.want_write_counter += 1
consume_data_on_server = (
consume_payload
if consume_garbage.garbage_consumed
else consume_garbage
)

consume_data_on_server(nonblocking_tls_server_connection)
# NOTE: We don't re-raise it as the calling code will do
# NOTE: the same after the call.
return write_result

patched_ssl_write.want_write_counter = 0

# NOTE: Make the client think it needs a handshake so that it'll
# NOTE: attempt to `do_handshake()` on the next `SSL_write()`
# NOTE: that originates from `sendall()`:
nonblocking_tls_client_connection.set_connect_state()
try:
nonblocking_tls_client_connection.do_handshake()
except WantWriteError:
assert True # Sanity check
except:
assert False # This should never happen (see the note above)

with monkeypatch.context() as mp_ctx:
mp_ctx.setattr(_lib, "SSL_write", patched_ssl_write)
nonblocking_tls_client_connection.sendall(large_payload)

assert consume_garbage.garbage_consumed

# NOTE: Read the leftover data from the very last `SSL_write()`
consume_payload(nonblocking_tls_server_connection)

assert consume_payload.consumed == large_payload

def test_wrong_args(self):
"""
When called with arguments other than a string argument for its first
Expand Down