Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpy check #221

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion edflow/data/dataset_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from chainer.dataset import DatasetMixin as DatasetMixin_
import numpy as np
from edflow.util import walk, update
from edflow.util import walk, update, LabelsDict

# handle bug with mocked chainer.dataset.DatasetMixin import
if hasattr(DatasetMixin_, "_mock_name"):
Expand Down Expand Up @@ -287,6 +287,7 @@ def labels(self):

@labels.setter
def labels(self, labels):
labels = LabelsDict(labels)
if hasattr(self, "data"):
self.data.labels = labels
else:
Expand Down
4 changes: 2 additions & 2 deletions edflow/edsetup_files/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, config):
self.example_names = config.get("example_names")

self.inputs = np.random.rand(self.num_example, self.feature_dimension)
self.labels = self.inputs
self.labels = {"random": self.inputs}

def get_example(self, idx):
"""
Expand All @@ -31,7 +31,7 @@ def get_example(self, idx):
example (dict): These will be retrieved by their respective keys in the step_op method of the iterator.
"""
inputs = self.inputs[idx]
labels = self.labels[idx]
labels = self.labels["random"][idx]

example = {"inputs": inputs, "labels": labels}
return example
Expand Down
39 changes: 39 additions & 0 deletions edflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,45 @@ def __init__(self, config):
self.config = config


class LabelsDict(dict):
"""
Dict which asserts all values are numpy arrays or memmaps.
Code from https://stackoverflow.com/questions/2060972/subclassing-python-dictionary-to-override-setitemZ
"""

def __init__(self, *args, **kwargs):
self.update(*args, **kwargs)

def __setitem__(self, key, value):
# optional processing here
def assert_valid_label(val):
assert isinstance(
val, (np.ndarray, np.memmap)
), "Labels must be dict s of numpy arrays or memmaps and not {}!".format(
type(val)
)

walk(value, assert_valid_label)
super(LabelsDict, self).__setitem__(key, value)

def update(self, *args, **kwargs):
if args:
if len(args) > 1:
raise TypeError(
"update expected at most 1 arguments, " "got %d" % len(args)
)
other = dict(args[0])
for key in other:
self[key] = other[key]
for key in kwargs:
self[key] = kwargs[key]

def setdefault(self, key, value=None):
if key not in self:
self[key] = value
return self[key]


if __name__ == "__main__":
from edflow.data.util import plot_datum

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
extras_require={
"explore": ["streamlit > 0.49"],
"docs": ["sphinx >= 1.4", "sphinx_rtd_theme", "better-apidoc"],
"test": ["pytest", "pytest-cov", "coveralls"],
"test": ["pytest", "pytest-cov", "coveralls", "coverage < 5.0"],
},
zip_safe=False,
scripts=[
Expand Down
25 changes: 25 additions & 0 deletions tests/test_data/test_datasetmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,34 @@ def __len__(self):
assert d2["val"] == d1["val"]()


def test_labels_numpy():
class MyDset(DatasetMixin):
def __init__(self):
self.labels = {}

# set labels with array
dataset = MyDset()
dataset.labels = {"l": np.arange(10)}

# set labels with list
dataset = MyDset()
with pytest.raises(AssertionError):
dataset.labels = {"l": list(np.arange(10))}

# update labels with array
dataset = MyDset()
dataset.labels.update({"l2": np.arange(10)})

# update labels with list
dataset = MyDset()
with pytest.raises(AssertionError):
dataset.labels.update({"l2": list(np.arange(10))})


if __name__ == "__main__":
test_dset_mxin()
test_dset_mxin_ops()
test_dset_mxin_data_attr()
test_dset_mxin_app_labels()
test_dset_mxin_data_attr_app_labels()
test_labels_numpy()
24 changes: 24 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
contains_key,
KeyNotFoundError,
get_leaf_names,
LabelsDict,
)
from edflow import util
from itertools import product
import numpy as np

# ================= set_value ====================

Expand Down Expand Up @@ -792,3 +794,25 @@ def test_get_leaf_name():
ref = sorted(["a/0", "a/1", "b/c/d", "e"])

assert names == ref


# ================= LabelsDict ===================


def test_LabelsDict():
"""Code adapted from https://stackoverflow.com/questions/2060972/subclassing-python-dictionary-to-override-setitem"""

def test_updates(dictish):
dictish["abc"] = np.array(123)
dictish.update({"def": np.array(234)})
dictish.update(red=np.array(1), blue=np.array(2))
dictish.update([("orange", np.array(3)), ("green", np.array(4))])
dictish.update({"hello": np.array("kitty")}, black=np.array("white"))
dictish.update({"yellow": np.array(5)}, yellow=np.array(6))
dictish.setdefault("brown", np.array(7))

python_dict = dict([("b", 2), ("c", 3)], a=1)
test_updates(python_dict)

my_dict = LabelsDict([("b", np.array(2)), ("c", np.array(3))], a=np.array(1))
test_updates(my_dict)