Skip to content

Commit

Permalink
Add support for target_session_attrs (#987)
Browse files Browse the repository at this point in the history
This adds support for the `target_session_attrs` connection option.

Co-authored-by: rony batista <rony.batista@revolut.com>
Co-authored-by: Jesse De Loore <jesse@sennac.be>
  • Loading branch information
3 people committed May 8, 2023
1 parent 7443a9e commit bf74e88
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 75 deletions.
90 changes: 90 additions & 0 deletions asyncpg/_testbase/__init__.py
Expand Up @@ -435,3 +435,93 @@ def tearDown(self):
self.con = None
finally:
super().tearDown()


class HotStandbyTestCase(ClusterTestCase):

@classmethod
def setup_cluster(cls):
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(
cls.master_cluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
}
)

con = None

try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(
database='postgres', user='postgres', loop=cls.loop))

cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))

cls.master_cluster.trust_local_replication_by('replication')

conn_spec = cls.master_cluster.get_connection_spec()

cls.standby_cluster = cls.new_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
}
)
cls.start_cluster(
cls.standby_cluster,
server_settings={
'hot_standby': True
}
)

finally:
if con is not None:
cls.loop.run_until_complete(con.close())

@classmethod
def get_cluster_connection_spec(cls, cluster, kwargs={}):
conn_spec = cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec

@classmethod
def get_connection_spec(cls, kwargs={}):
primary_spec = cls.get_cluster_connection_spec(
cls.master_cluster, kwargs
)
standby_spec = cls.get_cluster_connection_spec(
cls.standby_cluster, kwargs
)
return {
'host': [primary_spec['host'], standby_spec['host']],
'port': [primary_spec['port'], standby_spec['port']],
'database': primary_spec['database'],
'user': primary_spec['user'],
**kwargs
}

@classmethod
def connect_primary(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)

@classmethod
def connect_standby(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(
cls.standby_cluster,
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)
2 changes: 1 addition & 1 deletion asyncpg/cluster.py
Expand Up @@ -626,7 +626,7 @@ def init(self, **settings):
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))

if self._pg_version <= (11, 0):
if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
Expand Down
122 changes: 114 additions & 8 deletions asyncpg/connect_utils.py
Expand Up @@ -13,6 +13,7 @@
import os
import pathlib
import platform
import random
import re
import socket
import ssl as ssl_module
Expand Down Expand Up @@ -56,6 +57,7 @@ def parse(cls, sslmode):
'direct_tls',
'connect_timeout',
'server_settings',
'target_session_attrs',
])


Expand Down Expand Up @@ -260,7 +262,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:

def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, connect_timeout, server_settings):
direct_tls, connect_timeout, server_settings,
target_session_attrs):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -607,10 +610,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
'server_settings is expected to be None or '
'a Dict[str, str]')

if target_session_attrs is None:

target_session_attrs = os.getenv(
"PGTARGETSESSIONATTRS", SessionAttribute.any
)
try:

target_session_attrs = SessionAttribute(target_session_attrs)
except ValueError as exc:
raise exceptions.InterfaceError(
"target_session_attrs is expected to be one of "
"{!r}"
", got {!r}".format(
SessionAttribute.__members__.values, target_session_attrs
)
) from exc

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings)
connect_timeout=connect_timeout, server_settings=server_settings,
target_session_attrs=target_session_attrs)

return addrs, params

Expand All @@ -620,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings):

