diff --git a/dvc/api.py b/dvc/api.py index a57a67fefc..61ff23af46 100644 --- a/dvc/api.py +++ b/dvc/api.py @@ -1,8 +1,9 @@ import os +from collections import Counter from contextlib import _GeneratorContextManager as GCM -from typing import Optional +from typing import Dict, Iterable, Optional, Union -from funcy import reraise +from funcy import first, reraise from dvc.exceptions import OutputNotFoundError, PathMissingError from dvc.repo import Repo @@ -214,6 +215,267 @@ def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): return fd.read() +def params_show( + *targets: str, + repo: Optional[str] = None, + stages: Optional[Union[str, Iterable[str]]] = None, + rev: Optional[str] = None, + deps: bool = False, +) -> Dict: + """Get parameters tracked in `repo`. + + Without arguments, this function will retrieve all params from all tracked + parameter files, for the current working tree. + + See the options below to restrict the parameters retrieved. + + Args: + *targets (str, optional): Names of the parameter files to retrieve + params from. For example, "params.py, myparams.toml". + If no `targets` are provided, all parameter files tracked in `dvc.yaml` + will be used. + Note that targets don't necessarily have to be defined in `dvc.yaml`. + repo (str, optional): location of the DVC repository. + Defaults to the current project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + stages (Union[str, Iterable[str]], optional): Name or names of the + stages to retrieve parameters from. + Defaults to `None`. + If `None`, all parameters from all stages will be retrieved. + rev (str, optional): Name of the `Git revision`_ to retrieve parameters + from. + Defaults to `None`. + An example of git revision can be a branch or tag name, a commit + hash or a dvc experiment name. + If `repo` is not a Git repo, this option is ignored. + If `None`, the current working tree will be used. + deps (bool, optional): Whether to retrieve only parameters that are + stage dependencies or not. + Defaults to `False`. + + Returns: + Dict: See Examples below. + + Examples: + + - No arguments. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show() + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Filtering with `stages`. + + Working on https://github.com/iterative/example-get-started + + `stages` can a single string: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages="prepare") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + } + } + + Or an iterable of strings: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages=["prepare", "train"]) + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Using `rev`. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(rev="tune-hyperparams") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 100, + "min_split": 8 + } + } + + --- + + - Using `targets`. + + Working on `multi-params-files` folder of + https://github.com/iterative/pipeline-conifguration + + You can pass a single target: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show("params.yaml") + >>> print(json.dumps(params, indent=4)) + { + "run_mode": "prod", + "configs": { + "dev": "configs/params_dev.yaml", + "test": "configs/params_test.yaml", + "prod": "configs/params_prod.yaml" + }, + "evaluate": { + "dataset": "micro", + "size": 5000, + "metrics": ["f1", "roc-auc"], + "metrics_file": "reports/metrics.json", + "plots_cm": "reports/plot_confusion_matrix.png" + } + } + + + Or multiple targets: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... "configs/params_dev.yaml", "configs/params_prod.yaml") + >>> print(json.dumps(params, indent=4)) + { + "configs/params_prod.yaml:run_mode": "prod", + "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", + "configs/params_prod.yaml:data_load": { + "dataset": "large", + "sampling": { + "enable": true, + "size": 50000 + } + }, + "configs/params_prod.yaml:train": { + "epochs": 1000 + }, + "configs/params_dev.yaml:run_mode": "dev", + "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", + "configs/params_dev.yaml:data_load": { + "dataset": "development", + "sampling": { + "enable": true, + "size": 1000 + } + }, + "configs/params_dev.yaml:train": { + "epochs": 10 + } + } + + --- + + - Git URL as `repo`. + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... repo="https://github.com/iterative/demo-fashion-mnist") + { + "train": { + "batch_size": 128, + "hidden_units": 64, + "dropout": 0.4, + "num_epochs": 10, + "lr": 0.001, + "conv_activation": "relu" + } + } + + + .. _Git revision: + https://git-scm.com/docs/revisions + + """ + if isinstance(stages, str): + stages = [stages] + + def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): + raise exception + + def _postprocess(params): + processed = {} + for rev, rev_data in params.items(): + processed[rev] = {} + + counts = Counter() + for file_data in rev_data["data"].values(): + for k in file_data["data"]: + counts[k] += 1 + + for file_name, file_data in rev_data["data"].items(): + to_merge = { + (k if counts[k] == 1 else f"{file_name}:{k}"): v + for k, v in file_data["data"].items() + } + processed[rev] = {**processed[rev], **to_merge} + + if "workspace" in processed: + del processed["workspace"] + + return processed[first(processed)] + + with Repo.open(repo) as _repo: + params = _repo.params.show( + revs=rev if rev is None else [rev], + targets=targets, + deps=deps, + onerror=_onerror_raise, + stages=stages, + ) + + return _postprocess(params) + + def make_checkpoint(): """ Signal DVC to create a checkpoint experiment. diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 668577af2e..2b8d5decac 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -2,7 +2,15 @@ import os from collections import defaultdict from copy import copy -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, +) from scmrepo.exceptions import SCMError @@ -63,18 +71,23 @@ def _read_params( params_fs_paths, deps=False, onerror: Optional[Callable] = None, + stages: Optional[Iterable[str]] = None, ): res: Dict[str, Dict] = defaultdict(lambda: defaultdict(dict)) fs_paths = copy(params_fs_paths) - if deps: + if deps or stages: for param in params: + if stages and param.stage.name not in stages: + continue params_dict = error_handler(param.read_params)( onerror=onerror, flatten=False ) if params_dict: name = os.sep.join(repo.fs.path.relparts(param.fs_path)) res[name]["data"].update(params_dict["data"]) + if name in fs_paths: + fs_paths.remove(name) else: fs_paths += [param.fs_path for param in params] @@ -87,11 +100,13 @@ def _read_params( return res -def _collect_vars(repo, params) -> Dict: +def _collect_vars(repo, params, stages=None) -> Dict: vars_params: Dict[str, Dict] = defaultdict(dict) for stage in repo.index.stages: if isinstance(stage, PipelineStage) and stage.tracked_vars: + if stages and stage.name not in stages: + continue for file, vars_ in stage.tracked_vars.items(): # `params` file are shown regardless of `tracked` or not # to reduce noise and duplication, they are skipped @@ -104,14 +119,26 @@ def _collect_vars(repo, params) -> Dict: @locked -def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): +def show( + repo, + revs=None, + targets=None, + deps=False, + onerror: Callable = None, + stages=None, +): if onerror is None: onerror = onerror_collect res = {} for branch in repo.brancher(revs=revs): params = error_handler(_gather_params)( - repo=repo, rev=branch, targets=targets, deps=deps, onerror=onerror + repo=repo, + rev=branch, + targets=targets, + deps=deps, + onerror=onerror, + stages=stages, ) if params: @@ -138,9 +165,11 @@ def show(repo, revs=None, targets=None, deps=False, onerror: Callable = None): return res -def _gather_params(repo, rev, targets=None, deps=False, onerror=None): +def _gather_params( + repo, rev, targets=None, deps=False, onerror=None, stages=None +): param_outs, params_fs_paths = _collect_configs( - repo, rev, targets=targets, duplicates=deps + repo, rev, targets=targets, duplicates=deps or stages ) params = _read_params( repo, @@ -148,8 +177,9 @@ def _gather_params(repo, rev, targets=None, deps=False, onerror=None): params_fs_paths=params_fs_paths, deps=deps, onerror=onerror, + stages=stages, ) - vars_params = _collect_vars(repo, params) + vars_params = _collect_vars(repo, params, stages=stages) # NOTE: only those that are not added as a ParamDependency are # included so we don't need to recursively merge them yet. diff --git a/tests/func/test_api.py b/tests/func/test_api.py index 44bfd54f21..116318e3ad 100644 --- a/tests/func/test_api.py +++ b/tests/func/test_api.py @@ -1,4 +1,5 @@ import os +from textwrap import dedent import pytest from funcy import first, get_in @@ -228,3 +229,128 @@ def test_open_from_remote(tmp_dir, erepo_dir, cloud, local_cloud): remote="other", ) as fd: assert fd.read() == "foo content" + + +@pytest.fixture +def params_repo(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1") + tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') + tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') + + dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo", "params.json:bar"], + ) + + dvc.run( + name="stage-2", + cmd="echo stage-2", + params=["other_params.json:foo"], + ) + + dvc.run( + name="stage-3", + cmd="echo stage-2", + params=["params.json:foobar"], + ) + + scm.add( + [ + "params.yaml", + "params.json", + "other_params.json", + "dvc.yaml", + "dvc.lock", + ] + ) + scm.commit("commit dvc files") + + tmp_dir.gen("params.yaml", "foo: 5") + scm.add(["params.yaml"]) + scm.commit("update params.yaml") + + +def test_params_show_no_args(params_repo): + assert api.params_show() == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_targets(params_repo): + assert api.params_show("params.yaml") == {"foo": 5} + assert api.params_show("params.yaml", "params.json") == { + "foo": 5, + "bar": 2, + "foobar": 3, + } + + +def test_params_show_deps(params_repo): + params = api.params_show(deps=True) + assert params == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_stages(params_repo): + assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} + + assert api.params_show() == api.params_show( + stages=["stage-1", "stage-2", "stage-3"] + ) + + assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} + + +def test_params_show_revs(params_repo): + assert api.params_show(rev="HEAD~1") == { + "params.yaml:foo": 1, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_while_running_stage(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) + (tmp_dir / "params.json").dump({"bar": 2}) + + tmp_dir.gen( + "merge.py", + dedent( + """ + import json + from dvc import api + with open("merged.json", "w") as f: + json.dump(api.params_show(stages="merge"), f) + """ + ), + ) + dvc.stage.add( + name="merge", + cmd="python merge.py", + params=["foo.bar", {"params.json": ["bar"]}], + outs=["merged.json"], + ) + + dvc.reproduce() + + assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} + + +def test_params_show_repo(tmp_dir, erepo_dir): + with erepo_dir.chdir(): + erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") + erepo_dir.dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo"], + ) + assert api.params_show(repo=erepo_dir) == {"foo": 1}