Skip to content

Commit

Permalink
Handle environments with HOME set to a not-a-directory (#1063)
Browse files Browse the repository at this point in the history
If `HOME` points to a regular file (or `/dev/null`), make sure we don't
crash unnecessarily, and if we do need to crash, so so informatively.

Fixes: #1014
  • Loading branch information
elprans committed Aug 17, 2023
1 parent cbf64e1 commit af922bc
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
64 changes: 42 additions & 22 deletions asyncpg/connect_utils.py
Expand Up @@ -165,7 +165,7 @@ def _validate_port_spec(hosts, port):
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
else:
Expand Down Expand Up @@ -211,7 +211,7 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
addr = m.group(1)
hostspec_port = m.group(2)
else:
raise ValueError(
raise exceptions.ClientConfigurationError(
'invalid IPv6 address in the connection URI: {!r}'.format(
hostspec
)
Expand Down Expand Up @@ -240,13 +240,13 @@ def _parse_hostlist(hostlist, port, *, unquote=False):

def _parse_tls_version(tls_version):
if tls_version.startswith('SSL'):
raise ValueError(
raise exceptions.ClientConfigurationError(
f"Unsupported TLS version: {tls_version}"
)
try:
return ssl_module.TLSVersion[tls_version.replace('.', '_')]
except KeyError:
raise ValueError(
raise exceptions.ClientConfigurationError(
f"No such TLS version: {tls_version}"
)

Expand Down Expand Up @@ -274,7 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
parsed = urllib.parse.urlparse(dsn)

if parsed.scheme not in {'postgresql', 'postgres'}:
raise ValueError(
raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))

Expand Down Expand Up @@ -437,11 +437,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database = user

if user is None:
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'could not determine user name to connect with')

if database is None:
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'could not determine database name to connect to')

if password is None:
Expand Down Expand Up @@ -477,7 +477,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
have_tcp_addrs = True

if not addrs:
raise ValueError(
raise exceptions.InternalClientError(
'could not determine the database address to connect to')

if ssl is None:
Expand All @@ -491,7 +491,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
'`sslmode` parameter must be one of: {}'.format(modes))

# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
Expand All @@ -511,19 +511,36 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
try:
sslrootcert = _dot_postgresql_path('root.crt')
assert sslrootcert is not None
ssl.load_verify_locations(cafile=sslrootcert)
except (AssertionError, FileNotFoundError):
if sslrootcert is not None:
ssl.load_verify_locations(cafile=sslrootcert)
else:
raise exceptions.ClientConfigurationError(
'cannot determine location of user '
'PostgreSQL configuration directory'
)
except (
exceptions.ClientConfigurationError,
FileNotFoundError,
NotADirectoryError,
):
if sslmode > SSLMode.require:
if sslrootcert is None:
raise RuntimeError(
'Cannot determine home directory'
sslrootcert = '~/.postgresql/root.crt'
detail = (
'Could not determine location of user '
'home directory (HOME is either unset, '
'inaccessible, or does not point to a '
'valid directory)'
)
raise ValueError(
else:
detail = None
raise exceptions.ClientConfigurationError(
f'root certificate file "{sslrootcert}" does '
f'not exist\nEither provide the file or '
f'change sslmode to disable server '
f'certificate verification.'
f'not exist or cannot be accessed',
hint='Provide the certificate file directly '
f'or make sure "{sslrootcert}" '
'exists and is readable.',
detail=detail,
)
elif sslmode == SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
Expand All @@ -542,7 +559,10 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if sslcrl is not None:
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
except (
FileNotFoundError,
NotADirectoryError,
):
pass
else:
ssl.verify_flags |= \
Expand Down Expand Up @@ -571,7 +591,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
keyfile=sslkey,
password=lambda: sslpassword
)
except FileNotFoundError:
except (FileNotFoundError, NotADirectoryError):
pass

# OpenSSL 1.1.1 keylog file, copied from create_default_context()
Expand Down Expand Up @@ -606,7 +626,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
not isinstance(server_settings, dict) or
not all(isinstance(k, str) for k in server_settings) or
not all(isinstance(v, str) for v in server_settings.values())):
raise ValueError(
raise exceptions.ClientConfigurationError(
'server_settings is expected to be None or '
'a Dict[str, str]')

Expand All @@ -617,7 +637,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
try:
target_session_attrs = SessionAttribute(target_session_attrs)
except ValueError:
raise exceptions.InterfaceError(
raise exceptions.ClientConfigurationError(
"target_session_attrs is expected to be one of "
"{!r}"
", got {!r}".format(
Expand Down
7 changes: 6 additions & 1 deletion asyncpg/exceptions/_base.py
Expand Up @@ -13,7 +13,8 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
'ClientConfigurationError')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -220,6 +221,10 @@ def with_msg(self, msg):
)


class ClientConfigurationError(InterfaceError, ValueError):
"""An error caused by improper client configuration."""


class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""

Expand Down
32 changes: 30 additions & 2 deletions tests/test_connect.py
Expand Up @@ -79,6 +79,15 @@ def mock_no_home_dir():
yield


@contextlib.contextmanager
def mock_dev_null_home_dir():
with unittest.mock.patch(
'pathlib.Path.home',
unittest.mock.Mock(return_value=pathlib.Path('/dev/null')),
):
yield


class TestSettings(tb.ConnectedTestCase):

async def test_get_settings_01(self):
Expand Down Expand Up @@ -1318,16 +1327,35 @@ async def test_connection_no_home_dir(self):
await con.fetchval('SELECT 42')
await con.close()

with mock_dev_null_home_dir():
con = await self.connect(
dsn='postgresql://foo/',
user='postgres',
database='postgres',
host='localhost')
await con.fetchval('SELECT 42')
await con.close()

with self.assertRaisesRegex(
RuntimeError,
'Cannot determine home directory'
exceptions.ClientConfigurationError,
r'root certificate file "~/\.postgresql/root\.crt" does not exist'
):
with mock_no_home_dir():
await self.connect(
host='localhost',
user='ssl_user',
ssl='verify-full')

with self.assertRaisesRegex(
exceptions.ClientConfigurationError,
r'root certificate file ".*" does not exist'
):
with mock_dev_null_home_dir():
await self.connect(
host='localhost',
user='ssl_user',
ssl='verify-full')


class BaseTestSSLConnection(tb.ConnectedTestCase):
@classmethod
Expand Down

0 comments on commit af922bc

Please sign in to comment.