Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Sep 27, 2022
1 parent 3c95246 commit 49bb38f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
3 changes: 2 additions & 1 deletion tests/test_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
extract_path_from_uri,
is_remote_filesystem,
)
from datasets.utils._hf_hub_fixes import dataset_info as hf_api_dataset_info

from .utils import require_lz4, require_zstandard

Expand Down Expand Up @@ -93,7 +94,7 @@ def test_fs_isfile(protocol, zip_jsonl_path, jsonl_gz_path):

@pytest.mark.integration
def test_hf_filesystem(hf_token, hf_api, hf_private_dataset_repo_txt_data, text_file):
repo_info = hf_api.dataset_info(hf_private_dataset_repo_txt_data, token=hf_token)
repo_info = hf_api_dataset_info(hf_api, hf_private_dataset_repo_txt_data, token=hf_token)
hffs = HfFileSystem(repo_info=repo_info, token=hf_token)
assert sorted(hffs.glob("*")) == [".gitattributes", "data"]
assert hffs.isdir("data")
Expand Down
28 changes: 16 additions & 12 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,18 +756,6 @@ def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming):
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=hf_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("data_file", ["zip_csv_path", "zip_csv_with_dir_path", "csv_path"])
def test_load_dataset_zip_csv(data_file, streaming, zip_csv_path, zip_csv_with_dir_path, csv_path):
Expand Down Expand Up @@ -874,6 +862,22 @@ def assert_auth(url, *args, headers, **kwargs):
mock_head.assert_called()


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("implicit_token", [False, True])
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(
hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token
):
use_auth_token = None if implicit_token else hf_token
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=use_auth_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


@pytest.mark.integration
def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
with pytest.raises(FileNotFoundError):
Expand Down

0 comments on commit 49bb38f

Please sign in to comment.