diff --git a/dvc/api.py b/dvc/api.py deleted file mode 100644 index 61ff23af46..0000000000 --- a/dvc/api.py +++ /dev/null @@ -1,508 +0,0 @@ -import os -from collections import Counter -from contextlib import _GeneratorContextManager as GCM -from typing import Dict, Iterable, Optional, Union - -from funcy import first, reraise - -from dvc.exceptions import OutputNotFoundError, PathMissingError -from dvc.repo import Repo - - -def get_url(path, repo=None, rev=None, remote=None): - """ - Returns the URL to the storage location of a data file or directory tracked - in a DVC repo. For Git repos, HEAD is used unless a rev argument is - supplied. The default remote is tried unless a remote argument is supplied. - - Raises OutputNotFoundError if the file is not tracked by DVC. - - NOTE: This function does not check for the actual existence of the file or - directory in the remote storage. - """ - with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: - fs_path = _repo.dvcfs.from_os_path(path) - with reraise(FileNotFoundError, PathMissingError(path, repo)): - info = _repo.dvcfs.info(fs_path) - - dvc_info = info.get("dvc_info") - if not dvc_info: - raise OutputNotFoundError(path, repo) - - dvc_repo = info["repo"] - md5 = dvc_info["md5"] - - return dvc_repo.cloud.get_url_for(remote, checksum=md5) - - -class _OpenContextManager(GCM): - def __init__( - self, func, args, kwds - ): # pylint: disable=super-init-not-called - self.gen = func(*args, **kwds) - self.func, self.args, self.kwds = func, args, kwds - - def __getattr__(self, name): - raise AttributeError( - "dvc.api.open() should be used in a with statement." - ) - - -def open( # noqa, pylint: disable=redefined-builtin - path: str, - repo: Optional[str] = None, - rev: Optional[str] = None, - remote: Optional[str] = None, - mode: str = "r", - encoding: Optional[str] = None, -): - """ - Opens a file tracked in a DVC project. - - This function may only be used as a context manager (using the `with` - keyword, as shown in the examples). - - This function makes a direct connection to the remote storage, so the file - contents can be streamed. Your code can process the data buffer as it's - streamed, which optimizes memory usage. - - Note: - Use dvc.api.read() to load the complete file contents - in a single function call, no context manager involved. - Neither function utilizes disc space. - - Args: - path (str): location and file name of the target to open, - relative to the root of `repo`. - repo (str, optional): location of the DVC project or Git Repo. - Defaults to the current DVC project (found by walking up from the - current working directory tree). - It can be a URL or a file system path. - Both HTTP and SSH protocols are supported for online Git repos - (e.g. [user@]server:project.git). - rev (str, optional): Any `Git revision`_ such as a branch or tag name, - a commit hash or a dvc experiment name. - Defaults to HEAD. - If `repo` is not a Git repo, this option is ignored. - remote (str, optional): Name of the `DVC remote`_ used to form the - returned URL string. - Defaults to the `default remote`_ of `repo`. - For local projects, the cache is tried before the default remote. - mode (str, optional): Specifies the mode in which the file is opened. - Defaults to "r" (read). - Mirrors the namesake parameter in builtin `open()`_. - Only reading `mode` is supported. - encoding(str, optional): `Codec`_ used to decode the file contents. - Defaults to None. - This should only be used in text mode. - Mirrors the namesake parameter in builtin `open()`_. - - Returns: - _OpenContextManager: A context manager that generatse a corresponding - `file object`_. - The exact type of file object depends on the mode used. - For more details, please refer to Python's `open()`_ built-in, - which is used under the hood. - - Raises: - AttributeError: If this method is not used as a context manager. - ValueError: If non-read `mode` is used. - - Examples: - - - Use data or models from a DVC repository. - - Any file tracked in a DVC project (and stored remotely) can be - processed directly in your Python code with this API. - For example, an XML file tracked in a public DVC repo on GitHub can be - processed like this: - - >>> from xml.sax import parse - >>> import dvc.api - >>> from mymodule import mySAXHandler - - >>> with dvc.api.open( - ... 'get-started/data.xml', - ... repo='https://github.com/iterative/dataset-registry' - ... ) as fd: - ... parse(fd, mySAXHandler) - - We use a SAX XML parser here because dvc.api.open() is able to stream - the data from remote storage. - The mySAXHandler object should handle the event-driven parsing of the - document in this case. - This increases the performance of the code (minimizing memory usage), - and is typically faster than loading the whole data into memory. - - - Accessing private repos - - This is just a matter of using the right repo argument, for example an - SSH URL (requires that the credentials are configured locally): - - >>> import dvc.api - - >>> with dvc.api.open( - ... 'features.dat', - ... repo='git@server.com:path/to/repo.git' - ... ) as fd: - ... # ... Process 'features' - ... pass - - - Use different versions of data - - Any git revision (see `rev`) can be accessed programmatically. - For example, if your DVC repo has tagged releases of a CSV dataset: - - >>> import csv - >>> import dvc.api - >>> with dvc.api.open( - ... 'clean.csv', - ... rev='v1.1.0' - ... ) as fd: - ... reader = csv.reader(fd) - ... # ... Process 'clean' data from version 1.1.0 - - .. _Git revision: - https://git-scm.com/docs/revisions - - .. _DVC remote: - https://dvc.org/doc/command-reference/remote - - .. _default remote: - https://dvc.org/doc/command-reference/remote/default - - .. _open(): - https://docs.python.org/3/library/functions.html#open - - .. _Codec: - https://docs.python.org/3/library/codecs.html#standard-encodings - - .. _file object: - https://docs.python.org/3/glossary.html#term-file-object - - """ - if "r" not in mode: - raise ValueError("Only reading `mode` is supported.") - - args = (path,) - kwargs = { - "repo": repo, - "remote": remote, - "rev": rev, - "mode": mode, - "encoding": encoding, - } - return _OpenContextManager(_open, args, kwargs) - - -def _open(path, repo=None, rev=None, remote=None, mode="r", encoding=None): - with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: - with _repo.open_by_relpath( - path, remote=remote, mode=mode, encoding=encoding - ) as fd: - yield fd - - -def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): - """ - Returns the contents of a tracked file (by DVC or Git). For Git repos, HEAD - is used unless a rev argument is supplied. The default remote is tried - unless a remote argument is supplied. - """ - with open( - path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding - ) as fd: - return fd.read() - - -def params_show( - *targets: str, - repo: Optional[str] = None, - stages: Optional[Union[str, Iterable[str]]] = None, - rev: Optional[str] = None, - deps: bool = False, -) -> Dict: - """Get parameters tracked in `repo`. - - Without arguments, this function will retrieve all params from all tracked - parameter files, for the current working tree. - - See the options below to restrict the parameters retrieved. - - Args: - *targets (str, optional): Names of the parameter files to retrieve - params from. For example, "params.py, myparams.toml". - If no `targets` are provided, all parameter files tracked in `dvc.yaml` - will be used. - Note that targets don't necessarily have to be defined in `dvc.yaml`. - repo (str, optional): location of the DVC repository. - Defaults to the current project (found by walking up from the - current working directory tree). - It can be a URL or a file system path. - Both HTTP and SSH protocols are supported for online Git repos - (e.g. [user@]server:project.git). - stages (Union[str, Iterable[str]], optional): Name or names of the - stages to retrieve parameters from. - Defaults to `None`. - If `None`, all parameters from all stages will be retrieved. - rev (str, optional): Name of the `Git revision`_ to retrieve parameters - from. - Defaults to `None`. - An example of git revision can be a branch or tag name, a commit - hash or a dvc experiment name. - If `repo` is not a Git repo, this option is ignored. - If `None`, the current working tree will be used. - deps (bool, optional): Whether to retrieve only parameters that are - stage dependencies or not. - Defaults to `False`. - - Returns: - Dict: See Examples below. - - Examples: - - - No arguments. - - Working on https://github.com/iterative/example-get-started - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show() - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - }, - "featurize": { - "max_features": 200, - "ngrams": 2 - }, - "train": { - "seed": 20170428, - "n_est": 50, - "min_split": 0.01 - } - } - - --- - - - Filtering with `stages`. - - Working on https://github.com/iterative/example-get-started - - `stages` can a single string: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show(stages="prepare") - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - } - } - - Or an iterable of strings: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show(stages=["prepare", "train"]) - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - }, - "train": { - "seed": 20170428, - "n_est": 50, - "min_split": 0.01 - } - } - - --- - - - Using `rev`. - - Working on https://github.com/iterative/example-get-started - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show(rev="tune-hyperparams") - >>> print(json.dumps(params, indent=4)) - { - "prepare": { - "split": 0.2, - "seed": 20170428 - }, - "featurize": { - "max_features": 200, - "ngrams": 2 - }, - "train": { - "seed": 20170428, - "n_est": 100, - "min_split": 8 - } - } - - --- - - - Using `targets`. - - Working on `multi-params-files` folder of - https://github.com/iterative/pipeline-conifguration - - You can pass a single target: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show("params.yaml") - >>> print(json.dumps(params, indent=4)) - { - "run_mode": "prod", - "configs": { - "dev": "configs/params_dev.yaml", - "test": "configs/params_test.yaml", - "prod": "configs/params_prod.yaml" - }, - "evaluate": { - "dataset": "micro", - "size": 5000, - "metrics": ["f1", "roc-auc"], - "metrics_file": "reports/metrics.json", - "plots_cm": "reports/plot_confusion_matrix.png" - } - } - - - Or multiple targets: - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show( - ... "configs/params_dev.yaml", "configs/params_prod.yaml") - >>> print(json.dumps(params, indent=4)) - { - "configs/params_prod.yaml:run_mode": "prod", - "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", - "configs/params_prod.yaml:data_load": { - "dataset": "large", - "sampling": { - "enable": true, - "size": 50000 - } - }, - "configs/params_prod.yaml:train": { - "epochs": 1000 - }, - "configs/params_dev.yaml:run_mode": "dev", - "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", - "configs/params_dev.yaml:data_load": { - "dataset": "development", - "sampling": { - "enable": true, - "size": 1000 - } - }, - "configs/params_dev.yaml:train": { - "epochs": 10 - } - } - - --- - - - Git URL as `repo`. - - >>> import json - >>> import dvc.api - >>> params = dvc.api.params_show( - ... repo="https://github.com/iterative/demo-fashion-mnist") - { - "train": { - "batch_size": 128, - "hidden_units": 64, - "dropout": 0.4, - "num_epochs": 10, - "lr": 0.001, - "conv_activation": "relu" - } - } - - - .. _Git revision: - https://git-scm.com/docs/revisions - - """ - if isinstance(stages, str): - stages = [stages] - - def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): - raise exception - - def _postprocess(params): - processed = {} - for rev, rev_data in params.items(): - processed[rev] = {} - - counts = Counter() - for file_data in rev_data["data"].values(): - for k in file_data["data"]: - counts[k] += 1 - - for file_name, file_data in rev_data["data"].items(): - to_merge = { - (k if counts[k] == 1 else f"{file_name}:{k}"): v - for k, v in file_data["data"].items() - } - processed[rev] = {**processed[rev], **to_merge} - - if "workspace" in processed: - del processed["workspace"] - - return processed[first(processed)] - - with Repo.open(repo) as _repo: - params = _repo.params.show( - revs=rev if rev is None else [rev], - targets=targets, - deps=deps, - onerror=_onerror_raise, - stages=stages, - ) - - return _postprocess(params) - - -def make_checkpoint(): - """ - Signal DVC to create a checkpoint experiment. - - If the current process is being run from DVC, this function will block - until DVC has finished creating the checkpoint. Otherwise, this function - will return immediately. - """ - import builtins - from time import sleep - - from dvc.env import DVC_CHECKPOINT, DVC_ROOT - from dvc.stage.monitor import CheckpointTask - - if os.getenv(DVC_CHECKPOINT) is None: - return - - root_dir = os.getenv(DVC_ROOT, Repo.find_root()) - signal_file = os.path.join( - root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE - ) - - with builtins.open(signal_file, "w", encoding="utf-8") as fobj: - # NOTE: force flushing/writing empty file to disk, otherwise when - # run in certain contexts (pytest) file may not actually be written - fobj.write("") - fobj.flush() - os.fsync(fobj.fileno()) - while os.path.exists(signal_file): - sleep(0.1) diff --git a/dvc/api/__init__.py b/dvc/api/__init__.py new file mode 100644 index 0000000000..dde993ec76 --- /dev/null +++ b/dvc/api/__init__.py @@ -0,0 +1,9 @@ +from .data import ( # noqa, pylint: disable=redefined-builtin + get_url, + open, + read, +) +from .experiments import make_checkpoint +from .params import params_show + +__all__ = ["get_url", "make_checkpoint", "open", "params_show", "read"] diff --git a/dvc/api/data.py b/dvc/api/data.py new file mode 100644 index 0000000000..a063612f10 --- /dev/null +++ b/dvc/api/data.py @@ -0,0 +1,213 @@ +from contextlib import _GeneratorContextManager as GCM +from typing import Optional + +from funcy import reraise + +from dvc.exceptions import OutputNotFoundError, PathMissingError +from dvc.repo import Repo + + +def get_url(path, repo=None, rev=None, remote=None): + """ + Returns the URL to the storage location of a data file or directory tracked + in a DVC repo. For Git repos, HEAD is used unless a rev argument is + supplied. The default remote is tried unless a remote argument is supplied. + + Raises OutputNotFoundError if the file is not tracked by DVC. + + NOTE: This function does not check for the actual existence of the file or + directory in the remote storage. + """ + with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: + fs_path = _repo.dvcfs.from_os_path(path) + with reraise(FileNotFoundError, PathMissingError(path, repo)): + info = _repo.dvcfs.info(fs_path) + + dvc_info = info.get("dvc_info") + if not dvc_info: + raise OutputNotFoundError(path, repo) + + dvc_repo = info["repo"] + md5 = dvc_info["md5"] + + return dvc_repo.cloud.get_url_for(remote, checksum=md5) + + +class _OpenContextManager(GCM): + def __init__( + self, func, args, kwds + ): # pylint: disable=super-init-not-called + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + + def __getattr__(self, name): + raise AttributeError( + "dvc.api.open() should be used in a with statement." + ) + + +def open( # noqa, pylint: disable=redefined-builtin + path: str, + repo: Optional[str] = None, + rev: Optional[str] = None, + remote: Optional[str] = None, + mode: str = "r", + encoding: Optional[str] = None, +): + """ + Opens a file tracked in a DVC project. + + This function may only be used as a context manager (using the `with` + keyword, as shown in the examples). + + This function makes a direct connection to the remote storage, so the file + contents can be streamed. Your code can process the data buffer as it's + streamed, which optimizes memory usage. + + Note: + Use dvc.api.read() to load the complete file contents + in a single function call, no context manager involved. + Neither function utilizes disc space. + + Args: + path (str): location and file name of the target to open, + relative to the root of `repo`. + repo (str, optional): location of the DVC project or Git Repo. + Defaults to the current DVC project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + rev (str, optional): Any `Git revision`_ such as a branch or tag name, + a commit hash or a dvc experiment name. + Defaults to HEAD. + If `repo` is not a Git repo, this option is ignored. + remote (str, optional): Name of the `DVC remote`_ used to form the + returned URL string. + Defaults to the `default remote`_ of `repo`. + For local projects, the cache is tried before the default remote. + mode (str, optional): Specifies the mode in which the file is opened. + Defaults to "r" (read). + Mirrors the namesake parameter in builtin `open()`_. + Only reading `mode` is supported. + encoding(str, optional): `Codec`_ used to decode the file contents. + Defaults to None. + This should only be used in text mode. + Mirrors the namesake parameter in builtin `open()`_. + + Returns: + _OpenContextManager: A context manager that generatse a corresponding + `file object`_. + The exact type of file object depends on the mode used. + For more details, please refer to Python's `open()`_ built-in, + which is used under the hood. + + Raises: + AttributeError: If this method is not used as a context manager. + ValueError: If non-read `mode` is used. + + Examples: + + - Use data or models from a DVC repository. + + Any file tracked in a DVC project (and stored remotely) can be + processed directly in your Python code with this API. + For example, an XML file tracked in a public DVC repo on GitHub can be + processed like this: + + >>> from xml.sax import parse + >>> import dvc.api + >>> from mymodule import mySAXHandler + + >>> with dvc.api.open( + ... 'get-started/data.xml', + ... repo='https://github.com/iterative/dataset-registry' + ... ) as fd: + ... parse(fd, mySAXHandler) + + We use a SAX XML parser here because dvc.api.open() is able to stream + the data from remote storage. + The mySAXHandler object should handle the event-driven parsing of the + document in this case. + This increases the performance of the code (minimizing memory usage), + and is typically faster than loading the whole data into memory. + + - Accessing private repos + + This is just a matter of using the right repo argument, for example an + SSH URL (requires that the credentials are configured locally): + + >>> import dvc.api + + >>> with dvc.api.open( + ... 'features.dat', + ... repo='git@server.com:path/to/repo.git' + ... ) as fd: + ... # ... Process 'features' + ... pass + + - Use different versions of data + + Any git revision (see `rev`) can be accessed programmatically. + For example, if your DVC repo has tagged releases of a CSV dataset: + + >>> import csv + >>> import dvc.api + >>> with dvc.api.open( + ... 'clean.csv', + ... rev='v1.1.0' + ... ) as fd: + ... reader = csv.reader(fd) + ... # ... Process 'clean' data from version 1.1.0 + + .. _Git revision: + https://git-scm.com/docs/revisions + + .. _DVC remote: + https://dvc.org/doc/command-reference/remote + + .. _default remote: + https://dvc.org/doc/command-reference/remote/default + + .. _open(): + https://docs.python.org/3/library/functions.html#open + + .. _Codec: + https://docs.python.org/3/library/codecs.html#standard-encodings + + .. _file object: + https://docs.python.org/3/glossary.html#term-file-object + + """ + if "r" not in mode: + raise ValueError("Only reading `mode` is supported.") + + args = (path,) + kwargs = { + "repo": repo, + "remote": remote, + "rev": rev, + "mode": mode, + "encoding": encoding, + } + return _OpenContextManager(_open, args, kwargs) + + +def _open(path, repo=None, rev=None, remote=None, mode="r", encoding=None): + with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo: + with _repo.open_by_relpath( + path, remote=remote, mode=mode, encoding=encoding + ) as fd: + yield fd + + +def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None): + """ + Returns the contents of a tracked file (by DVC or Git). For Git repos, HEAD + is used unless a rev argument is supplied. The default remote is tried + unless a remote argument is supplied. + """ + with open( + path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding + ) as fd: + return fd.read() diff --git a/dvc/api/experiments.py b/dvc/api/experiments.py new file mode 100644 index 0000000000..f6448eedb0 --- /dev/null +++ b/dvc/api/experiments.py @@ -0,0 +1,33 @@ +import builtins +import os +from time import sleep + +from dvc.env import DVC_CHECKPOINT, DVC_ROOT +from dvc.repo import Repo +from dvc.stage.monitor import CheckpointTask + + +def make_checkpoint(): + """ + Signal DVC to create a checkpoint experiment. + + If the current process is being run from DVC, this function will block + until DVC has finished creating the checkpoint. Otherwise, this function + will return immediately. + """ + if os.getenv(DVC_CHECKPOINT) is None: + return + + root_dir = os.getenv(DVC_ROOT, Repo.find_root()) + signal_file = os.path.join( + root_dir, Repo.DVC_DIR, "tmp", CheckpointTask.SIGNAL_FILE + ) + + with builtins.open(signal_file, "w", encoding="utf-8") as fobj: + # NOTE: force flushing/writing empty file to disk, otherwise when + # run in certain contexts (pytest) file may not actually be written + fobj.write("") + fobj.flush() + os.fsync(fobj.fileno()) + while os.path.exists(signal_file): + sleep(0.1) diff --git a/dvc/api/params.py b/dvc/api/params.py new file mode 100644 index 0000000000..d66b19a591 --- /dev/null +++ b/dvc/api/params.py @@ -0,0 +1,267 @@ +from collections import Counter +from typing import Dict, Iterable, Optional, Union + +from funcy import first + +from dvc.repo import Repo + + +def params_show( + *targets: str, + repo: Optional[str] = None, + stages: Optional[Union[str, Iterable[str]]] = None, + rev: Optional[str] = None, + deps: bool = False, +) -> Dict: + """Get parameters tracked in `repo`. + + Without arguments, this function will retrieve all params from all tracked + parameter files, for the current working tree. + + See the options below to restrict the parameters retrieved. + + Args: + *targets (str, optional): Names of the parameter files to retrieve + params from. For example, "params.py, myparams.toml". + If no `targets` are provided, all parameter files tracked in `dvc.yaml` + will be used. + Note that targets don't necessarily have to be defined in `dvc.yaml`. + repo (str, optional): location of the DVC repository. + Defaults to the current project (found by walking up from the + current working directory tree). + It can be a URL or a file system path. + Both HTTP and SSH protocols are supported for online Git repos + (e.g. [user@]server:project.git). + stages (Union[str, Iterable[str]], optional): Name or names of the + stages to retrieve parameters from. + Defaults to `None`. + If `None`, all parameters from all stages will be retrieved. + rev (str, optional): Name of the `Git revision`_ to retrieve parameters + from. + Defaults to `None`. + An example of git revision can be a branch or tag name, a commit + hash or a dvc experiment name. + If `repo` is not a Git repo, this option is ignored. + If `None`, the current working tree will be used. + deps (bool, optional): Whether to retrieve only parameters that are + stage dependencies or not. + Defaults to `False`. + + Returns: + Dict: See Examples below. + + Examples: + + - No arguments. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show() + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Filtering with `stages`. + + Working on https://github.com/iterative/example-get-started + + `stages` can a single string: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages="prepare") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + } + } + + Or an iterable of strings: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(stages=["prepare", "train"]) + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "train": { + "seed": 20170428, + "n_est": 50, + "min_split": 0.01 + } + } + + --- + + - Using `rev`. + + Working on https://github.com/iterative/example-get-started + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show(rev="tune-hyperparams") + >>> print(json.dumps(params, indent=4)) + { + "prepare": { + "split": 0.2, + "seed": 20170428 + }, + "featurize": { + "max_features": 200, + "ngrams": 2 + }, + "train": { + "seed": 20170428, + "n_est": 100, + "min_split": 8 + } + } + + --- + + - Using `targets`. + + Working on `multi-params-files` folder of + https://github.com/iterative/pipeline-conifguration + + You can pass a single target: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show("params.yaml") + >>> print(json.dumps(params, indent=4)) + { + "run_mode": "prod", + "configs": { + "dev": "configs/params_dev.yaml", + "test": "configs/params_test.yaml", + "prod": "configs/params_prod.yaml" + }, + "evaluate": { + "dataset": "micro", + "size": 5000, + "metrics": ["f1", "roc-auc"], + "metrics_file": "reports/metrics.json", + "plots_cm": "reports/plot_confusion_matrix.png" + } + } + + + Or multiple targets: + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... "configs/params_dev.yaml", "configs/params_prod.yaml") + >>> print(json.dumps(params, indent=4)) + { + "configs/params_prod.yaml:run_mode": "prod", + "configs/params_prod.yaml:config_file": "configs/params_prod.yaml", + "configs/params_prod.yaml:data_load": { + "dataset": "large", + "sampling": { + "enable": true, + "size": 50000 + } + }, + "configs/params_prod.yaml:train": { + "epochs": 1000 + }, + "configs/params_dev.yaml:run_mode": "dev", + "configs/params_dev.yaml:config_file": "configs/params_dev.yaml", + "configs/params_dev.yaml:data_load": { + "dataset": "development", + "sampling": { + "enable": true, + "size": 1000 + } + }, + "configs/params_dev.yaml:train": { + "epochs": 10 + } + } + + --- + + - Git URL as `repo`. + + >>> import json + >>> import dvc.api + >>> params = dvc.api.params_show( + ... repo="https://github.com/iterative/demo-fashion-mnist") + { + "train": { + "batch_size": 128, + "hidden_units": 64, + "dropout": 0.4, + "num_epochs": 10, + "lr": 0.001, + "conv_activation": "relu" + } + } + + + .. _Git revision: + https://git-scm.com/docs/revisions + + """ + if isinstance(stages, str): + stages = [stages] + + def _onerror_raise(result: Dict, exception: Exception, *args, **kwargs): + raise exception + + def _postprocess(params): + processed = {} + for rev, rev_data in params.items(): + processed[rev] = {} + + counts = Counter() + for file_data in rev_data["data"].values(): + for k in file_data["data"]: + counts[k] += 1 + + for file_name, file_data in rev_data["data"].items(): + to_merge = { + (k if counts[k] == 1 else f"{file_name}:{k}"): v + for k, v in file_data["data"].items() + } + processed[rev] = {**processed[rev], **to_merge} + + if "workspace" in processed: + del processed["workspace"] + + return processed[first(processed)] + + with Repo.open(repo) as _repo: + params = _repo.params.show( + revs=rev if rev is None else [rev], + targets=targets, + deps=deps, + onerror=_onerror_raise, + stages=stages, + ) + + return _postprocess(params) diff --git a/dvc/testing/tmp_dir.py b/dvc/testing/tmp_dir.py index 2517166eb4..ea4cd0d777 100644 --- a/dvc/testing/tmp_dir.py +++ b/dvc/testing/tmp_dir.py @@ -259,6 +259,15 @@ def modify(self, *args, **kwargs): dump_toml = partialmethod(serialize.dump_toml) +def make_subrepo(dir_: TmpDir, scm, config=None): + dir_.mkdir(parents=True, exist_ok=True) + with dir_.chdir(): + dir_.scm = scm + dir_.init(dvc=True, subdir=True) + if config: + dir_.add_remote(config=config) + + def _coerce_filenames(filenames): if isinstance(filenames, (str, bytes, pathlib.PurePath)): filenames = [filenames] diff --git a/tests/func/test_api.py b/tests/func/api/test_data.py similarity index 69% rename from tests/func/test_api.py rename to tests/func/api/test_data.py index 116318e3ad..d73226175d 100644 --- a/tests/func/test_api.py +++ b/tests/func/api/test_data.py @@ -1,5 +1,4 @@ import os -from textwrap import dedent import pytest from funcy import first, get_in @@ -11,8 +10,8 @@ PathMissingError, ) from dvc.testing.test_api import TestAPI # noqa, pylint: disable=unused-import +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import remove -from tests.unit.fs.test_dvc import make_subrepo def test_get_url_external(tmp_dir, erepo_dir, cloud): @@ -229,128 +228,3 @@ def test_open_from_remote(tmp_dir, erepo_dir, cloud, local_cloud): remote="other", ) as fd: assert fd.read() == "foo content" - - -@pytest.fixture -def params_repo(tmp_dir, scm, dvc): - tmp_dir.gen("params.yaml", "foo: 1") - tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') - tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') - - dvc.run( - name="stage-1", - cmd="echo stage-1", - params=["foo", "params.json:bar"], - ) - - dvc.run( - name="stage-2", - cmd="echo stage-2", - params=["other_params.json:foo"], - ) - - dvc.run( - name="stage-3", - cmd="echo stage-2", - params=["params.json:foobar"], - ) - - scm.add( - [ - "params.yaml", - "params.json", - "other_params.json", - "dvc.yaml", - "dvc.lock", - ] - ) - scm.commit("commit dvc files") - - tmp_dir.gen("params.yaml", "foo: 5") - scm.add(["params.yaml"]) - scm.commit("update params.yaml") - - -def test_params_show_no_args(params_repo): - assert api.params_show() == { - "params.yaml:foo": 5, - "bar": 2, - "foobar": 3, - "other_params.json:foo": {"bar": 4}, - } - - -def test_params_show_targets(params_repo): - assert api.params_show("params.yaml") == {"foo": 5} - assert api.params_show("params.yaml", "params.json") == { - "foo": 5, - "bar": 2, - "foobar": 3, - } - - -def test_params_show_deps(params_repo): - params = api.params_show(deps=True) - assert params == { - "params.yaml:foo": 5, - "bar": 2, - "foobar": 3, - "other_params.json:foo": {"bar": 4}, - } - - -def test_params_show_stages(params_repo): - assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} - - assert api.params_show() == api.params_show( - stages=["stage-1", "stage-2", "stage-3"] - ) - - assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} - - -def test_params_show_revs(params_repo): - assert api.params_show(rev="HEAD~1") == { - "params.yaml:foo": 1, - "bar": 2, - "foobar": 3, - "other_params.json:foo": {"bar": 4}, - } - - -def test_params_show_while_running_stage(tmp_dir, dvc): - (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) - (tmp_dir / "params.json").dump({"bar": 2}) - - tmp_dir.gen( - "merge.py", - dedent( - """ - import json - from dvc import api - with open("merged.json", "w") as f: - json.dump(api.params_show(stages="merge"), f) - """ - ), - ) - dvc.stage.add( - name="merge", - cmd="python merge.py", - params=["foo.bar", {"params.json": ["bar"]}], - outs=["merged.json"], - ) - - dvc.reproduce() - - assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} - - -def test_params_show_repo(tmp_dir, erepo_dir): - with erepo_dir.chdir(): - erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") - erepo_dir.dvc.run( - name="stage-1", - cmd="echo stage-1", - params=["foo"], - ) - assert api.params_show(repo=erepo_dir) == {"foo": 1} diff --git a/tests/func/api/test_params.py b/tests/func/api/test_params.py new file mode 100644 index 0000000000..9a2fb0f83b --- /dev/null +++ b/tests/func/api/test_params.py @@ -0,0 +1,130 @@ +from textwrap import dedent + +import pytest + +from dvc import api + + +@pytest.fixture +def params_repo(tmp_dir, scm, dvc): + tmp_dir.gen("params.yaml", "foo: 1") + tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') + tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') + + dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo", "params.json:bar"], + ) + + dvc.run( + name="stage-2", + cmd="echo stage-2", + params=["other_params.json:foo"], + ) + + dvc.run( + name="stage-3", + cmd="echo stage-2", + params=["params.json:foobar"], + ) + + scm.add( + [ + "params.yaml", + "params.json", + "other_params.json", + "dvc.yaml", + "dvc.lock", + ] + ) + scm.commit("commit dvc files") + + tmp_dir.gen("params.yaml", "foo: 5") + scm.add(["params.yaml"]) + scm.commit("update params.yaml") + + +def test_params_show_no_args(params_repo): + assert api.params_show() == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_targets(params_repo): + assert api.params_show("params.yaml") == {"foo": 5} + assert api.params_show("params.yaml", "params.json") == { + "foo": 5, + "bar": 2, + "foobar": 3, + } + + +def test_params_show_deps(params_repo): + params = api.params_show(deps=True) + assert params == { + "params.yaml:foo": 5, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_stages(params_repo): + assert api.params_show(stages="stage-2") == {"foo": {"bar": 4}} + + assert api.params_show() == api.params_show( + stages=["stage-1", "stage-2", "stage-3"] + ) + + assert api.params_show("params.json", stages="stage-3") == {"foobar": 3} + + +def test_params_show_revs(params_repo): + assert api.params_show(rev="HEAD~1") == { + "params.yaml:foo": 1, + "bar": 2, + "foobar": 3, + "other_params.json:foo": {"bar": 4}, + } + + +def test_params_show_while_running_stage(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump({"foo": {"bar": 1}}) + (tmp_dir / "params.json").dump({"bar": 2}) + + tmp_dir.gen( + "merge.py", + dedent( + """ + import json + from dvc import api + with open("merged.json", "w") as f: + json.dump(api.params_show(stages="merge"), f) + """ + ), + ) + dvc.stage.add( + name="merge", + cmd="python merge.py", + params=["foo.bar", {"params.json": ["bar"]}], + outs=["merged.json"], + ) + + dvc.reproduce() + + assert (tmp_dir / "merged.json").parse() == {"foo": {"bar": 1}, "bar": 2} + + +def test_params_show_repo(tmp_dir, erepo_dir): + with erepo_dir.chdir(): + erepo_dir.scm_gen("params.yaml", "foo: 1", commit="Create params.yaml") + erepo_dir.dvc.run( + name="stage-1", + cmd="echo stage-1", + params=["foo"], + ) + assert api.params_show(repo=erepo_dir) == {"foo": 1} diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 6ee62ff72e..3988e44db0 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -544,7 +544,7 @@ def test_subdir(tmp_dir, scm, dvc, workspace): @pytest.mark.parametrize("workspace", [True, False]) def test_subrepo(tmp_dir, scm, workspace): - from tests.unit.fs.test_dvc import make_subrepo + from dvc.testing.tmp_dir import make_subrepo subrepo = tmp_dir / "dir" / "repo" make_subrepo(subrepo, scm) diff --git a/tests/func/test_external_repo.py b/tests/func/test_external_repo.py index 36e2bad56e..6e31edef48 100644 --- a/tests/func/test_external_repo.py +++ b/tests/func/test_external_repo.py @@ -4,11 +4,11 @@ from scmrepo.git import Git from dvc.external_repo import CLONES, external_repo +from dvc.testing.tmp_dir import make_subrepo from dvc.utils import relpath from dvc.utils.fs import makedirs, remove from dvc_data.stage import stage from dvc_data.transfer import transfer -from tests.unit.fs.test_dvc import make_subrepo def test_external_repo(erepo_dir, mocker): diff --git a/tests/func/test_get.py b/tests/func/test_get.py index fb477dfe0b..d2d7eb3cfb 100644 --- a/tests/func/test_get.py +++ b/tests/func/test_get.py @@ -8,8 +8,8 @@ from dvc.odbmgr import ODBManager from dvc.repo import Repo from dvc.repo.get import GetDVCFileError +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import makedirs -from tests.unit.fs.test_dvc import make_subrepo def test_get_repo_file(tmp_dir, erepo_dir): diff --git a/tests/func/test_import.py b/tests/func/test_import.py index a0c160f960..6de58fe4c5 100644 --- a/tests/func/test_import.py +++ b/tests/func/test_import.py @@ -12,8 +12,8 @@ from dvc.fs import system from dvc.odbmgr import ODBManager from dvc.stage.exceptions import StagePathNotFoundError +from dvc.testing.tmp_dir import make_subrepo from dvc.utils.fs import makedirs, remove -from tests.unit.fs.test_dvc import make_subrepo def test_import(tmp_dir, scm, dvc, erepo_dir): diff --git a/tests/func/test_update.py b/tests/func/test_update.py index 94ba2c9d36..37f98f314a 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -4,7 +4,7 @@ from dvc.dvcfile import Dvcfile from dvc.exceptions import InvalidArgumentError -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo @pytest.mark.parametrize("cached", [True, False]) diff --git a/tests/unit/fs/test_dvc.py b/tests/unit/fs/test_dvc.py index 06e48f1b3e..f7c6dec7bb 100644 --- a/tests/unit/fs/test_dvc.py +++ b/tests/unit/fs/test_dvc.py @@ -6,6 +6,7 @@ import pytest from dvc.fs.dvc import DvcFileSystem +from dvc.testing.tmp_dir import make_subrepo from dvc_data.hashfile.hash_info import HashInfo from dvc_data.stage import stage @@ -318,15 +319,6 @@ def test_isdvc(tmp_dir, dvc): assert fs.isdvc("dir/baz", recursive=True) -def make_subrepo(dir_, scm, config=None): - dir_.mkdir(parents=True, exist_ok=True) - with dir_.chdir(): - dir_.scm = scm - dir_.init(dvc=True, subdir=True) - if config: - dir_.add_remote(config=config) - - def test_subrepos(tmp_dir, scm, dvc, mocker): tmp_dir.scm_gen( {"dir": {"repo.txt": "file to confuse DvcFileSystem"}}, diff --git a/tests/unit/fs/test_dvc_info.py b/tests/unit/fs/test_dvc_info.py index 03bdec1f7c..dc175662e5 100644 --- a/tests/unit/fs/test_dvc_info.py +++ b/tests/unit/fs/test_dvc_info.py @@ -3,7 +3,7 @@ import pytest from dvc.fs.dvc import DvcFileSystem -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo @pytest.fixture diff --git a/tests/unit/test_external_repo.py b/tests/unit/test_external_repo.py index 7f2f680b71..c69b46355b 100644 --- a/tests/unit/test_external_repo.py +++ b/tests/unit/test_external_repo.py @@ -4,7 +4,7 @@ import pytest from dvc.external_repo import external_repo -from tests.unit.fs.test_dvc import make_subrepo +from dvc.testing.tmp_dir import make_subrepo def test_hook_is_called(tmp_dir, erepo_dir, mocker):