diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index bd064bd1bf..eb4faf176f 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -8,7 +8,7 @@ from voluptuous import Any from dvc.exceptions import DvcException -from dvc.utils.serialize import LOADERS, ParseError +from dvc.utils.serialize import ParseError, load_path from dvc_data.hashfile.hash_info import HashInfo from .base import Dependency @@ -139,12 +139,9 @@ def validate_filepath(self): ) def read_file(self): - _, ext = os.path.splitext(self.fs_path) - loader = LOADERS[ext] - self.validate_filepath() try: - return loader(self.fs_path, fs=self.repo.fs) + return load_path(self.fs_path, self.repo.fs) except ParseError as exc: raise BadParamFileError( f"Unable to read parameters from '{self}'" diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 8c6510c9d9..b4315d7365 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -1,5 +1,4 @@ import logging -import os from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping, MutableMapping, MutableSequence, Sequence @@ -357,17 +356,14 @@ def select( def load_from( cls, fs, path: str, select_keys: List[str] = None ) -> "Context": - from dvc.utils.serialize import LOADERS + from dvc.utils.serialize import load_path if not fs.exists(path): raise ParamsLoadError(f"'{path}' does not exist") if fs.isdir(path): raise ParamsLoadError(f"'{path}' is a directory") - _, ext = os.path.splitext(path) - loader = LOADERS[ext] - - data = loader(path, fs=fs) + data = load_path(path, fs) if not isinstance(data, Mapping): typ = type(data).__name__ raise ParamsLoadError( diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index a829413e53..a6eb07933f 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -12,7 +12,7 @@ from dvc.scm import NoSCMError from dvc.utils import error_handler, errored_revisions, onerror_collect from dvc.utils.collections import ensure_list -from dvc.utils.serialize import LOADERS +from dvc.utils.serialize import load_path logger = logging.getLogger(__name__) @@ -71,9 +71,7 @@ def _extract_metrics(metrics, path, rev): @error_handler def _read_metric(path, fs, rev, **kwargs): - suffix = fs.path.suffix(path).lower() - loader = LOADERS[suffix] - val = loader(path, fs=fs) + val = load_path(path, fs) val = _extract_metrics(val, path, rev) return val or {} diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 502867a20a..5a79f54a44 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -22,7 +22,7 @@ from dvc.ui import ui from dvc.utils import error_handler, errored_revisions, onerror_collect from dvc.utils.collections import ensure_list -from dvc.utils.serialize import LOADERS +from dvc.utils.serialize import load_path if TYPE_CHECKING: from dvc.output import Output @@ -62,9 +62,7 @@ def _collect_configs( @error_handler def _read_fs_path(fs, fs_path, **kwargs): - suffix = fs.path.suffix(fs_path).lower() - loader = LOADERS[suffix] - return loader(fs_path, fs=fs) + return load_path(fs_path, fs) def _read_params( diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index 67eae15aa3..82fc51d213 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -14,6 +14,13 @@ {".toml": load_toml, ".json": load_json, ".py": load_py} # noqa: F405 ) + +def load_path(fs_path, fs): + suffix = fs.path.suffix(fs_path).lower() + loader = LOADERS[suffix] + return loader(fs_path, fs=fs) + + DUMPERS: DefaultDict[str, DumperFn] = defaultdict( # noqa: F405 lambda: dump_yaml # noqa: F405 )