Skip to content

Commit

Permalink
refactor: extract load_path function from common code using LOADERS (#…
Browse files Browse the repository at this point in the history
…7983)

Following up on #7965 (comment)

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
  • Loading branch information
alexmojaki and daavoo committed Aug 9, 2022
1 parent f7fbac1 commit 2ca1d6c
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 19 deletions.
7 changes: 2 additions & 5 deletions dvc/dependency/param.py
Expand Up @@ -8,7 +8,7 @@
from voluptuous import Any

from dvc.exceptions import DvcException
from dvc.utils.serialize import LOADERS, ParseError
from dvc.utils.serialize import ParseError, load_path
from dvc_data.hashfile.hash_info import HashInfo

from .base import Dependency
Expand Down Expand Up @@ -139,12 +139,9 @@ def validate_filepath(self):
)

def read_file(self):
_, ext = os.path.splitext(self.fs_path)
loader = LOADERS[ext]

self.validate_filepath()
try:
return loader(self.fs_path, fs=self.repo.fs)
return load_path(self.fs_path, self.repo.fs)
except ParseError as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
Expand Down
8 changes: 2 additions & 6 deletions dvc/parsing/context.py
@@ -1,5 +1,4 @@
import logging
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, MutableMapping, MutableSequence, Sequence
Expand Down Expand Up @@ -357,17 +356,14 @@ def select(
def load_from(
cls, fs, path: str, select_keys: List[str] = None
) -> "Context":
from dvc.utils.serialize import LOADERS
from dvc.utils.serialize import load_path

if not fs.exists(path):
raise ParamsLoadError(f"'{path}' does not exist")
if fs.isdir(path):
raise ParamsLoadError(f"'{path}' is a directory")

_, ext = os.path.splitext(path)
loader = LOADERS[ext]

data = loader(path, fs=fs)
data = load_path(path, fs)
if not isinstance(data, Mapping):
typ = type(data).__name__
raise ParamsLoadError(
Expand Down
6 changes: 2 additions & 4 deletions dvc/repo/metrics/show.py
Expand Up @@ -12,7 +12,7 @@
from dvc.scm import NoSCMError
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils.collections import ensure_list
from dvc.utils.serialize import LOADERS
from dvc.utils.serialize import load_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,9 +71,7 @@ def _extract_metrics(metrics, path, rev):

@error_handler
def _read_metric(path, fs, rev, **kwargs):
suffix = fs.path.suffix(path).lower()
loader = LOADERS[suffix]
val = loader(path, fs=fs)
val = load_path(path, fs)
val = _extract_metrics(val, path, rev)
return val or {}

Expand Down
6 changes: 2 additions & 4 deletions dvc/repo/params/show.py
Expand Up @@ -22,7 +22,7 @@
from dvc.ui import ui
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils.collections import ensure_list
from dvc.utils.serialize import LOADERS
from dvc.utils.serialize import load_path

if TYPE_CHECKING:
from dvc.output import Output
Expand Down Expand Up @@ -62,9 +62,7 @@ def _collect_configs(

@error_handler
def _read_fs_path(fs, fs_path, **kwargs):
suffix = fs.path.suffix(fs_path).lower()
loader = LOADERS[suffix]
return loader(fs_path, fs=fs)
return load_path(fs_path, fs)


def _read_params(
Expand Down
7 changes: 7 additions & 0 deletions dvc/utils/serialize/__init__.py
Expand Up @@ -14,6 +14,13 @@
{".toml": load_toml, ".json": load_json, ".py": load_py} # noqa: F405
)


def load_path(fs_path, fs):
suffix = fs.path.suffix(fs_path).lower()
loader = LOADERS[suffix]
return loader(fs_path, fs=fs)


DUMPERS: DefaultDict[str, DumperFn] = defaultdict( # noqa: F405
lambda: dump_yaml # noqa: F405
)
Expand Down

0 comments on commit 2ca1d6c

Please sign in to comment.