diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 488bd3f8c0dd7..ece5ca058d9f5 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -116,8 +116,9 @@ Changelog - |Enhancement| :func:`datasets.fetch_kddcup99` raises a better message when the cached file is invalid. :pr:`19669` `Thomas Fan`_. -- |Feature| load_files now allows users to include a blocklist and - an allowlist. :pr:`19747` by :user:`Tony Attalla `. +- |Feature| :func:`datasets.load_files` now accepts a ignore list and + an allow list based on file extensions. :pr:`19747` by + :user:`Tony Attalla `. :mod:`sklearn.decomposition` ............................ diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index a9a2d23b2809f..d5dcf5ed96dfd 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -202,7 +202,7 @@ def load_files(container_path, *, description=None, categories=None, try: assert not (allowed_extensions and ignored_extensions) except AssertionError: - raise AssertionError("Ignored extensions and allowed extensions cannot" + raise ValueError("Ignored extensions and allowed extensions cannot" " both be present. Please choose one or the" " other.") diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index f489035de914a..8ea17b23552ca 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -66,16 +66,6 @@ def test_category_dir_2(load_files_root): _remove_dir(test_category_dir2) -@pytest.fixture -def test_category_dir_3(load_files_root): - test_category_dir_3 = tempfile.mkdtemp(dir=load_files_root) - sample_file = tempfile.NamedTemporaryFile(dir=test_category_dir_3, - delete=False, suffix='.txt') - sample_file.close() - yield str(test_category_dir_3) - _remove_dir(test_category_dir_3) - - def test_data_home(data_home): # get_data_home will point to a pre-existing folder data_home = get_data_home(data_home=data_home) @@ -135,12 +125,13 @@ def test_load_files_w_allowed_and_ignored_extensions(load_files_root): msg = ("Ignored extensions and allowed extensions cannot both be present." " Please choose one or the other.") - with pytest.raises(AssertionError, match=msg): + with pytest.raises(ValueError, match=msg): load_files(load_files_root, allowed_extensions=[".txt"], ignored_extensions=[".txt"]) def test_load_files_w_ignore_list(tmp_path): + """Test load_files with ignore_extensions.""" d = tmp_path / "sub" d.mkdir() p1 = d / "file1.txt" @@ -150,11 +141,13 @@ def test_load_files_w_ignore_list(tmp_path): p3 = d / "file3.json" p3.touch() res = load_files(tmp_path, ignored_extensions=[".txt"]) - assert len(res.filenames) == 2 - assert all([re.search(r".*\.txt$", f) for f in res.filenames]) is False + assert str(p1) not in res.filenames + assert str(p2) in res.filenames + assert str(p3) in res.filenames def test_load_files_w_allow_list(tmp_path): + """Test load_files with_allow_extensions.""" d = tmp_path / "sub" d.mkdir() p1 = d / "file1.txt" @@ -164,8 +157,9 @@ def test_load_files_w_allow_list(tmp_path): p3 = d / "file3.json" p3.touch() res = load_files(tmp_path, allowed_extensions=[".txt"]) - assert len(res.filenames) == 1 - assert all([re.search(r".*\.txt$", f) for f in res.filenames]) is True + assert str(p1) in res.filenames + assert str(p2) not in res.filenames + assert str(p3) not in res.filenames def test_load_sample_images():