Skip to content

Commit

Permalink
Add support for password functions (useful for RDS IAM auth) (#554)
Browse files Browse the repository at this point in the history
Closes: #554
Closes: #553
  • Loading branch information
Harvey Frye authored and elprans committed Apr 23, 2020
1 parent 1d4325c commit 1d9457f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
2 changes: 1 addition & 1 deletion asyncpg/__init__.py
Expand Up @@ -31,4 +31,4 @@
# snapshots will automatically include the git revision
# in __version__, for example: '0.16.0.dev0+ge06ad03'

__version__ = '0.20.1'
__version__ = '0.21.0.dev0'
13 changes: 12 additions & 1 deletion asyncpg/connect_utils.py
Expand Up @@ -21,6 +21,7 @@
import typing
import urllib.parse
import warnings
import inspect

from . import compat
from . import exceptions
Expand Down Expand Up @@ -601,6 +602,16 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
raise asyncio.TimeoutError

connected = _create_future(loop)

params_input = params
if callable(params.password):
if inspect.iscoroutinefunction(params.password):
password = await params.password()
else:
password = params.password()

params = params._replace(password=password)

proto_factory = lambda: protocol.Protocol(
addr, connected, params, loop)

Expand Down Expand Up @@ -633,7 +644,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
tr.close()
raise

con = connection_class(pr, tr, loop, addr, config, params)
con = connection_class(pr, tr, loop, addr, config, params_input)
pr.set_connection(con)
return con

Expand Down
7 changes: 7 additions & 0 deletions asyncpg/connection.py
Expand Up @@ -1566,6 +1566,10 @@ async def connect(dsn=None, *,
other users and applications may be able to read it without needing
specific privileges. It is recommended to use *passfile* instead.
Password may be either a string, or a callable that returns a string.
If a callable is provided, it will be called each time a new connection
is established.
:param passfile:
The name of the file used to store passwords
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
Expand Down Expand Up @@ -1646,6 +1650,9 @@ async def connect(dsn=None, *,
Added ability to specify multiple hosts in the *dsn*
and *host* arguments.
.. versionchanged:: 0.21.0
The *password* argument now accepts a callable or an async function.
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
Expand Down
38 changes: 38 additions & 0 deletions tests/test_connect.py
Expand Up @@ -204,6 +204,44 @@ async def test_auth_password_cleartext(self):
user='password_user',
password='wrongpassword')

async def test_auth_password_cleartext_callable(self):
def get_correctpassword():
return 'correctpassword'

def get_wrongpassword():
return 'wrongpassword'

conn = await self.connect(
user='password_user',
password=get_correctpassword)
await conn.close()

with self.assertRaisesRegex(
asyncpg.InvalidPasswordError,
'password authentication failed for user "password_user"'):
await self._try_connect(
user='password_user',
password=get_wrongpassword)

async def test_auth_password_cleartext_callable_coroutine(self):
async def get_correctpassword():
return 'correctpassword'

async def get_wrongpassword():
return 'wrongpassword'

conn = await self.connect(
user='password_user',
password=get_correctpassword)
await conn.close()

with self.assertRaisesRegex(
asyncpg.InvalidPasswordError,
'password authentication failed for user "password_user"'):
await self._try_connect(
user='password_user',
password=get_wrongpassword)

async def test_auth_password_md5(self):
conn = await self.connect(
user='md5_user', password='correctpassword')
Expand Down

0 comments on commit 1d9457f

Please sign in to comment.