Skip to content

Commit

Permalink
api: Add get_params.
Browse files Browse the repository at this point in the history
Closes #6507

Uses `repo.params.show` with custom error_handler and postprocess the outputs for more user-friendly structure.

Extend `repo.params.show` to accept `stages` argument to cover the "params of current stage" use case.
  • Loading branch information
daavoo committed Jun 4, 2022
1 parent fbe95c2 commit 0da8293
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 8 deletions.
173 changes: 172 additions & 1 deletion dvc/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from collections import Counter
from contextlib import _GeneratorContextManager as GCM
from typing import Dict, Iterable, Optional, Union

from funcy import reraise
from funcy import first, reraise

from dvc.exceptions import OutputNotFoundError, PathMissingError
from dvc.repo import Repo
Expand Down Expand Up @@ -92,6 +94,175 @@ def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None):
return fd.read()


def get_params(
*targets: str,
repo: Optional[str] = None,
revs: Optional[Union[str, Iterable[str]]] = None,
deps: bool = False,
stages: Optional[Union[str, Iterable[str]]] = None,
) -> Dict:
"""Get parameters tracked in `repo`.
Args:
*targets: Names of the parameter files to retrieve parameters from.
If no `targets` are provided, all parameter files will be used.
repo (str, optional): location of the DVC project.
Defaults to the current project (found by walking up from the
current working directory tree).
It can be a URL or a file system path.
Both HTTP and SSH protocols are supported for online Git repos
(e.g. [user@]server:project.git).
revs (Union[str, Iterable[str]], optionak): Any `Git revision`_ such as
a branch or tag name, a commit hash or a dvc experiment name.
Defaults to HEAD.
If `repo` is not a Git repo, this option is ignored.
deps (bool, optional): Whether to retrieve only parameters that are
stage dependencies or not.
Defaults to `False`.
stages: (Union[str, Iterable[str]], optional): Name of the stages to
retrieve parameters from.
Defaults to `None`.
If no stages are provided, all parameters from all stages will be
retrieved.
Returns:
Dict: Processed parameters.
If no `revs` were provided, parameters will be available as
top-level keys.
If `revs` were provided, the top-level keys will be those revs.
See Examples below.
Examples:
- No args.
Will use the current project as `repo` and retrieve all parameters,
for all stages, for the current revision.
>>> import json
>>> import dvc.api
>>> params = dvc.api.get_params()
>>> print(json.dumps(params, indent=4))
{
"prepare": {
"split": 0.2,
"seed": 20170428
},
"featurize": {
"max_features": 10000,
"ngrams": 2
},
"train": {
"seed": 20170428,
"n_est": 50,
"min_split": 0.01
}
}
- Filter with `stages`.
>>> import json
>>> import dvc.api
>>> params = dvc.api.get_params(stages="prepare")
>>> print(json.dumps(params, indent=4))
{
"prepare": {
"split": 0.2,
"seed": 20170428
}
}
- Git URL as `repo`.
>>> import json
>>> import dvc.api
>>> params = dvc.api.get_params(
... repo="https://github.com/iterative/demo-fashion-mnist")
{
"train": {
"batch_size": 128,
"hidden_units": 64,
"dropout": 0.4,
"num_epochs": 10,
"lr": 0.001,
"conv_activation": "relu"
}
}
- Multiple `revs`.
>>> import json
>>> import dvc.api
>>> params = dvc.api.get_params(
... repo="https://github.com/iterative/demo-fashion-mnist",
... revs=["main", "low-lr-experiment"])
>>> print(json.dumps(params, indent=4))
{
"main": {
"train": {
"batch_size": 128,
"hidden_units": 64,
"dropout": 0.4,
"num_epochs": 10,
"lr": 0.001,
"conv_activation": "relu"
}
},
"low-lr-experiment": {
"train": {
"batch_size": 128,
"hidden_units": 64,
"dropout": 0.4,
"num_epochs": 10,
"lr": 0.001,
"conv_activation": "relu"
}
}
}
"""

