-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
param.py
171 lines (133 loc) · 4.7 KB
/
param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import logging
import os
import typing
from collections import defaultdict
from typing import Dict
import dpath.util
from voluptuous import Any
from dvc.exceptions import DvcException
from dvc.utils.serialize import ParseError, load_path
from dvc_data.hashfile.hash_info import HashInfo
from .base import Dependency
logger = logging.getLogger(__name__)
class MissingParamsError(DvcException):
pass
class MissingParamsFile(DvcException):
pass
class ParamsIsADirectoryError(DvcException):
pass
class BadParamFileError(DvcException):
pass
class ParamsDependency(Dependency):
PARAM_PARAMS = "params"
PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)}
DEFAULT_PARAMS_FILE = "params.yaml"
def __init__(self, stage, path, params=None, repo=None):
self.params = list(params) if params else []
info = (
{self.PARAM_PARAMS: params} if isinstance(params, dict) else None
)
repo = repo or stage.repo
path = path or os.path.join(repo.root_dir, self.DEFAULT_PARAMS_FILE)
super().__init__(stage, path, info=info, repo=repo)
def dumpd(self):
ret = super().dumpd()
if not self.hash_info:
ret[self.PARAM_PARAMS] = self.params or {}
return ret
def fill_values(self, values=None):
"""Load params values dynamically."""
if values is None:
return
info = {}
if not self.params:
info.update(values)
for param in self.params:
if param in values:
info[param] = values[param]
self.hash_info = HashInfo(self.PARAM_PARAMS, info)
def read_params(
self, flatten: bool = True, **kwargs: typing.Any
) -> Dict[str, typing.Any]:
try:
config = self.read_file()
except MissingParamsFile:
config = {}
if not self.params:
return config
ret = {}
if flatten:
for param in self.params:
try:
ret[param] = dpath.util.get(config, param, separator=".")
except KeyError:
continue
return ret
from dpath.util import merge
for param in self.params:
merge(
ret,
dpath.util.search(config, param, separator="."),
separator=".",
)
return ret
def workspace_status(self):
if not self.exists:
return {str(self): "deleted"}
if self.hash_info.value is None:
return {str(self): "new"}
from funcy import ldistinct
status = defaultdict(dict)
info = self.hash_info.value if self.hash_info else {}
actual = self.read_params()
# NOTE: we want to preserve the order of params as specified in the
# status. In case of tracking the whole file, the order is top-level
# keys in the file and then the keys in the `info` from `dvc.lock`
# (which are alphabetically sorted).
params = self.params or ldistinct([*actual.keys(), *info.keys()])
for param in params:
if param not in actual:
st = "deleted"
elif param not in info:
st = "new"
elif actual[param] != info[param]:
st = "modified"
else:
assert actual[param] == info[param]
continue
status[str(self)][param] = st
return status
def status(self):
return self.workspace_status()
def validate_filepath(self):
if not self.exists:
raise MissingParamsFile(f"Parameters file '{self}' does not exist")
if self.isdir():
raise ParamsIsADirectoryError(
f"'{self}' is a directory, expected a parameters file"
)
def read_file(self):
self.validate_filepath()
try:
return load_path(self.fs_path, self.repo.fs)
except ParseError as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc
def get_hash(self):
info = self.read_params()
missing_params = set(self.params) - set(info.keys())
if missing_params:
raise MissingParamsError(
"Parameters '{}' are missing from '{}'.".format(
", ".join(missing_params), self
)
)
return HashInfo(self.PARAM_PARAMS, info)
def save(self):
if not self.exists:
raise self.DoesNotExistError(self)
if not self.isfile and not self.isdir:
raise self.IsNotFileOrDirError(self)
self.ignore()
self.hash_info = self.get_hash()