Skip to content

Commit

Permalink
Catch only on ModuleNotFound error and simple reraise with warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Apr 30, 2024
1 parent eb41dbf commit 680e0a2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
15 changes: 11 additions & 4 deletions airflow/operators/python.py
Expand Up @@ -58,6 +58,8 @@
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from pendulum.datetime import DateTime

Expand Down Expand Up @@ -352,16 +354,21 @@ def _load_pickle():
def _load_dill():
try:
import dill
except ImportError:
raise AirflowException("Unable to import 'dill' make sure that it installed.")
except ModuleNotFoundError:
log.error("Unable to import `dill` module. Please please make sure that it installed.")
raise
return dill


def _load_cloudpickle():
try:
import cloudpickle
except ImportError:
raise AirflowException("Unable to import 'cloudpickle' make sure that it installed.")
except ModuleNotFoundError:
log.error(
"Unable to import `cloudpickle` module. "
"Please install it with: pip install 'apache-airflow[cloudpickle]'"
)
raise
return cloudpickle


Expand Down
9 changes: 5 additions & 4 deletions tests/operators/test_python.py
Expand Up @@ -921,13 +921,14 @@ def f(exit_code):
),
],
)
def test_advanced_serializer_not_installed(self, serializer):
def test_advanced_serializer_not_installed(self, serializer, caplog):
"""Test case for check raising an error if dill/cloudpickle is not installed."""

def f(a): ...

with pytest.raises(AirflowException, match=f"Unable to import '{serializer}'"):
self.run_as_task(f, op_args=[42], serializer=serializer, system_site_packages=False)
with pytest.raises(ModuleNotFoundError):
self.run_as_task(f, op_args=[42], serializer=serializer)
assert f"Unable to import `{serializer}` module." in caplog.text


venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")
Expand Down Expand Up @@ -1319,7 +1320,7 @@ def f(
[
pytest.param("pickle", id="pickle"),
pytest.param("dill", marks=DILL_MARKER, id="dill"),
pytest.param("cloudpickle", id="cloudpickle"),
pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"),
pytest.param(None, id="default"),
],
)
Expand Down

0 comments on commit 680e0a2

Please sign in to comment.