diff --git a/.github/workflows/install-krb5.sh b/.github/workflows/install-krb5.sh new file mode 100755 index 00000000..093b8519 --- /dev/null +++ b/.github/workflows/install-krb5.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -Eexuo pipefail + +if [ "$RUNNER_OS" == "Linux" ]; then + # Assume Ubuntu since this is the only Linux used in CI. + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + libkrb5-dev krb5-user krb5-kdc krb5-admin-server +fi diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7fc77b38..b7229e18 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,6 +62,7 @@ jobs: - name: Install Python Deps if: steps.release.outputs.version == 0 run: | + .github/workflows/install-krb5.sh python -m pip install -U pip setuptools wheel python -m pip install -e .[test] @@ -122,6 +123,7 @@ jobs: - name: Install Python Deps if: steps.release.outputs.version == 0 run: | + .github/workflows/install-krb5.sh python -m pip install -U pip setuptools wheel python -m pip install -e .[test] diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..8039d1b4 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -56,6 +56,7 @@ def parse(cls, sslmode): 'direct_tls', 'server_settings', 'target_session_attrs', + 'krbsrvname', ]) @@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, direct_tls, server_settings, - target_session_attrs): + target_session_attrs, krbsrvname): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs + if 'krbsrvname' in query: + val = query.pop('krbsrvname') + if krbsrvname is None: + krbsrvname = val + if query: if server_settings is None: server_settings = query @@ -650,11 +656,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ) ) from None + if krbsrvname is None: + krbsrvname = os.getenv('PGKRBSRVNAME') + params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname) return addrs, params @@ -665,7 +675,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs): + target_session_attrs, krbsrvname): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -694,7 +704,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname) config = _ClientConfiguration( command_timeout=command_timeout, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0367e365..bf5f6db6 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2007,7 +2007,8 @@ async def connect(dsn=None, *, connection_class=Connection, record_class=protocol.Record, server_settings=None, - target_session_attrs=None): + target_session_attrs=None, + krbsrvname=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2235,6 +2236,10 @@ async def connect(dsn=None, *, or the value of the ``PGTARGETSESSIONATTRS`` environment variable, or ``"any"`` if neither is specified. + :param str krbsrvname: + Kerberos service name to use when authenticating with GSSAPI. This + must match the server configuration. Defaults to 'postgres'. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2303,6 +2308,9 @@ async def connect(dsn=None, *, .. versionchanged:: 0.28.0 Added the *target_session_attrs* parameter. + .. versionchanged:: 0.30.0 + Added the *krbsrvname* parameter. + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext .. _create_default_context: https://docs.python.org/3/library/ssl.html#ssl.create_default_context @@ -2344,7 +2352,8 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, - target_session_attrs=target_session_attrs + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname, ) diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 7ce4f574..612d8cae 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -51,16 +51,6 @@ cdef enum AuthenticationMessage: AUTH_SASL_FINAL = 12 -AUTH_METHOD_NAME = { - AUTH_REQUIRED_KERBEROS: 'kerberosv5', - AUTH_REQUIRED_PASSWORD: 'password', - AUTH_REQUIRED_PASSWORDMD5: 'md5', - AUTH_REQUIRED_GSS: 'gss', - AUTH_REQUIRED_SASL: 'scram-sha-256', - AUTH_REQUIRED_SSPI: 'sspi', -} - - cdef enum ResultType: RESULT_OK = 1 RESULT_FAILED = 2 @@ -96,10 +86,13 @@ cdef class CoreProtocol: object transport + object address # Instance of _ConnectionParameters object con_params # Instance of SCRAMAuthentication SCRAMAuthentication scram + # Instance of gssapi.SecurityContext + object gss_ctx readonly int32_t backend_pid readonly int32_t backend_secret @@ -145,6 +138,8 @@ cdef class CoreProtocol: cdef _auth_password_message_md5(self, bytes salt) cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods) cdef _auth_password_message_sasl_continue(self, bytes server_response) + cdef _auth_gss_init(self) + cdef _auth_gss_step(self, bytes server_response) cdef _write(self, buf) cdef _writelines(self, list buffers) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 64afe934..7a2b257e 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -6,14 +6,26 @@ import hashlib +import socket include "scram.pyx" +cdef dict AUTH_METHOD_NAME = { + AUTH_REQUIRED_KERBEROS: 'kerberosv5', + AUTH_REQUIRED_PASSWORD: 'password', + AUTH_REQUIRED_PASSWORDMD5: 'md5', + AUTH_REQUIRED_GSS: 'gss', + AUTH_REQUIRED_SASL: 'scram-sha-256', + AUTH_REQUIRED_SSPI: 'sspi', +} + + cdef class CoreProtocol: - def __init__(self, con_params): + def __init__(self, addr, con_params): + self.address = addr # type of `con_params` is `_ConnectionParameters` self.buffer = ReadBuffer() self.user = con_params.user @@ -26,6 +38,8 @@ cdef class CoreProtocol: self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` self.scram = None + # type of `gss_ctx` is `gssapi.SecurityContext` + self.gss_ctx = None self._reset_result() @@ -619,9 +633,17 @@ cdef class CoreProtocol: 'could not verify server signature for ' 'SCRAM authentciation: scram-sha-256', ) + self.scram = None + + elif status == AUTH_REQUIRED_GSS: + self._auth_gss_init() + self.auth_msg = self._auth_gss_step(None) + + elif status == AUTH_REQUIRED_GSS_CONTINUE: + server_response = self.buffer.consume_message() + self.auth_msg = self._auth_gss_step(server_response) elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, - AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, AUTH_REQUIRED_SSPI): self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( @@ -634,7 +656,8 @@ cdef class CoreProtocol: 'unsupported authentication method requested by the ' 'server: {}'.format(status)) - if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: + if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL, + AUTH_REQUIRED_GSS_CONTINUE]: self.buffer.discard_message() cdef _auth_password_message_cleartext(self): @@ -691,6 +714,40 @@ cdef class CoreProtocol: return msg + cdef _auth_gss_init(self): + try: + import gssapi + except ModuleNotFoundError: + raise RuntimeError( + 'gssapi module not found; please install asyncpg[gssapi] to ' + 'use asyncpg with Kerberos or GSSAPI authentication' + ) from None + + service_name = self.con_params.krbsrvname or 'postgres' + # find the canonical name of the server host + if isinstance(self.address, str): + raise RuntimeError('GSSAPI authentication is only supported for ' + 'TCP/IP connections') + + host = self.address[0] + host_cname = socket.gethostbyname_ex(host)[0] + gss_name = gssapi.Name(f'{service_name}/{host_cname}') + self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate') + + cdef _auth_gss_step(self, bytes server_response): + cdef: + WriteBuffer msg + + token = self.gss_ctx.step(server_response) + if not token: + self.gss_ctx = None + return None + msg = WriteBuffer.new_message(b'p') + msg.write_bytes(token) + msg.end_message() + + return msg + cdef _parse_msg_ready_for_query(self): cdef char status = self.buffer.read_byte() diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd index a9ac8d5f..cd221fbb 100644 --- a/asyncpg/protocol/protocol.pxd +++ b/asyncpg/protocol/protocol.pxd @@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol): cdef: object loop - object address ConnectionSettings settings object cancel_sent_waiter object cancel_waiter diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index b43b0e9c..1459d908 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -75,7 +75,7 @@ NO_TIMEOUT = object() cdef class BaseProtocol(CoreProtocol): def __init__(self, addr, connected_fut, con_params, record_class: type, loop): # type of `con_params` is `_ConnectionParameters` - CoreProtocol.__init__(self, con_params) + CoreProtocol.__init__(self, addr, con_params) self.loop = loop self.transport = None @@ -83,8 +83,7 @@ cdef class BaseProtocol(CoreProtocol): self.cancel_waiter = None self.cancel_sent_waiter = None - self.address = addr - self.settings = ConnectionSettings((self.address, con_params.database)) + self.settings = ConnectionSettings((addr, con_params.database)) self.record_class = record_class self.statement = None diff --git a/pyproject.toml b/pyproject.toml index ed2340a7..8209d838 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,9 +35,14 @@ dependencies = [ github = "https://github.com/MagicStack/asyncpg" [project.optional-dependencies] +gssapi = [ + 'gssapi', +] test = [ 'flake8~=6.1', 'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"', + 'gssapi; platform_system == "Linux"', + 'k5test; platform_system == "Linux"', ] docs = [ 'Sphinx~=5.3.0', diff --git a/tests/test_connect.py b/tests/test_connect.py index 5333e2c5..ebf0e462 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -130,30 +130,22 @@ def test_server_version_02(self): CORRECT_PASSWORD = 'correct\u1680password' -class TestAuthentication(tb.ConnectedTestCase): +class BaseTestAuthentication(tb.ConnectedTestCase): + USERS = [] + def setUp(self): super().setUp() if not self.cluster.is_managed(): self.skipTest('unmanaged cluster') - methods = [ - ('trust', None), - ('reject', None), - ('scram-sha-256', CORRECT_PASSWORD), - ('md5', CORRECT_PASSWORD), - ('password', CORRECT_PASSWORD), - ] - self.cluster.reset_hba() create_script = [] - for method, password in methods: + for username, method, password in self.USERS: if method == 'scram-sha-256' and self.server_version.major < 10: continue - username = method.replace('-', '_') - # if this is a SCRAM password, we need to set the encryption method # to "scram-sha-256" in order to properly hash the password if method == 'scram-sha-256': @@ -162,7 +154,7 @@ def setUp(self): ) create_script.append( - 'CREATE ROLE {}_user WITH LOGIN{};'.format( + 'CREATE ROLE "{}" WITH LOGIN{};'.format( username, f' PASSWORD E{(password or "")!r}' ) @@ -175,20 +167,20 @@ def setUp(self): "SET password_encryption = 'md5';" ) - if _system != 'Windows': + if _system != 'Windows' and method != 'gss': self.cluster.add_hba_entry( type='local', - database='postgres', user='{}_user'.format(username), + database='postgres', user=username, auth_method=method) self.cluster.add_hba_entry( type='host', address=ipaddress.ip_network('127.0.0.0/24'), - database='postgres', user='{}_user'.format(username), + database='postgres', user=username, auth_method=method) self.cluster.add_hba_entry( type='host', address=ipaddress.ip_network('::1/128'), - database='postgres', user='{}_user'.format(username), + database='postgres', user=username, auth_method=method) # Put hba changes into effect @@ -201,28 +193,28 @@ def tearDown(self): # Reset cluster's pg_hba.conf since we've meddled with it self.cluster.trust_local_connections() - methods = [ - 'trust', - 'reject', - 'scram-sha-256', - 'md5', - 'password', - ] - drop_script = [] - for method in methods: + for username, method, _ in self.USERS: if method == 'scram-sha-256' and self.server_version.major < 10: continue - username = method.replace('-', '_') - - drop_script.append('DROP ROLE {}_user;'.format(username)) + drop_script.append('DROP ROLE "{}";'.format(username)) drop_script = '\n'.join(drop_script) self.loop.run_until_complete(self.con.execute(drop_script)) super().tearDown() + +class TestAuthentication(BaseTestAuthentication): + USERS = [ + ('trust_user', 'trust', None), + ('reject_user', 'reject', None), + ('scram_sha_256_user', 'scram-sha-256', CORRECT_PASSWORD), + ('md5_user', 'md5', CORRECT_PASSWORD), + ('password_user', 'password', CORRECT_PASSWORD), + ] + async def _try_connect(self, **kwargs): # On Windows the server sometimes just closes # the connection sooner than we receive the @@ -388,6 +380,62 @@ async def test_auth_md5_unsupported(self, _): await self.connect(user='md5_user', password=CORRECT_PASSWORD) +class TestGssAuthentication(BaseTestAuthentication): + @classmethod + def setUpClass(cls): + try: + from k5test.realm import K5Realm + except ModuleNotFoundError: + raise unittest.SkipTest('k5test not installed') + + cls.realm = K5Realm() + cls.addClassCleanup(cls.realm.stop) + # Setup environment before starting the cluster. + patch = unittest.mock.patch.dict(os.environ, cls.realm.env) + patch.start() + cls.addClassCleanup(patch.stop) + # Add credentials. + cls.realm.addprinc('postgres/localhost') + cls.realm.extract_keytab('postgres/localhost', cls.realm.keytab) + + cls.USERS = [(cls.realm.user_princ, 'gss', None)] + super().setUpClass() + + cls.cluster.override_connection_spec(host='localhost') + + @classmethod + def get_server_settings(cls): + settings = super().get_server_settings() + settings['krb_server_keyfile'] = f'FILE:{cls.realm.keytab}' + return settings + + @classmethod + def setup_cluster(cls): + cls.cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.cluster, server_settings=cls.get_server_settings()) + + async def test_auth_gssapi(self): + conn = await self.connect(user=self.realm.user_princ) + await conn.close() + + # Service name mismatch. + with self.assertRaisesRegex( + exceptions.InternalClientError, + 'Server .* not found' + ): + await self.connect(user=self.realm.user_princ, krbsrvname='wrong') + + # Credentials mismatch. + self.realm.addprinc('wrong_user', 'password') + self.realm.kinit('wrong_user', 'password') + with self.assertRaisesRegex( + exceptions.InvalidAuthorizationSpecificationError, + 'GSSAPI authentication failed for user' + ): + await self.connect(user=self.realm.user_princ) + + class TestConnectParams(tb.TestCase): TESTS = [ @@ -600,6 +648,46 @@ class TestConnectParams(tb.TestCase): }) }, + { + 'name': 'krbsrvname', + 'dsn': 'postgresql://user@host/db?krbsrvname=srv_qs', + 'env': { + 'PGKRBSRVNAME': 'srv_env', + }, + 'result': ([('host', 5432)], { + 'database': 'db', + 'user': 'user', + 'target_session_attrs': 'any', + 'krbsrvname': 'srv_qs', + }) + }, + + { + 'name': 'krbsrvname_2', + 'dsn': 'postgresql://user@host/db?krbsrvname=srv_qs', + 'krbsrvname': 'srv_kws', + 'result': ([('host', 5432)], { + 'database': 'db', + 'user': 'user', + 'target_session_attrs': 'any', + 'krbsrvname': 'srv_kws', + }) + }, + + { + 'name': 'krbsrvname_3', + 'dsn': 'postgresql://user@host/db', + 'env': { + 'PGKRBSRVNAME': 'srv_env', + }, + 'result': ([('host', 5432)], { + 'database': 'db', + 'user': 'user', + 'target_session_attrs': 'any', + 'krbsrvname': 'srv_env', + }) + }, + { 'name': 'dsn_ipv6_multi_host', 'dsn': 'postgresql://user@[2001:db8::1234%25eth0],[::1]/db', @@ -883,6 +971,7 @@ def run_testcase(self, testcase): sslmode = testcase.get('ssl') server_settings = testcase.get('server_settings') target_session_attrs = testcase.get('target_session_attrs') + krbsrvname = testcase.get('krbsrvname') expected = testcase.get('result') expected_error = testcase.get('error') @@ -907,7 +996,8 @@ def run_testcase(self, testcase): passfile=passfile, database=database, ssl=sslmode, direct_tls=False, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname) params = { k: v for k, v in params._asdict().items()