Skip to content

Commit

Permalink
Merge pull request #2 from TonyAttalla/tests-pr-feedback
Browse files Browse the repository at this point in the history
pr feedback
  • Loading branch information
TonyAttalla committed May 23, 2021
2 parents 8889db8 + 4b73a45 commit e12051f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 18 deletions.
5 changes: 3 additions & 2 deletions doc/whats_new/v1.0.rst
Expand Up @@ -116,8 +116,9 @@ Changelog
- |Enhancement| :func:`datasets.fetch_kddcup99` raises a better message
when the cached file is invalid. :pr:`19669` `Thomas Fan`_.

- |Feature| load_files now allows users to include a blocklist and
an allowlist. :pr:`19747` by :user:`Tony Attalla <tonyattalla>`.
- |Feature| :func:`datasets.load_files` now accepts a ignore list and
an allow list based on file extensions. :pr:`19747` by
:user:`Tony Attalla <tonyattalla>`.

:mod:`sklearn.decomposition`
............................
Expand Down
2 changes: 1 addition & 1 deletion sklearn/datasets/_base.py
Expand Up @@ -202,7 +202,7 @@ def load_files(container_path, *, description=None, categories=None,
try:
assert not (allowed_extensions and ignored_extensions)
except AssertionError:
raise AssertionError("Ignored extensions and allowed extensions cannot"
raise ValueError("Ignored extensions and allowed extensions cannot"
" both be present. Please choose one or the"
" other.")

Expand Down
24 changes: 9 additions & 15 deletions sklearn/datasets/tests/test_base.py
Expand Up @@ -66,16 +66,6 @@ def test_category_dir_2(load_files_root):
_remove_dir(test_category_dir2)


@pytest.fixture
def test_category_dir_3(load_files_root):
test_category_dir_3 = tempfile.mkdtemp(dir=load_files_root)
sample_file = tempfile.NamedTemporaryFile(dir=test_category_dir_3,
delete=False, suffix='.txt')
sample_file.close()
yield str(test_category_dir_3)
_remove_dir(test_category_dir_3)


def test_data_home(data_home):
# get_data_home will point to a pre-existing folder
data_home = get_data_home(data_home=data_home)
Expand Down Expand Up @@ -135,12 +125,13 @@ def test_load_files_w_allowed_and_ignored_extensions(load_files_root):
msg = ("Ignored extensions and allowed extensions cannot both be present."
" Please choose one or the other.")

with pytest.raises(AssertionError, match=msg):
with pytest.raises(ValueError, match=msg):
load_files(load_files_root, allowed_extensions=[".txt"],
ignored_extensions=[".txt"])


def test_load_files_w_ignore_list(tmp_path):
"""Test load_files with ignore_extensions."""
d = tmp_path / "sub"
d.mkdir()
p1 = d / "file1.txt"
Expand All @@ -150,11 +141,13 @@ def test_load_files_w_ignore_list(tmp_path):
p3 = d / "file3.json"
p3.touch()
res = load_files(tmp_path, ignored_extensions=[".txt"])
assert len(res.filenames) == 2
assert all([re.search(r".*\.txt$", f) for f in res.filenames]) is False
assert str(p1) not in res.filenames
assert str(p2) in res.filenames
assert str(p3) in res.filenames


def test_load_files_w_allow_list(tmp_path):
"""Test load_files with_allow_extensions."""
d = tmp_path / "sub"
d.mkdir()
p1 = d / "file1.txt"
Expand All @@ -164,8 +157,9 @@ def test_load_files_w_allow_list(tmp_path):
p3 = d / "file3.json"
p3.touch()
res = load_files(tmp_path, allowed_extensions=[".txt"])
assert len(res.filenames) == 1
assert all([re.search(r".*\.txt$", f) for f in res.filenames]) is True
assert str(p1) in res.filenames
assert str(p2) not in res.filenames
assert str(p3) not in res.filenames


def test_load_sample_images():
Expand Down

0 comments on commit e12051f

Please sign in to comment.