From f48fe2a5e640716d59e211dcac26899d6f8dd163 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 3 Aug 2022 21:01:43 +0800 Subject: [PATCH] Only load distribution of a name once (#25296) (cherry picked from commit c30dc5e64d7229cbf8e9fbe84cfa790dfef5fb8c) --- airflow/plugins_manager.py | 2 +- airflow/utils/entry_points.py | 22 ++++++++---- tests/plugins/test_plugins_manager.py | 3 +- tests/utils/test_entry_points.py | 49 +++++++++++++++++++++++++++ tests/www/views/test_views.py | 2 +- 5 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 tests/utils/test_entry_points.py diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 82e295fa1906c..431d5fe55afe1 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -113,7 +113,7 @@ class EntryPointSource(AirflowPluginSource): """Class used to define Plugins loaded from entrypoint.""" def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution): - self.dist = dist.metadata['name'] + self.dist = dist.metadata['Name'] self.version = dist.version self.entrypoint = str(entrypoint) diff --git a/airflow/utils/entry_points.py b/airflow/utils/entry_points.py index 668ed9b9941fb..483f9efe776cb 100644 --- a/airflow/utils/entry_points.py +++ b/airflow/utils/entry_points.py @@ -15,15 +15,20 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from typing import Iterator + +from packaging.utils import canonicalize_name + try: - import importlib_metadata + import importlib_metadata as metadata except ImportError: - from importlib import metadata as importlib_metadata # type: ignore + from importlib import metadata # type: ignore[no-redef] -def entry_points_with_dist(group: str): - """ - Return EntryPoint objects of the given group, along with the distribution information. +def entry_points_with_dist(group: str) -> Iterator[tuple[metadata.EntryPoint, metadata.Distribution]]: + """Retrieve entry points of the given group. This is like the ``entry_points()`` function from importlib.metadata, except it also returns the distribution the entry_point was loaded from. @@ -31,7 +36,12 @@ def entry_points_with_dist(group: str): :param group: Filter results to only this entrypoint group :return: Generator of (EntryPoint, Distribution) objects for the specified groups """ - for dist in importlib_metadata.distributions(): + loaded: set[str] = set() + for dist in metadata.distributions(): + key = canonicalize_name(dist.metadata["Name"]) + if key in loaded: + continue + loaded.add(key) for e in dist.entry_points: if e.group != group: continue diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index f97a811c911ec..c46b6e83f2f49 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -298,6 +298,7 @@ def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog): from airflow.plugins_manager import import_errors, load_entrypoint_plugins mock_dist = mock.Mock() + mock_dist.metadata = {"Name": "test-dist"} mock_entrypoint = mock.Mock() mock_entrypoint.name = 'test-entrypoint' @@ -387,7 +388,7 @@ def test_should_return_correct_source_details(self): mock_entrypoint.module = 'module_name_plugin' mock_dist = mock.Mock() - mock_dist.metadata = {'name': 'test-entrypoint-plugin'} + mock_dist.metadata = {'Name': 'test-entrypoint-plugin'} mock_dist.version = '1.0.0' mock_dist.entry_points = [mock_entrypoint] diff --git a/tests/utils/test_entry_points.py b/tests/utils/test_entry_points.py new file mode 100644 index 0000000000000..65f688647e123 --- /dev/null +++ b/tests/utils/test_entry_points.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Iterable +from unittest import mock + +from airflow.utils.entry_points import entry_points_with_dist, metadata + + +class MockDistribution: + def __init__(self, name: str, entry_points: Iterable[metadata.EntryPoint]) -> None: + self.metadata = {"Name": name} + self.entry_points = entry_points + + +class MockMetadata: + def distributions(self): + return [ + MockDistribution( + "dist1", + [metadata.EntryPoint("a", "b", "group_x"), metadata.EntryPoint("c", "d", "group_y")], + ), + MockDistribution("Dist2", [metadata.EntryPoint("e", "f", "group_x")]), + MockDistribution("dist2", [metadata.EntryPoint("g", "h", "group_x")]), # Duplicated name. + ] + + +@mock.patch("airflow.utils.entry_points.metadata", MockMetadata()) +def test_entry_points_with_dist(): + entries = list(entry_points_with_dist("group_x")) + + # The second "dist2" is ignored. Only "group_x" entries are loaded. + assert [dist.metadata["Name"] for _, dist in entries] == ["dist1", "Dist2"] + assert [ep.name for ep, _ in entries] == ["a", "e"] diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index fa79e145cba6c..b6d21a3f26357 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -87,7 +87,7 @@ def test_plugin_should_list_entrypoint_on_page_with_details(admin_client): mock_plugin = AirflowPlugin() mock_plugin.name = "test_plugin" mock_plugin.source = EntryPointSource( - mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'}) + mock.Mock(), mock.Mock(version='1.0.0', metadata={'Name': 'test-entrypoint-testpluginview'}) ) with mock_plugin_manager(plugins=[mock_plugin]): resp = admin_client.get('/plugin')