Skip to content

Commit

Permalink
api: Add params_show.
Browse files Browse the repository at this point in the history
Closes #6507

Uses `repo.params.show` with custom error_handler and postprocess the outputs for more user-friendly structure.

Extend `repo.params.show` to accept `stages` argument to cover the "params of current stage" use case.
  • Loading branch information
daavoo committed Jun 20, 2022
1 parent 7e89223 commit 057c9d2
Show file tree
Hide file tree
Showing 3 changed files with 428 additions and 10 deletions.
266 changes: 264 additions & 2 deletions dvc/api.py
@@ -1,8 +1,9 @@
import os
from collections import Counter
from contextlib import _GeneratorContextManager as GCM
from typing import Optional
from typing import Dict, Iterable, Optional, Union

from funcy import reraise
from funcy import first, reraise

from dvc.exceptions import OutputNotFoundError, PathMissingError
from dvc.repo import Repo
Expand Down Expand Up @@ -214,6 +215,267 @@ def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None):
return fd.read()


def params_show(
*targets: str,
repo: Optional[str] = None,
stages: Optional[Union[str, Iterable[str]]] = None,
rev: Optional[str] = None,
deps: bool = False,
) -> Dict:
"""Get parameters tracked in `repo`.
Without arguments, this function will retrieve all params from all tracked
parameter files, for the current working tree.
See the options below to restrict the parameters retrieved.
Args:
*targets (str, optional): Names of the parameter files to retrieve
params from. For example, "params.py, myparams.toml".
If no `targets` are provided, all parameter files tracked in `dvc.yaml`
will be used.
Note that targets don't necessarily have to be defined in `dvc.yaml`.
repo (str, optional): location of the DVC repository.
Defaults to the current project (found by walking up from the
current working directory tree).
It can be a URL or a file system path.
Both HTTP and SSH protocols are supported for online Git repos
(e.g. [user@]server:project.git).
stages (Union[str, Iterable[str]], optional): Name or names of the
stages to retrieve parameters from.
Defaults to `None`.
If `None`, all parameters from all stages will be retrieved.
rev (str, optional): Name of the `Git revision`_ to retrieve parameters
from.
Defaults to `None`.
An example of git revision can be a branch or tag name, a commit
hash or a dvc experiment name.
If `repo` is not a Git repo, this option is ignored.
If `None`, the current working tree will be used.
deps (bool, optional): Whether to retrieve only parameters that are
stage dependencies or not.
Defaults to `False`.
Returns:
Dict: See Examples below.
Examples:
- No arguments.
Working on https://github.com/iterative/example-get-started
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show()
>>> print(json.dumps(params, indent=4))
{
"prepare": {
"split": 0.2,
"seed": 20170428
},
"featurize": {
"max_features": 200,
"ngrams": 2
},
"train": {
"seed": 20170428,
"n_est": 50,
"min_split": 0.01
}
}
---
- Filtering with `stages`.
Working on https://github.com/iterative/example-get-started
`stages` can a single string:
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show(stages="prepare")
>>> print(json.dumps(params, indent=4))
{
"prepare": {
"split": 0.2,
"seed": 20170428
}
}
Or an iterable of strings:
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show(stages=["prepare", "train"])
>>> print(json.dumps(params, indent=4))
{
"prepare": {
"split": 0.2,
"seed": 20170428
},
"train": {
"seed": 20170428,
"n_est": 50,
"min_split": 0.01
}
}
---
- Using `rev`.
Working on https://github.com/iterative/example-get-started
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show(rev="tune-hyperparams")
>>> print(json.dumps(params, indent=4))
{
"prepare": {
"split": 0.2,
"seed": 20170428
},
"featurize": {
"max_features": 200,
"ngrams": 2
},
"train": {
"seed": 20170428,
"n_est": 100,
"min_split": 8
}
}
---
- Using `targets`.
Working on `multi-params-files` folder of
https://github.com/iterative/pipeline-conifguration
You can pass a single target:
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show("params.yaml")
>>> print(json.dumps(params, indent=4))
{
"run_mode": "prod",
"configs": {
"dev": "configs/params_dev.yaml",
"test": "configs/params_test.yaml",
"prod": "configs/params_prod.yaml"
},
"evaluate": {
"dataset": "micro",
"size": 5000,
"metrics": ["f1", "roc-auc"],
"metrics_file": "reports/metrics.json",
"plots_cm": "reports/plot_confusion_matrix.png"
}
}
Or multiple targets:
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show(
... "configs/params_dev.yaml", "configs/params_prod.yaml")
>>> print(json.dumps(params, indent=4))
{
"configs/params_prod.yaml:run_mode": "prod",
"configs/params_prod.yaml:config_file": "configs/params_prod.yaml",
"configs/params_prod.yaml:data_load": {
"dataset": "large",
"sampling": {
"enable": true,
"size": 50000
}
},
"configs/params_prod.yaml:train": {
"epochs": 1000
},
"configs/params_dev.yaml:run_mode": "dev",
"configs/params_dev.yaml:config_file": "configs/params_dev.yaml",
"configs/params_dev.yaml:data_load": {
"dataset": "development",
"sampling": {
"enable": true,
"size": 1000
}
},
"configs/params_dev.yaml:train": {
"epochs": 10
}
}
---
- Git URL as `repo`.
>>> import json
>>> import dvc.api
>>> params = dvc.api.params_show(
... repo="https://github.com/iterative/demo-fashion-mnist")
{
"train": {
"batch_size": 128,
"hidden_units": 64,
"dropout": 0.4,
"num_epochs": 10,
"lr": 0.001,
"conv_activation": "relu"
}
}
.. _Git revision:
https://git-scm.com/docs/revisions
"""
if isinstance(stages, str):
stages = [stages]

