Skip to content

Commit

Permalink
parsing: Support dict unpacking in cmd.
Browse files Browse the repository at this point in the history
Allow to use dictionaries as values for template interpolation but only inside the `cmd` key.

See tests/func/parsing/test_interpolated_entry.py::test_cmd_dict for detailed syntax.

Add `config.parsing` section for configuring behavior of ambiguous data types
like booleans and lists.
  • Loading branch information
daavoo committed Jul 4, 2022
1 parent 735b563 commit 7fac181
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 15 deletions.
4 changes: 4 additions & 0 deletions dvc/config_schema.py
Expand Up @@ -280,4 +280,8 @@ class RelPath(str):
"plots": str,
"live": str,
},
"parsing": {
"bool": All(Lower, Choices("store_true", "boolean_optional")),
"list": All(Lower, Choices("nargs", "append")),
},
}
2 changes: 1 addition & 1 deletion dvc/parsing/__init__.py
Expand Up @@ -303,7 +303,7 @@ def _resolve(
) -> DictStr:
try:
return context.resolve(
value, skip_interpolation_checks=skip_checks
value, skip_interpolation_checks=skip_checks, key=key
)
except (ParseError, KeyNotInContext) as exc:
format_and_raise(
Expand Down
13 changes: 8 additions & 5 deletions dvc/parsing/context.py
Expand Up @@ -18,6 +18,7 @@
normalize_key,
recurse,
str_interpolate,
validate_value,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -506,7 +507,7 @@ def set_temporarily(self, to_set: DictStr, reserve: bool = False):
self.data.pop(key, None)

def resolve(
self, src, unwrap=True, skip_interpolation_checks=False
self, src, unwrap=True, skip_interpolation_checks=False, key=None
) -> Any:
"""Recursively resolves interpolation and returns resolved data.
Expand All @@ -522,10 +523,10 @@ def resolve(
{'lst': [1, 2, 3]}
"""
func = recurse(self.resolve_str)
return func(src, unwrap, skip_interpolation_checks)
return func(src, unwrap, skip_interpolation_checks, key)

def resolve_str(
self, src: str, unwrap=True, skip_interpolation_checks=False
self, src: str, unwrap=True, skip_interpolation_checks=False, key=None
) -> str:
"""Resolves interpolated string to it's original value,
or in case of multiple interpolations, a combined string.
Expand All @@ -543,10 +544,12 @@ def resolve_str(
expr = get_expression(
matches[0], skip_checks=skip_interpolation_checks
)
return self.select(expr, unwrap=unwrap)
value = self.select(expr, unwrap=unwrap)
validate_value(value, key)
return value
# but not "${num} days"
return str_interpolate(
src, matches, self, skip_checks=skip_interpolation_checks
src, matches, self, skip_checks=skip_interpolation_checks, key=key
)


Expand Down
64 changes: 57 additions & 7 deletions dvc/parsing/interpolate.py
@@ -1,11 +1,12 @@
import re
import typing
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from functools import singledispatch

from funcy import memoize, rpartial

from dvc.exceptions import DvcException
from dvc.utils.flatten import flatten

if typing.TYPE_CHECKING:
from typing import List, Match
Expand Down Expand Up @@ -80,6 +81,45 @@ def _(obj: bool):
return "true" if obj else "false"


@to_str.register(dict)
def _(obj: dict):
from dvc.config import Config

config = Config().get("parsing", {})

result = ""
for k, v in flatten(obj).items():

if isinstance(v, bool):
if v:
result += f"--{k} "
else:
if config.get("bool", "store_true") == "boolean_optional":
result += f"--no-{k} "

elif isinstance(v, str):
result += f"--{k} '{v}' "

elif isinstance(v, Iterable):
for n, i in enumerate(v):
if isinstance(i, str):
i = f"'{i}'"
elif isinstance(i, Iterable):
raise ParseError(
f"Cannot interpolate nested iterable in '{k}'"
)

if config.get("list", "nargs") == "append":
result += f"--{k} {i} "
else:
result += f"{i} " if n > 0 else f"--{k} {i} "

else:
result += f"--{k} {v} "

return result.rstrip()


def _format_exc_msg(exc: "ParseException"):
from pyparsing import ParseException

Expand Down Expand Up @@ -148,23 +188,33 @@ def get_expression(match: "Match", skip_checks: bool = False):
return inner if skip_checks else parse_expr(inner)


def validate_value(value, key):
from .context import PRIMITIVES

not_primitive = value is not None and not isinstance(value, PRIMITIVES)
not_foreach = key is not None and "foreach" not in key
if not_primitive and not_foreach:
if isinstance(value, dict):
if key == "cmd":
return True
raise ParseError(
f"Cannot interpolate data of type '{type(value).__name__}'"
)


def str_interpolate(
template: str,
matches: "List[Match]",
context: "Context",
skip_checks: bool = False,
key=None,
):
from .context import PRIMITIVES

index, buf = 0, ""
for match in matches:
start, end = match.span(0)
expr = get_expression(match, skip_checks=skip_checks)
value = context.select(expr, unwrap=True)
if value is not None and not isinstance(value, PRIMITIVES):
raise ParseError(
f"Cannot interpolate data of type '{type(value).__name__}'"
)
validate_value(value, key)
buf += template[index:start] + to_str(value)
index = end
buf += template[index:]
Expand Down
20 changes: 18 additions & 2 deletions tests/func/parsing/test_errors.py
Expand Up @@ -119,18 +119,34 @@ def test_wdir_failed_to_interpolate(tmp_dir, dvc, wdir, expected_msg):

def test_interpolate_non_string(tmp_dir, dvc):
definition = make_entry_definition(
tmp_dir, "build", {"cmd": "echo ${models}"}, Context(models={})
tmp_dir, "build", {"outs": "${models}"}, Context(models={})
)
with pytest.raises(ResolveError) as exc_info:
definition.resolve()

assert str(exc_info.value) == (
"failed to parse 'stages.build.cmd' in 'dvc.yaml':\n"
"failed to parse 'stages.build.outs' in 'dvc.yaml':\n"
"Cannot interpolate data of type 'dict'"
)
assert definition.context == {"models": {}}


def test_interpolate_nested_iterable(tmp_dir, dvc):
definition = make_entry_definition(
tmp_dir,
"build",
{"cmd": "echo ${models}"},
Context(models={"list": [1, [2, 3]]}),
)
with pytest.raises(ResolveError) as exc_info:
definition.resolve()

assert str(exc_info.value) == (
"failed to parse 'stages.build.cmd' in 'dvc.yaml':\n"
"Cannot interpolate nested iterable in 'list'"
)


def test_partial_vars_doesnot_exist(tmp_dir, dvc):
(tmp_dir / "test_params.yaml").dump({"sub1": "sub1", "sub2": "sub2"})

Expand Down
56 changes: 56 additions & 0 deletions tests/func/parsing/test_interpolated_entry.py
Expand Up @@ -259,3 +259,59 @@ def test_vars_load_partial(tmp_dir, dvc, local, vars_):
d["vars"] = vars_
resolver = DataResolver(dvc, tmp_dir.fs_path, d)
resolver.resolve()


@pytest.mark.parametrize(
"bool_config, list_config",
[(None, None), ("store_true", "nargs"), ("boolean_optional", "append")],
)
def test_cmd_dict(tmp_dir, dvc, bool_config, list_config):
with dvc.config.edit() as conf:
if bool_config:
conf["parsing"]["bool"] = bool_config
if list_config:
conf["parsing"]["list"] = list_config

data = {
"dict": {
"foo": "foo",
"bar": 2,
"string": "spaced string",
"bool": True,
"bool-false": False,
"list": [1, 2, "foo"],
"nested": {"foo": "foo"},
}
}
(tmp_dir / DEFAULT_PARAMS_FILE).dump(data)
resolver = DataResolver(
dvc,
tmp_dir.fs_path,
{"stages": {"stage1": {"cmd": "python script.py ${dict}"}}},
)

if bool_config is None or bool_config == "store_true":
bool_resolved = " --bool"
else:
bool_resolved = " --bool --no-bool-false"

if list_config is None or list_config == "nargs":
list_resolved = " --list 1 2 'foo'"
else:
list_resolved = " --list 1 --list 2 --list 'foo'"

assert_stage_equal(
resolver.resolve(),
{
"stages": {
"stage1": {
"cmd": "python script.py"
" --foo 'foo' --bar 2"
" --string 'spaced string'"
f"{bool_resolved}"
f"{list_resolved}"
" --nested.foo 'foo'"
}
}
},
)

0 comments on commit 7fac181

Please sign in to comment.