diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 3891660c77..96152c9020 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -280,4 +280,8 @@ class RelPath(str): "plots": str, "live": str, }, + "parsing": { + "bool": All(Lower, Choices("store_true", "boolean_optional")), + "list": All(Lower, Choices("nargs", "append")), + }, } diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index 041863a55c..4f86b932c6 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -303,7 +303,7 @@ def _resolve( ) -> DictStr: try: return context.resolve( - value, skip_interpolation_checks=skip_checks + value, skip_interpolation_checks=skip_checks, key=key ) except (ParseError, KeyNotInContext) as exc: format_and_raise( diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 899ee85646..8c6510c9d9 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -18,6 +18,7 @@ normalize_key, recurse, str_interpolate, + validate_value, ) logger = logging.getLogger(__name__) @@ -506,7 +507,7 @@ def set_temporarily(self, to_set: DictStr, reserve: bool = False): self.data.pop(key, None) def resolve( - self, src, unwrap=True, skip_interpolation_checks=False + self, src, unwrap=True, skip_interpolation_checks=False, key=None ) -> Any: """Recursively resolves interpolation and returns resolved data. @@ -522,10 +523,10 @@ def resolve( {'lst': [1, 2, 3]} """ func = recurse(self.resolve_str) - return func(src, unwrap, skip_interpolation_checks) + return func(src, unwrap, skip_interpolation_checks, key) def resolve_str( - self, src: str, unwrap=True, skip_interpolation_checks=False + self, src: str, unwrap=True, skip_interpolation_checks=False, key=None ) -> str: """Resolves interpolated string to it's original value, or in case of multiple interpolations, a combined string. @@ -543,10 +544,12 @@ def resolve_str( expr = get_expression( matches[0], skip_checks=skip_interpolation_checks ) - return self.select(expr, unwrap=unwrap) + value = self.select(expr, unwrap=unwrap) + validate_value(value, key) + return value # but not "${num} days" return str_interpolate( - src, matches, self, skip_checks=skip_interpolation_checks + src, matches, self, skip_checks=skip_interpolation_checks, key=key ) diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index 6bb92c81a0..4c29875a72 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -1,11 +1,12 @@ import re import typing -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from functools import singledispatch from funcy import memoize, rpartial from dvc.exceptions import DvcException +from dvc.utils.flatten import flatten if typing.TYPE_CHECKING: from typing import List, Match @@ -80,6 +81,45 @@ def _(obj: bool): return "true" if obj else "false" +@to_str.register(dict) +def _(obj: dict): + from dvc.config import Config + + config = Config().get("parsing", {}) + + result = "" + for k, v in flatten(obj).items(): + + if isinstance(v, bool): + if v: + result += f"--{k} " + else: + if config.get("bool", "store_true") == "boolean_optional": + result += f"--no-{k} " + + elif isinstance(v, str): + result += f"--{k} '{v}' " + + elif isinstance(v, Iterable): + for n, i in enumerate(v): + if isinstance(i, str): + i = f"'{i}'" + elif isinstance(i, Iterable): + raise ParseError( + f"Cannot interpolate nested iterable in '{k}'" + ) + + if config.get("list", "nargs") == "append": + result += f"--{k} {i} " + else: + result += f"{i} " if n > 0 else f"--{k} {i} " + + else: + result += f"--{k} {v} " + + return result.rstrip() + + def _format_exc_msg(exc: "ParseException"): from pyparsing import ParseException @@ -148,23 +188,33 @@ def get_expression(match: "Match", skip_checks: bool = False): return inner if skip_checks else parse_expr(inner) +def validate_value(value, key): + from .context import PRIMITIVES + + not_primitive = value is not None and not isinstance(value, PRIMITIVES) + not_foreach = key is not None and "foreach" not in key + if not_primitive and not_foreach: + if isinstance(value, dict): + if key == "cmd": + return True + raise ParseError( + f"Cannot interpolate data of type '{type(value).__name__}'" + ) + + def str_interpolate( template: str, matches: "List[Match]", context: "Context", skip_checks: bool = False, + key=None, ): - from .context import PRIMITIVES - index, buf = 0, "" for match in matches: start, end = match.span(0) expr = get_expression(match, skip_checks=skip_checks) value = context.select(expr, unwrap=True) - if value is not None and not isinstance(value, PRIMITIVES): - raise ParseError( - f"Cannot interpolate data of type '{type(value).__name__}'" - ) + validate_value(value, key) buf += template[index:start] + to_str(value) index = end buf += template[index:] diff --git a/tests/func/parsing/test_errors.py b/tests/func/parsing/test_errors.py index d49fb08698..559e734c43 100644 --- a/tests/func/parsing/test_errors.py +++ b/tests/func/parsing/test_errors.py @@ -119,18 +119,34 @@ def test_wdir_failed_to_interpolate(tmp_dir, dvc, wdir, expected_msg): def test_interpolate_non_string(tmp_dir, dvc): definition = make_entry_definition( - tmp_dir, "build", {"cmd": "echo ${models}"}, Context(models={}) + tmp_dir, "build", {"outs": "${models}"}, Context(models={}) ) with pytest.raises(ResolveError) as exc_info: definition.resolve() assert str(exc_info.value) == ( - "failed to parse 'stages.build.cmd' in 'dvc.yaml':\n" + "failed to parse 'stages.build.outs' in 'dvc.yaml':\n" "Cannot interpolate data of type 'dict'" ) assert definition.context == {"models": {}} +def test_interpolate_nested_iterable(tmp_dir, dvc): + definition = make_entry_definition( + tmp_dir, + "build", + {"cmd": "echo ${models}"}, + Context(models={"list": [1, [2, 3]]}), + ) + with pytest.raises(ResolveError) as exc_info: + definition.resolve() + + assert str(exc_info.value) == ( + "failed to parse 'stages.build.cmd' in 'dvc.yaml':\n" + "Cannot interpolate nested iterable in 'list'" + ) + + def test_partial_vars_doesnot_exist(tmp_dir, dvc): (tmp_dir / "test_params.yaml").dump({"sub1": "sub1", "sub2": "sub2"}) diff --git a/tests/func/parsing/test_interpolated_entry.py b/tests/func/parsing/test_interpolated_entry.py index 6c9650d221..3be23b6d50 100644 --- a/tests/func/parsing/test_interpolated_entry.py +++ b/tests/func/parsing/test_interpolated_entry.py @@ -259,3 +259,59 @@ def test_vars_load_partial(tmp_dir, dvc, local, vars_): d["vars"] = vars_ resolver = DataResolver(dvc, tmp_dir.fs_path, d) resolver.resolve() + + +@pytest.mark.parametrize( + "bool_config, list_config", + [(None, None), ("store_true", "nargs"), ("boolean_optional", "append")], +) +def test_cmd_dict(tmp_dir, dvc, bool_config, list_config): + with dvc.config.edit() as conf: + if bool_config: + conf["parsing"]["bool"] = bool_config + if list_config: + conf["parsing"]["list"] = list_config + + data = { + "dict": { + "foo": "foo", + "bar": 2, + "string": "spaced string", + "bool": True, + "bool-false": False, + "list": [1, 2, "foo"], + "nested": {"foo": "foo"}, + } + } + (tmp_dir / DEFAULT_PARAMS_FILE).dump(data) + resolver = DataResolver( + dvc, + tmp_dir.fs_path, + {"stages": {"stage1": {"cmd": "python script.py ${dict}"}}}, + ) + + if bool_config is None or bool_config == "store_true": + bool_resolved = " --bool" + else: + bool_resolved = " --bool --no-bool-false" + + if list_config is None or list_config == "nargs": + list_resolved = " --list 1 2 'foo'" + else: + list_resolved = " --list 1 --list 2 --list 'foo'" + + assert_stage_equal( + resolver.resolve(), + { + "stages": { + "stage1": { + "cmd": "python script.py" + " --foo 'foo' --bar 2" + " --string 'spaced string'" + f"{bool_resolved}" + f"{list_resolved}" + " --nested.foo 'foo'" + } + } + }, + )