def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs):
raise exception

def _postprocess(params):
processed = {}
for rev, rev_data in params.items():
processed[rev] = {}

counts = Counter()
for file_data in rev_data["data"].values():
for k in file_data["data"]:
counts[k] += 1

for file_name, file_data in rev_data["data"].items():
to_merge = {
(k if counts[k] == 1 else f"{file_name}:{k}"): v
for k, v in file_data["data"].items()
}
processed[rev] = {**processed[rev], **to_merge}

if "workspace" in processed:
del processed["workspace"]

return processed[first(processed)]

with Repo.open(repo) as _repo:
params = _repo.params.show(
revs=rev if rev is None else [rev],
targets=targets,
deps=deps,
onerror=_onerror_raise,
stages=stages,
)

return _postprocess(params)


def make_checkpoint():
"""
Signal DVC to create a checkpoint experiment.
Expand Down
46 changes: 38 additions & 8 deletions dvc/repo/params/show.py
Expand Up @@ -2,7 +2,15 @@
import os
from collections import defaultdict
from copy import copy
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
)

from scmrepo.exceptions import SCMError

Expand Down Expand Up @@ -63,18 +71,23 @@ def _read_params(
params_fs_paths,
deps=False,
onerror: Optional[Callable] = None,
stages: Optional[Iterable[str]] = None,
):
res: Dict[str, Dict] = defaultdict(lambda: defaultdict(dict))
fs_paths = copy(params_fs_paths)

if deps:
if deps or stages:
for param in params:
if stages and param.stage.name not in stages:
continue
params_dict = error_handler(param.read_params)(
onerror=onerror, flatten=False
)
if params_dict:
name = os.sep.join(repo.fs.path.relparts(param.fs_path))
res[name]["data"].update(params_dict["data"])
if name in fs_paths:
fs_paths.remove(name)
else:
fs_paths += [param.fs_path for param in params]

Expand All @@ -87,11 +100,13 @@ def _read_params(
return res


def _collect_vars(repo, params) -> Dict:
def _collect_vars(repo, params, stages=None) -> Dict:
vars_params: Dict[str, Dict] = defaultdict(dict)

for stage in repo.index.stages:
if isinstance(stage, PipelineStage) and stage.tracked_vars:
if stages and stage.name not in stages:
continue
for file, vars_ in stage.tracked_vars.items():
# `params` file are shown regardless of `tracked` or not
# to reduce noise and duplication, they are skipped
Expand All @@ -104,14 +119,26 @@ def _collect_vars(repo, params) -> Dict:


@locked
def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None):
def show(
repo,
revs=None,
targets=None,
deps=False,
onerror: Callable = None,
stages=None,
):
if onerror is None:
onerror = onerror_collect
res = {}

for branch in repo.brancher(revs=revs):
params = error_handler(_gather_params)(
repo=repo, rev=branch, targets=targets, deps=deps, onerror=onerror
repo=repo,
rev=branch,
targets=targets,
deps=deps,
onerror=onerror,
stages=stages,
)

if params:
Expand All @@ -138,18 +165,21 @@ def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None):
return res


def _gather_params(repo, rev, targets=None, deps=False, onerror=None):
def _gather_params(
repo, rev, targets=None, deps=False, onerror=None, stages=None
):
param_outs, params_fs_paths = _collect_configs(
repo, rev, targets=targets, duplicates=deps
repo, rev, targets=targets, duplicates=deps or stages
)
params = _read_params(
repo,
params=param_outs,
params_fs_paths=params_fs_paths,
deps=deps,
onerror=onerror,
stages=stages,
)
vars_params = _collect_vars(repo, params)
vars_params = _collect_vars(repo, params, stages=stages)

# NOTE: only those that are not added as a ParamDependency are
# included so we don't need to recursively merge them yet.
Expand Down

0 comments on commit 057c9d2

Please sign in to comment.