Skip to content

Commit

Permalink
Implement GSSAPI authentication
Browse files Browse the repository at this point in the history
Most commonly used with Kerberos.

Closes: #769
  • Loading branch information
eltoder committed Feb 25, 2024
1 parent c2c8d20 commit f7bd646
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 23 deletions.
16 changes: 12 additions & 4 deletions asyncpg/connect_utils.py
Expand Up @@ -56,6 +56,7 @@ def parse(cls, sslmode):
'direct_tls',
'server_settings',
'target_session_attrs',
'krbsrvname',
])


Expand Down Expand Up @@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, server_settings,
target_session_attrs):
target_session_attrs, krbsrvname):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in query:
val = query.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val

if query:
if server_settings is None:
server_settings = query
Expand Down Expand Up @@ -654,7 +660,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs)
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname)

return addrs, params

Expand All @@ -665,7 +672,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs):
target_session_attrs, krbsrvname):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -694,7 +701,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs)
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down
9 changes: 7 additions & 2 deletions asyncpg/connection.py
Expand Up @@ -2007,7 +2007,8 @@ async def connect(dsn=None, *,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None,
target_session_attrs=None):
target_session_attrs=None,
krbsrvname=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2235,6 +2236,9 @@ async def connect(dsn=None, *,
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
or ``"any"`` if neither is specified.
:param str krbsrvname:
Kerberos service name to use when authenticating with GSSAPI.
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
Expand Down Expand Up @@ -2344,7 +2348,8 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname,
)


Expand Down
15 changes: 5 additions & 10 deletions asyncpg/protocol/coreproto.pxd
Expand Up @@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12


AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}


cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
Expand Down Expand Up @@ -96,10 +86,13 @@ cdef class CoreProtocol:

object transport

object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
# Instance of gssapi.SecurityContext
object gss_ctx

readonly int32_t backend_pid
readonly int32_t backend_secret
Expand Down Expand Up @@ -145,6 +138,8 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _auth_gss_init(self)
cdef _auth_gss_step(self, bytes server_response)

cdef _write(self, buf)
cdef _writelines(self, list buffers)
Expand Down
62 changes: 59 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Expand Up @@ -6,14 +6,26 @@


import hashlib
import socket


include "scram.pyx"


cdef dict AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}


cdef class CoreProtocol:

def __init__(self, con_params):
def __init__(self, addr, con_params):
self.address = addr
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
Expand All @@ -26,6 +38,8 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
# type of `gss_ctx` is `gssapi.SecurityContext`
self.gss_ctx = None

self._reset_result()

Expand Down Expand Up @@ -619,9 +633,17 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
self.scram = None

elif status == AUTH_REQUIRED_GSS:
self._auth_gss_init()
self.auth_msg = self._auth_gss_step(None)

elif status == AUTH_REQUIRED_GSS_CONTINUE:
server_response = self.buffer.consume_message()
self.auth_msg = self._auth_gss_step(server_response)

elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
AUTH_REQUIRED_SSPI):
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
Expand All @@ -634,7 +656,8 @@ cdef class CoreProtocol:
'unsupported authentication method requested by the '
'server: {}'.format(status))

if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
AUTH_REQUIRED_GSS_CONTINUE]:
self.buffer.discard_message()

cdef _auth_password_message_cleartext(self):
Expand Down Expand Up @@ -691,6 +714,39 @@ cdef class CoreProtocol:

return msg

cdef _auth_gss_init(self):
try:
import gssapi
except ModuleNotFoundError:
raise RuntimeError(
'gssapi module not found; please install asyncpg[gss] to use '
'asyncpg with Kerberos or GSSAPI authentication'
) from None

service_name = self.con_params.krbsrvname or 'postgres'
# find the canonical name of the server host
if isinstance(self.address, str):
host = socket.gethostname()
else:
host = self.address[0]
host_cname = socket.gethostbyname_ex(host)[0].rstrip('.') or host
gss_name = gssapi.Name(f'{service_name}/{host_cname}')
self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate')

cdef _auth_gss_step(self, bytes server_response):
cdef:
WriteBuffer msg

token = self.gss_ctx.step(server_response)
if not token:
self.gss_ctx = None
return None
msg = WriteBuffer.new_message(b'p')
msg.write_bytes(token)
msg.end_message()

return msg

cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()

Expand Down
1 change: 0 additions & 1 deletion asyncpg/protocol/protocol.pxd
Expand Up @@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):

cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
Expand Down
5 changes: 2 additions & 3 deletions asyncpg/protocol/protocol.pyx
Expand Up @@ -75,16 +75,15 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
CoreProtocol.__init__(self, con_params)
CoreProtocol.__init__(self, addr, con_params)

self.loop = loop
self.transport = None
self.waiter = connected_fut
self.cancel_waiter = None
self.cancel_sent_waiter = None

self.address = addr
self.settings = ConnectionSettings((self.address, con_params.database))
self.settings = ConnectionSettings((addr, con_params.database))
self.record_class = record_class

self.statement = None
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Expand Up @@ -35,6 +35,9 @@ dependencies = [
github = "https://github.com/MagicStack/asyncpg"

[project.optional-dependencies]
gss = [
'gssapi',
]
test = [
'flake8~=6.1',
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',
Expand Down

0 comments on commit f7bd646

Please sign in to comment.