Skip to content

Commit

Permalink
Merge pull request #120 from dlax/waiting
Browse files Browse the repository at this point in the history
support RW ready in waiting functions
  • Loading branch information
dvarrazzo committed Nov 16, 2021
2 parents 3cf7d2b + 4cf0450 commit 2205490
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 34 deletions.
66 changes: 35 additions & 31 deletions psycopg/psycopg/waiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Wait(IntEnum):
class Ready(IntEnum):
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE


def wait_selector(
Expand Down Expand Up @@ -59,6 +60,7 @@ def wait_selector(
sel.unregister(fileno)
# note: this line should require a cast, but mypy doesn't complain
ready: Ready = rlist[0][1]
assert s & ready
s = gen.send(ready)

except StopIteration as ex:
Expand Down Expand Up @@ -118,29 +120,29 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:

def wakeup(state: Ready) -> None:
nonlocal ready
ready = state
ready |= state # type: ignore[assignment]
ev.set()

try:
s = next(gen)
while 1:
reader = s & Wait.R
writer = s & Wait.W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
if s == Wait.R:
loop.add_reader(fileno, wakeup, Ready.R)
await ev.wait()
loop.remove_reader(fileno)
elif s == Wait.W:
loop.add_writer(fileno, wakeup, Ready.W)
await ev.wait()
loop.remove_writer(fileno)
elif s == Wait.RW:
ready = 0 # type: ignore[assignment]
if reader:
loop.add_reader(fileno, wakeup, Ready.R)
if writer:
loop.add_writer(fileno, wakeup, Ready.W)
try:
await ev.wait()
loop.remove_reader(fileno)
loop.remove_writer(fileno)
else:
raise e.InternalError("bad poll status: %s")
finally:
if reader:
loop.remove_reader(fileno)
if writer:
loop.remove_writer(fileno)
s = gen.send(ready)

except StopIteration as ex:
Expand Down Expand Up @@ -179,23 +181,23 @@ def wakeup(state: Ready) -> None:
try:
fileno, s = next(gen)
while 1:
reader = s & Wait.R
writer = s & Wait.W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
if s == Wait.R:
ready = 0 # type: ignore[assignment]
if reader:
loop.add_reader(fileno, wakeup, Ready.R)
await wait_for(ev.wait(), timeout)
loop.remove_reader(fileno)
elif s == Wait.W:
if writer:
loop.add_writer(fileno, wakeup, Ready.W)
try:
await wait_for(ev.wait(), timeout)
loop.remove_writer(fileno)
elif s == Wait.RW:
loop.add_reader(fileno, wakeup, Ready.R)
loop.add_writer(fileno, wakeup, Ready.W)
await wait_for(ev.wait(), timeout)
loop.remove_reader(fileno)
loop.remove_writer(fileno)
else:
raise e.InternalError("bad poll status: %s")
finally:
if reader:
loop.remove_reader(fileno)
if writer:
loop.remove_writer(fileno)
fileno, s = gen.send(ready)

except TimeoutError:
Expand Down Expand Up @@ -232,11 +234,13 @@ def wait_epoll(
while not fileevs:
fileevs = epoll.poll(timeout)
ev = fileevs[0][1]
ready = 0
if ev & ~select.EPOLLOUT:
s = Ready.R
else:
s = Ready.W
s = gen.send(s)
ready = Ready.R
if ev & ~select.EPOLLIN:
ready |= Ready.W
assert s & ready
s = gen.send(ready)
evmask = poll_evmasks[s]
epoll.modify(fileno, evmask)

Expand Down
48 changes: 45 additions & 3 deletions tests/test_waiting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import select
import socket
import sys

import pytest

Expand All @@ -8,9 +10,8 @@
from psycopg.pq import ConnStatus, ExecStatus


skip_no_epoll = pytest.mark.skipif(
not hasattr(select, "epoll"), reason="epoll not available"
)
hasepoll = hasattr(select, "epoll")
skip_no_epoll = pytest.mark.skipif(not hasepoll, reason="epoll not available")

timeouts = [
{},
Expand All @@ -21,6 +22,11 @@
]


skip_if_not_linux = pytest.mark.skipif(
not sys.platform.startswith("linux"), reason="non-Linux platform"
)


@pytest.mark.parametrize("timeout", timeouts)
def test_wait_conn(dsn, timeout, retries):
for retry in retries:
Expand All @@ -44,6 +50,29 @@ def test_wait(pgconn, timeout):
assert res.status == ExecStatus.TUPLES_OK


waits_and_ids = [
(waiting.wait, "wait"),
(waiting.wait_selector, "wait_selector"),
]
if hasepoll:
waits_and_ids.append((waiting.wait_epoll, "wait_epoll"))

waits, wids = list(zip(*waits_and_ids))


@pytest.mark.parametrize("waitfn", waits, ids=wids)
@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
@skip_if_not_linux
def test_wait_ready(waitfn, wait, ready):
def gen():
r = yield wait
return r

with socket.socket() as s:
r = waitfn(gen(), s.fileno())
assert r & ready


@pytest.mark.parametrize("timeout", timeouts)
def test_wait_selector(pgconn, timeout):
pgconn.send_query(b"select 1")
Expand Down Expand Up @@ -100,6 +129,19 @@ async def test_wait_async(pgconn):
assert res.status == ExecStatus.TUPLES_OK


@pytest.mark.asyncio
@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
@skip_if_not_linux
async def test_wait_ready_async(wait, ready):
def gen():
r = yield wait
return r

with socket.socket() as s:
r = await waiting.wait_async(gen(), s.fileno())
assert r & ready


@pytest.mark.asyncio
async def test_wait_async_bad(pgconn):
pgconn.send_query(b"select 1")
Expand Down

0 comments on commit 2205490

Please sign in to comment.