From 365884d6197442207f03261495b9e88fc6e0f1d4 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Fri, 30 Sep 2022 11:15:59 +0200 Subject: [PATCH] Support hfh 0.10 implicit auth (#5031) * support hfh 0.10 implicit auth * update tests * Bump minimum hfh to 0.2.0 and test minimum version * style * fix test * fix tests * again * lucain's comment * fix ci --- .github/workflows/ci.yml | 14 ++--- setup.py | 3 +- src/datasets/arrow_dataset.py | 3 +- src/datasets/load.py | 21 +++----- src/datasets/utils/_hf_hub_fixes.py | 79 ++++++++++++++++++++++++++++- src/datasets/utils/file_utils.py | 16 +++--- tests/test_filesystem.py | 3 +- tests/test_load.py | 36 ++++++------- tests/test_metric_common.py | 18 ++++++- 9 files changed, 143 insertions(+), 50 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2dff4e6bd0c..ccadd893ccd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: matrix: test: ['unit', 'integration'] os: [ubuntu-latest, windows-latest] - pyarrow_version: [latest, 6.0.1] + deps_versions: [latest, minimum] continue-on-error: ${{ matrix.test == 'integration' }} runs-on: ${{ matrix.os }} steps: @@ -63,12 +63,12 @@ jobs: run: | pip install .[tests] pip install -r additional-tests-requirements.txt --no-deps - - name: Install latest PyArrow - if: ${{ matrix.pyarrow_version == 'latest' }} - run: pip install pyarrow --upgrade - - name: Install PyArrow ${{ matrix.pyarrow_version }} - if: ${{ matrix.pyarrow_version != 'latest' }} - run: pip install pyarrow==${{ matrix.pyarrow_version }} + - name: Install dependencies (latest versions) + if: ${{ matrix.deps_versions == 'latest' }} + run: pip install --upgrade pyarrow huggingface-hub + - name: Install depencencies (minimum versions) + if: ${{ matrix.deps_versions != 'latest' }} + run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0 transformers - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ diff --git a/setup.py b/setup.py index 509042e18c6..60fd6b89cef 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,8 @@ # for data streaming via http "aiohttp", # To get datasets from the Datasets Hub on huggingface.co - "huggingface-hub>=0.1.0,<1.0.0", + # minimum 0.2.0 for set_access_token + "huggingface-hub>=0.2.0,<1.0.0", # Utilities from PyPA to e.g., compare versions "packaging", "responses<0.19", diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 42479e42091..55c0324e581 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -102,6 +102,7 @@ from .tasks import TaskTemplate from .utils import logging from .utils._hf_hub_fixes import create_repo +from .utils._hf_hub_fixes import list_repo_files as hf_api_list_repo_files from .utils.file_utils import _retry, cached_path, estimate_dataset_size, hf_hub_url from .utils.info_utils import is_small_dataset from .utils.py_utils import asdict, convert_file_size_to_int, unique_values @@ -4288,7 +4289,7 @@ def shards_with_embedded_external_files(shards): shards = shards_with_embedded_external_files(shards) - files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token) + files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token) data_files = [file for file in files if file.startswith("data/")] def path_in_repo(_index, shard): diff --git a/src/datasets/load.py b/src/datasets/load.py index 159c6e4dec3..ff1aa01d640 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -29,7 +29,7 @@ import fsspec import requests -from huggingface_hub import HfApi, HfFolder +from huggingface_hub import HfApi from . import config from .arrow_dataset import Dataset @@ -62,6 +62,7 @@ ) from .splits import Split from .tasks import TaskTemplate +from .utils._hf_hub_fixes import dataset_info as hf_api_dataset_info from .utils.deprecation_utils import deprecated from .utils.file_utils import ( OfflineModeIsEnabled, @@ -736,14 +737,11 @@ def __init__( increase_load_count(name, resource_type="dataset") def get_module(self) -> DatasetModule: - if isinstance(self.download_config.use_auth_token, bool): - token = HfFolder.get_token() if self.download_config.use_auth_token else None - else: - token = self.download_config.use_auth_token - hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info( + hfh_dataset_info = hf_api_dataset_info( + HfApi(config.HF_ENDPOINT), self.name, revision=self.revision, - token=token if token else "no-token", + use_auth_token=self.download_config.use_auth_token, timeout=100.0, ) patterns = ( @@ -1104,14 +1102,11 @@ def dataset_module_factory( _raise_if_offline_mode_is_enabled() hf_api = HfApi(config.HF_ENDPOINT) try: - if isinstance(download_config.use_auth_token, bool): - token = HfFolder.get_token() if download_config.use_auth_token else None - else: - token = download_config.use_auth_token - dataset_info = hf_api.dataset_info( + dataset_info = hf_api_dataset_info( + hf_api, repo_id=path, revision=revision, - token=token if token else "no-token", + use_auth_token=download_config.use_auth_token, timeout=100.0, ) except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist diff --git a/src/datasets/utils/_hf_hub_fixes.py b/src/datasets/utils/_hf_hub_fixes.py index c02fec828d4..85718d72f81 100644 --- a/src/datasets/utils/_hf_hub_fixes.py +++ b/src/datasets/utils/_hf_hub_fixes.py @@ -1,7 +1,8 @@ -from typing import Optional +from typing import List, Optional, Union import huggingface_hub -from huggingface_hub import HfApi +from huggingface_hub import HfApi, HfFolder +from huggingface_hub.hf_api import DatasetInfo from packaging import version @@ -99,3 +100,77 @@ def delete_repo( token=token, repo_type=repo_type, ) + + +def dataset_info( + hf_api: HfApi, + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + use_auth_token: Optional[Union[bool, str]] = None, +) -> DatasetInfo: + """ + The huggingface_hub.HfApi.dataset_info parameters changed in 0.10.0 and some of them were deprecated. + This function checks the huggingface_hub version to call the right parameters. + + Args: + hf_api (`huggingface_hub.HfApi`): Hub client + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the dataset repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. + Returns: + [`hf_api.DatasetInfo`]: The dataset repository information. + + Raises the following errors: + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + """ + if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"): + if use_auth_token is False: + token = "no-token" + elif isinstance(use_auth_token, str): + token = use_auth_token + else: + token = HfFolder.get_token() or "no-token" + return hf_api.dataset_info( + repo_id, + revision=revision, + token=token, + timeout=timeout, + ) + else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0 + return hf_api.dataset_info(repo_id, revision=revision, timeout=timeout, use_auth_token=use_auth_token) + + +def list_repo_files( + hf_api: HfApi, + repo_id: str, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[str] = None, + timeout: Optional[float] = None, +) -> List[str]: + """ + The huggingface_hub.HfApi.list_repo_files parameters changed in 0.10.0 and some of them were deprecated. + This function checks the huggingface_hub version to call the right parameters. + """ + if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"): + return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout) + else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0 + return hf_api.list_repo_files( + repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout + ) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 7f2024d585f..e8e90617544 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -22,7 +22,9 @@ from typing import List, Optional, Type, TypeVar, Union from urllib.parse import urljoin, urlparse +import huggingface_hub import requests +from huggingface_hub import HfFolder from .. import __version__, config from ..download.download_config import DownloadConfig @@ -218,7 +220,9 @@ def cached_path( def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str: - ua = f"datasets/{__version__}; python/{config.PY_VERSION}" + ua = f"datasets/{__version__}" + ua += f"; python/{config.PY_VERSION}" + ua += f"; huggingface_hub/{huggingface_hub.__version__}" ua += f"; pyarrow/{config.PYARROW_VERSION}" if config.TORCH_AVAILABLE: ua += f"; torch/{config.TORCH_VERSION}" @@ -239,13 +243,13 @@ def get_authentication_headers_for_url(url: str, use_auth_token: Optional[Union[ """Handle the HF authentication""" headers = {} if url.startswith(config.HF_ENDPOINT): - token = None - if isinstance(use_auth_token, str): + if use_auth_token is False: + token = None + elif isinstance(use_auth_token, str): token = use_auth_token - elif bool(use_auth_token): - from huggingface_hub import hf_api + else: + token = HfFolder.get_token() - token = hf_api.HfFolder.get_token() if token: headers["authorization"] = f"Bearer {token}" return headers diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py index e9ccb92dfa5..faafe79dd54 100644 --- a/tests/test_filesystem.py +++ b/tests/test_filesystem.py @@ -12,6 +12,7 @@ extract_path_from_uri, is_remote_filesystem, ) +from datasets.utils._hf_hub_fixes import dataset_info as hf_api_dataset_info from .utils import require_lz4, require_zstandard @@ -93,7 +94,7 @@ def test_fs_isfile(protocol, zip_jsonl_path, jsonl_gz_path): @pytest.mark.integration def test_hf_filesystem(hf_token, hf_api, hf_private_dataset_repo_txt_data, text_file): - repo_info = hf_api.dataset_info(hf_private_dataset_repo_txt_data, token=hf_token) + repo_info = hf_api_dataset_info(hf_api, hf_private_dataset_repo_txt_data, use_auth_token=hf_token) hffs = HfFileSystem(repo_info=repo_info, token=hf_token) assert sorted(hffs.glob("*")) == [".gitattributes", "data"] assert hffs.isdir("data") diff --git a/tests/test_load.py b/tests/test_load.py index 65446896f36..f4315355920 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -756,18 +756,6 @@ def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0} -@require_pil -@pytest.mark.integration -@pytest.mark.parametrize("streaming", [False, True]) -def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming): - ds = load_dataset( - hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=hf_token - ) - assert isinstance(ds, IterableDataset if streaming else Dataset) - ds_items = list(ds) - assert len(ds_items) == 2 - - @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("data_file", ["zip_csv_path", "zip_csv_with_dir_path", "csv_path"]) def test_load_dataset_zip_csv(data_file, streaming, zip_csv_path, zip_csv_with_dir_path, csv_path): @@ -876,20 +864,32 @@ def assert_auth(url, *args, headers, **kwargs): @pytest.mark.integration def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data): - with pytest.raises(FileNotFoundError): - load_dataset(hf_private_dataset_repo_txt_data, streaming=True) - ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True, use_auth_token=hf_token) + ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True) assert next(iter(ds)) is not None @pytest.mark.integration def test_load_streaming_private_dataset_with_zipped_data(hf_token, hf_private_dataset_repo_zipped_txt_data): - with pytest.raises(FileNotFoundError): - load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True) - ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True, use_auth_token=hf_token) + ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True) assert next(iter(ds)) is not None +@require_pil +@pytest.mark.integration +@pytest.mark.parametrize("implicit_token", [False, True]) +@pytest.mark.parametrize("streaming", [False, True]) +def test_load_dataset_private_zipped_images( + hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token +): + use_auth_token = None if implicit_token else hf_token + ds = load_dataset( + hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=use_auth_token + ) + assert isinstance(ds, IterableDataset if streaming else Dataset) + ds_items = list(ds) + assert len(ds_items) == 2 + + def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir, tmp_path, caplog): cache_dir1 = tmp_path / "cache1" cache_dir2 = tmp_path / "cache2" diff --git a/tests/test_metric_common.py b/tests/test_metric_common.py index 8fce96668f8..88edeb94932 100644 --- a/tests/test_metric_common.py +++ b/tests/test_metric_common.py @@ -38,6 +38,9 @@ UNSUPPORTED_ON_WINDOWS = {"code_eval"} _on_windows = os.name == "nt" +REQUIRE_TRANSFORMERS = {"bertscore", "frugalscore", "perplexity"} +_has_transformers = importlib.util.find_spec("transformers") is not None + def skip_if_metric_requires_fairseq(test_case): @wraps(test_case) @@ -50,6 +53,17 @@ def wrapper(self, metric_name): return wrapper +def skip_if_metric_requires_transformers(test_case): + @wraps(test_case) + def wrapper(self, metric_name): + if not _has_transformers and metric_name in REQUIRE_TRANSFORMERS: + self.skipTest('"test requires transformers"') + else: + test_case(self, metric_name) + + return wrapper + + def skip_on_windows_if_not_windows_compatible(test_case): @wraps(test_case) def wrapper(self, metric_name): @@ -67,7 +81,9 @@ def get_local_metric_names(): @parameterized.named_parameters(get_local_metric_names()) -@for_all_test_methods(skip_if_metric_requires_fairseq, skip_on_windows_if_not_windows_compatible) +@for_all_test_methods( + skip_if_metric_requires_fairseq, skip_if_metric_requires_transformers, skip_on_windows_if_not_windows_compatible +) @local @pytest.mark.integration class LocalMetricTest(parameterized.TestCase):