if isinstance(revs, str):
revs = [revs]
if isinstance(stages, str):
stages = [stages]

def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs):
raise exception

def _postprocess(params):
processed = {}
for rev, rev_data in params.items():
processed[rev] = {}

counts = Counter()
for file_data in rev_data["data"].values():
for k in file_data["data"]:
counts[k] += 1

for file_name, file_data in rev_data["data"].items():
to_merge = {
(k if counts[k] == 1 else f"{file_name}:{k}"): v
for k, v in file_data["data"].items()
}
processed[rev] = {**processed[rev], **to_merge}

if first(processed) == "":
return processed[""]

return processed

with Repo.open(repo) as _repo:
params = _repo.params.show(
revs=revs,
targets=targets,
deps=deps,
onerror=_onerror_raise,
stages=stages,
)

return _postprocess(params)


def make_checkpoint():
"""
Signal DVC to create a checkpoint experiment.
Expand Down
46 changes: 39 additions & 7 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
import os
from collections import defaultdict
from copy import copy
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
)

from scmrepo.exceptions import SCMError

Expand Down Expand Up @@ -62,20 +70,27 @@ def _read_params(
params_fs_paths,
deps=False,
onerror: Optional[Callable] = None,
stages: Optional[Iterable[str]] = None,
):
res: Dict[str, Dict] = defaultdict(lambda: defaultdict(dict))
fs_paths = copy(params_fs_paths)

if deps:
for param in params:
if stages and param.stage.name not in stages:
continue
params_dict = error_handler(param.read_params)(
onerror=onerror, flatten=False
)
if params_dict:
name = os.sep.join(repo.fs.path.relparts(param.fs_path))
res[name]["data"].update(params_dict["data"])
else:
fs_paths += [param.fs_path for param in params]
fs_paths += [
param.fs_path
for param in params
if not stages or param.stage.name in stages
]

for fs_path in fs_paths:
from_path = _read_fs_path(repo.fs, fs_path, onerror=onerror)
Expand All @@ -86,11 +101,13 @@ def _read_params(
return res


def _collect_vars(repo, params) -> Dict:
def _collect_vars(repo, params, stages=None) -> Dict:
vars_params: Dict[str, Dict] = defaultdict(dict)

for stage in repo.index.stages:
if isinstance(stage, PipelineStage) and stage.tracked_vars:
if stages and stage.name not in stages:
continue
for file, vars_ in stage.tracked_vars.items():
# `params` file are shown regardless of `tracked` or not
# to reduce noise and duplication, they are skipped
Expand All @@ -103,14 +120,26 @@ def _collect_vars(repo, params) -> Dict:


@locked
def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None):
def show(
repo,
revs=None,
targets=None,
deps=False,
onerror: Callable = None,
stages=None,
):
if onerror is None:
onerror = onerror_collect
res = {}

for branch in repo.brancher(revs=revs):
params = error_handler(_gather_params)(
repo=repo, rev=branch, targets=targets, deps=deps, onerror=onerror
repo=repo,
rev=branch,
targets=targets,
deps=deps,
onerror=onerror,
stages=stages,
)

if params:
Expand All @@ -137,16 +166,19 @@ def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None):
return res


def _gather_params(repo, rev, targets=None, deps=False, onerror=None):
def _gather_params(
repo, rev, targets=None, deps=False, onerror=None, stages=None
):
param_outs, params_fs_paths = _collect_configs(repo, rev, targets=targets)
params = _read_params(
repo,
params=param_outs,
params_fs_paths=params_fs_paths,
deps=deps,
onerror=onerror,
stages=stages,
)
vars_params = _collect_vars(repo, params)
vars_params = _collect_vars(repo, params, stages=stages)

# NOTE: only those that are not added as a ParamDependency are
# included so we don't need to recursively merge them yet.
Expand Down

0 comments on commit 0da8293

Please sign in to comment.