ssl, direct_tls, server_settings,
target_session_attrs):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -649,7 +670,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings)
connect_timeout=timeout, server_settings=server_settings,
target_session_attrs=target_session_attrs)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down Expand Up @@ -882,18 +904,84 @@ async def __connect_addr(
return con


class SessionAttribute(str, enum.Enum):
any = 'any'
primary = 'primary'
standby = 'standby'
prefer_standby = 'prefer-standby'
read_write = "read-write"
read_only = "read-only"


def _accept_in_hot_standby(should_be_in_hot_standby: bool):
"""
If the server didn't report "in_hot_standby" at startup, we must determine
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
If the server allows a connection and states it is in recovery it must
be a replica/standby server.
"""
async def can_be_used(connection):
settings = connection.get_settings()
hot_standby_status = getattr(settings, 'in_hot_standby', None)
if hot_standby_status is not None:
is_in_hot_standby = hot_standby_status == 'on'
else:
is_in_hot_standby = await connection.fetchval(
"SELECT pg_catalog.pg_is_in_recovery()"
)
return is_in_hot_standby == should_be_in_hot_standby

return can_be_used


def _accept_read_only(should_be_read_only: bool):
"""
Verify the server has not set default_transaction_read_only=True
"""
async def can_be_used(connection):
settings = connection.get_settings()
is_readonly = getattr(settings, 'default_transaction_read_only', 'off')

if is_readonly == "on":
return should_be_read_only

return await _accept_in_hot_standby(should_be_read_only)(connection)
return can_be_used


async def _accept_any(_):
return True


target_attrs_check = {
SessionAttribute.any: _accept_any,
SessionAttribute.primary: _accept_in_hot_standby(False),
SessionAttribute.standby: _accept_in_hot_standby(True),
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
SessionAttribute.read_write: _accept_read_only(False),
SessionAttribute.read_only: _accept_read_only(True),
}


async def _can_use_connection(connection, attr: SessionAttribute):
can_use = target_attrs_check[attr]
return await can_use(connection)


async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
target_attr = params.target_session_attrs

candidates = []
chosen_connection = None
last_error = None
addr = None
for addr in addrs:
before = time.monotonic()
try:
return await _connect_addr(
conn = await _connect_addr(
addr=addr,
loop=loop,
timeout=timeout,
Expand All @@ -902,12 +990,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
connection_class=connection_class,
record_class=record_class,
)
candidates.append(conn)
if await _can_use_connection(conn, target_attr):
chosen_connection = conn
break
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
finally:
timeout -= time.monotonic() - before
else:
if target_attr == SessionAttribute.prefer_standby and candidates:
chosen_connection = random.choice(candidates)

await asyncio.gather(
(c.close() for c in candidates if c is not chosen_connection),
return_exceptions=True
)

if chosen_connection:
return chosen_connection

raise last_error
raise last_error or exceptions.TargetServerAttributeNotMatched(
'None of the hosts match the target attribute requirement '
'{!r}'.format(target_attr)
)


async def _cancel(*, loop, addr, params: _ConnectionParameters,
Expand Down
20 changes: 19 additions & 1 deletion asyncpg/connection.py
Expand Up @@ -1792,7 +1792,8 @@ async def connect(dsn=None, *,
direct_tls=False,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None):
server_settings=None,
target_session_attrs=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2003,6 +2004,22 @@ async def connect(dsn=None, *,
this connection object. Must be a subclass of
:class:`~asyncpg.Record`.
:param SessionAttribute target_session_attrs:
If specified, check that the host has the correct attribute.
Can be one of:
"any": the first successfully connected host
"primary": the host must NOT be in hot standby mode
"standby": the host must be in hot standby mode
"read-write": the host must allow writes
"read-only": the host most NOT allow writes
"prefer-standby": first try to find a standby host, but if
none of the listed hosts is a standby server,
return any of them.
If not specified will try to use PGTARGETSESSIONATTRS
from the environment.
Defaults to "any" if no value is set.
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
Expand Down Expand Up @@ -2109,6 +2126,7 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
)


Expand Down
6 changes: 5 additions & 1 deletion asyncpg/exceptions/_base.py
Expand Up @@ -13,7 +13,7 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError')
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -244,6 +244,10 @@ class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""


class TargetServerAttributeNotMatched(InternalClientError):
"""Could not find a host that satisfies the target attribute requirement"""


class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""

Expand Down

0 comments on commit bf74e88

Please sign in to comment.