Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only load distribution of a name once #25296

Merged
merged 1 commit into from Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/plugins_manager.py
Expand Up @@ -113,7 +113,7 @@ class EntryPointSource(AirflowPluginSource):
"""Class used to define Plugins loaded from entrypoint."""

def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
self.dist = dist.metadata['name']
self.dist = dist.metadata['Name']
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python pacakge metadata keys are case-insensitive, I use Name (the case used by the standard) everywhere so tests are easier to write.

self.version = dist.version
self.entrypoint = str(entrypoint)

Expand Down
22 changes: 16 additions & 6 deletions airflow/utils/entry_points.py
Expand Up @@ -15,23 +15,33 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Iterator

from packaging.utils import canonicalize_name

try:
import importlib_metadata
import importlib_metadata as metadata
except ImportError:
from importlib import metadata as importlib_metadata # type: ignore
from importlib import metadata # type: ignore[no-redef]


def entry_points_with_dist(group: str):
"""
Return EntryPoint objects of the given group, along with the distribution information.
def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]:
"""Retrieve entry points of the given group.

This is like the ``entry_points()`` function from importlib.metadata,
except it also returns the distribution the entry_point was loaded from.

:param group: Filter results to only this entrypoint group
:return: Generator of (EntryPoint, Distribution) objects for the specified groups
"""
for dist in importlib_metadata.distributions():
loaded: set[str] = set()
for dist in metadata.distributions():
key = canonicalize_name(dist.metadata["Name"])
if key in loaded:
continue
loaded.add(key)
for e in dist.entry_points:
if e.group != group:
continue
Expand Down
3 changes: 2 additions & 1 deletion tests/plugins/test_plugins_manager.py
Expand Up @@ -298,6 +298,7 @@ def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog):
from airflow.plugins_manager import import_errors, load_entrypoint_plugins

mock_dist = mock.Mock()
mock_dist.metadata = {"Name": "test-dist"}

mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint'
Expand Down Expand Up @@ -387,7 +388,7 @@ def test_should_return_correct_source_details(self):
mock_entrypoint.module = 'module_name_plugin'

mock_dist = mock.Mock()
mock_dist.metadata = {'name': 'test-entrypoint-plugin'}
mock_dist.metadata = {'Name': 'test-entrypoint-plugin'}
mock_dist.version = '1.0.0'
mock_dist.entry_points = [mock_entrypoint]

Expand Down
49 changes: 49 additions & 0 deletions tests/utils/test_entry_points.py
@@ -0,0 +1,49 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from typing import Iterable
from unittest import mock

from airflow.utils.entry_points import entry_points_with_dist, metadata


class MockDistribution:
def __init__(self, name: str, entry_points: Iterable[metadata.EntryPoint]) -> None:
self.metadata = {"Name": name}
self.entry_points = entry_points


class MockMetadata:
def distributions(self):
return [
MockDistribution(
"dist1",
[metadata.EntryPoint("a", "b", "group_x"), metadata.EntryPoint("c", "d", "group_y")],
),
MockDistribution("Dist2", [metadata.EntryPoint("e", "f", "group_x")]),
MockDistribution("dist2", [metadata.EntryPoint("g", "h", "group_x")]), # Duplicated name.
]


@mock.patch("airflow.utils.entry_points.metadata", MockMetadata())
def test_entry_points_with_dist():
entries = list(entry_points_with_dist("group_x"))

# The second "dist2" is ignored. Only "group_x" entries are loaded.
assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"]
assert [ep.name for ep, _ in entries] == ["a", "e"]
2 changes: 1 addition & 1 deletion tests/www/views/test_views.py
Expand Up @@ -88,7 +88,7 @@ def test_plugin_should_list_entrypoint_on_page_with_details(admin_client):
mock_plugin = AirflowPlugin()
mock_plugin.name = "test_plugin"
mock_plugin.source = EntryPointSource(
mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'})
mock.Mock(), mock.Mock(version='1.0.0', metadata={'Name': 'test-entrypoint-testpluginview'})
)
with mock_plugin_manager(plugins=[mock_plugin]):
resp = admin_client.get('/plugin')
Expand Down