Skip to content

Commit

Permalink
🐛 FIX: Checking for entry points
Browse files Browse the repository at this point in the history
Since pypa/setuptools#3086,
the `EntryPoint` class can be overriden by setuptools.
Here we also check for that class,
plus providining more feedback, if the check fails.
  • Loading branch information
chrisjsewell committed Feb 16, 2022
1 parent c5673be commit d4e80e2
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 30 deletions.
5 changes: 2 additions & 3 deletions aiida/cmdline/params/types/plugin.py
Expand Up @@ -11,10 +11,9 @@
import functools

import click
from importlib_metadata import EntryPoint

from aiida.common import exceptions
from aiida.plugins import factories
from aiida.plugins import factories, utils
from aiida.plugins.entry_point import (
ENTRY_POINT_GROUP_PREFIX,
ENTRY_POINT_STRING_SEPARATOR,
Expand Down Expand Up @@ -239,7 +238,7 @@ def convert(self, value, param, ctx):
"""
# If the value is already of the expected return type, simply return it. This behavior is new in `click==8.0`:
# https://click.palletsprojects.com/en/8.0.x/parameters/#implementing-custom-types
if isinstance(value, EntryPoint):
if utils.is_entry_point(value):
try:
self.validate_entry_point_group(value.group)
except ValueError as exception:
Expand Down
56 changes: 32 additions & 24 deletions aiida/plugins/factories.py
Expand Up @@ -16,6 +16,8 @@

from aiida.common.exceptions import InvalidEntryPointTypeError

from .utils import is_entry_point

__all__ = (
'BaseFactory', 'CalculationFactory', 'CalcJobImporterFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory',
'OrbitalFactory', 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory'
Expand All @@ -31,16 +33,22 @@
from aiida.transports import Transport


def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> None:
def raise_invalid_type_error(
entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...], entry_point: Any
) -> None:
"""Raise an `InvalidEntryPointTypeError` with formatted message.
:param entry_point_name: name of the entry point
:param entry_point_group: name of the entry point group
:param valid_classes: tuple of valid classes for the given entry point group
:raises aiida.common.InvalidEntryPointTypeError: always
"""
template = 'entry point `{}` registered in group `{}` is invalid because its type is not one of: {}'
args = (entry_point_name, entry_point_group, ', '.join([e.__name__ for e in valid_classes]))
template = 'entry point `{}` registered in group `{}` is invalid because its type is not one of: {}; {}'
try:
entry_point_str = f'{entry_point.__module__}::{entry_point.__class__.__name__}'
except AttributeError:
entry_point_str = str(entry_point)
args = (entry_point_name, entry_point_group, ', '.join([e.__name__ for e in valid_classes]), entry_point_str)
raise InvalidEntryPointTypeError(template.format(*args))


Expand Down Expand Up @@ -79,12 +87,12 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni
valid_classes = (CalcJob, calcfunction)

if (
isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, CalcJob)) or
(is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode)
is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, CalcJob)) or
(is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode) # type: ignore
):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'CalcJobImporter']]:
Expand All @@ -103,7 +111,7 @@ def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional
if isclass(entry_point) and issubclass(entry_point, CalcJobImporter):
return entry_point # type: ignore[return-value]

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Data']]:
Expand All @@ -120,10 +128,10 @@ def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Entr
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Data,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Data)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, Data)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'DbImporter']]:
Expand All @@ -140,10 +148,10 @@ def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Unio
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (DbImporter,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, DbImporter)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, DbImporter)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def GroupFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Group']]:
Expand All @@ -160,10 +168,10 @@ def GroupFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Ent
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Group,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Group)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, Group)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Orbital']]:
Expand All @@ -180,10 +188,10 @@ def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[E
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Orbital,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Orbital)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, Orbital)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Parser']]:
Expand All @@ -200,10 +208,10 @@ def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[En
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Parser,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Parser)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, Parser)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Scheduler']]:
Expand All @@ -220,10 +228,10 @@ def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Scheduler,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Scheduler)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, Scheduler)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, Type['Transport']]]:
Expand All @@ -239,10 +247,10 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Transport,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Transport)):
if is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, Transport)):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)


def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'WorkChain', Callable]]:
Expand All @@ -261,9 +269,9 @@ def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[
valid_classes = (WorkChain, workfunction)

if (
isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, WorkChain)) or
(is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode)
is_entry_point(entry_point) or (isclass(entry_point) and issubclass(entry_point, WorkChain)) or
(is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode) # type: ignore
):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes, entry_point)
19 changes: 18 additions & 1 deletion aiida/plugins/utils.py
Expand Up @@ -8,8 +8,10 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Utilities dealing with plugins and entry points."""

from importlib import import_module
from typing import Any

from importlib_metadata import EntryPoint

from aiida.common import AIIDA_LOGGER

Expand All @@ -20,6 +22,21 @@
KEY_VERSION_PLUGIN = 'plugin' # The version of the plugin top level module, e.g. `aiida-quantumespresso`


def is_entry_point(obj: Any) -> bool:
"""Return True if the given object is an `EntryPoint` instance."""
if isinstance(obj, EntryPoint):
return True
# if setuptools is installed it can override the entry point class
# see https://github.com/pypa/setuptools/pull/3086
try:
from setuptools._vendor.importlib_metadata import EntryPoint as StEntryPoint
if isinstance(obj, StEntryPoint):
return True
except ImportError:
pass
return False


class PluginVersionProvider:
"""Utility class that determines version information about a given plugin resource."""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -367,6 +367,7 @@ module = [
'pytz.*',
'ruamel.*',
'scipy.*',
'setuptools.*',
'tabulate.*',
'tqdm.*',
'wrapt.*',
Expand Down
5 changes: 3 additions & 2 deletions tests/plugins/test_entry_point.py
Expand Up @@ -11,7 +11,8 @@
import pytest

from aiida.common.warnings import AiidaDeprecationWarning
from aiida.plugins.entry_point import EntryPoint, get_entry_point, validate_registered_entry_points
from aiida.plugins.entry_point import get_entry_point, validate_registered_entry_points
from aiida.plugins.utils import is_entry_point


def test_validate_registered_entry_points():
Expand Down Expand Up @@ -44,4 +45,4 @@ def test_get_entry_point_deprecated(group, name):
with pytest.warns(AiidaDeprecationWarning, match=warning):
entry_point = get_entry_point(group, name)

assert isinstance(entry_point, EntryPoint)
assert is_entry_point(entry_point)

0 comments on commit d4e80e2

Please sign in to comment.