Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLSUpgradeProto: don't set multiple results for an event #1117

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import re
import socket
import textwrap
import time
import traceback
Expand Down Expand Up @@ -525,3 +526,42 @@ def connect_standby(cls, **kwargs):
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)


class InstrumentedServer:
"""
A socket server for testing.
It will write each item from `data`, and wait for the corresponding event
in `received_events` to notify that it was received before writing the next
item from `data`.
"""
def __init__(self, data, received_events):
assert len(data) == len(received_events)
self._data = data
self._server = None
self._received_events = received_events

async def _handle_client(self, _reader, writer):
for datum, received_event in zip(self._data, self._received_events):
writer.write(datum)
await writer.drain()
await received_event.wait()

writer.close()
await writer.wait_closed()

async def start(self):
"""Start the server."""
self._server = await asyncio.start_server(self._handle_client, 'localhost', 0)
assert self._server.sockets
sock = self._server.sockets[0]
# Account for IPv4 and IPv6
addr, port = sock.getsockname()[:2]
return {
'host': addr,
'port': port,
}

def stop(self):
"""Stop the server."""
self._server.close()
5 changes: 5 additions & 0 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,11 @@ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
self.ssl_is_advisory = ssl_is_advisory

def data_received(self, data):
if self.on_data.done():
# Only expect to receive one byte here; ignore unsolicited further
# data.
return

if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
Expand Down
55 changes: 55 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import contextlib
import copy
import gc
import ipaddress
import os
Expand All @@ -17,11 +18,13 @@
import stat
import tempfile
import textwrap
import time
import unittest
import unittest.mock
import urllib.parse
import warnings
import weakref
from unittest import mock

import asyncpg
from asyncpg import _testbase as tb
Expand Down Expand Up @@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self):
await con.close()


class TestMisbehavingServer(tb.TestCase):
"""Tests for client connection behaviour given a misbehaving server."""

async def test_tls_upgrade_extra_data_received(self):
data = [
# First, the server writes b"S" to signal it is willing to perform
# SSL
b"S",
# Then, the server writes an unsolicted arbitrary byte afterwards
b"N",
]
data_received_events = [asyncio.Event() for _ in data]

# Patch out the loop's create_connection so we can instrument the proto
# we return.
old_create_conn = self.loop.create_connection

async def _mock_create_conn(*args, **kwargs):
transport, proto = await old_create_conn(*args, **kwargs)
old_data_received = proto.data_received

num_received = 0

def _data_received(*args, **kwargs):
nonlocal num_received
# Call the original data_received method
ret = old_data_received(*args, **kwargs)
# Fire the event to signal we've received this datum now.
data_received_events[num_received].set()
num_received += 1
return ret

proto.data_received = _data_received

# To deterministically provoke the race we're interested in for
# this regression test, wait for all data to be received before
# returning from create_connection().
await data_received_events[-1].wait()
return transport, proto

server = tb.InstrumentedServer(data, data_received_events)
conn_spec = await server.start()

# The call to connect() should raise a ConnectionResetError as the
# server will close the connection after writing all the data.
with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn),
self.assertRaises(ConnectionResetError)):
await pg_connection.connect(**conn_spec, ssl=True, loop=self.loop)

server.stop()


def _get_connected_host(con):
peername = con._transport.get_extra_info('peername')
if isinstance(peername, tuple):
Expand Down