Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hfh 0.10 implicit auth #5031

Merged
merged 9 commits into from Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix:
test: ['unit', 'integration']
os: [ubuntu-latest, windows-latest]
pyarrow_version: [latest, 6.0.1]
deps_versions: [latest, minimum]
continue-on-error: ${{ matrix.test == 'integration' }}
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -63,12 +63,12 @@ jobs:
run: |
pip install .[tests]
pip install -r additional-tests-requirements.txt --no-deps
- name: Install latest PyArrow
if: ${{ matrix.pyarrow_version == 'latest' }}
run: pip install pyarrow --upgrade
- name: Install PyArrow ${{ matrix.pyarrow_version }}
if: ${{ matrix.pyarrow_version != 'latest' }}
run: pip install pyarrow==${{ matrix.pyarrow_version }}
- name: Install dependencies (latest versions)
if: ${{ matrix.deps_versions == 'latest' }}
run: pip install --upgrade pyarrow huggingface-hub
- name: Install depencencies (minimum versions)
if: ${{ matrix.deps_versions != 'latest' }}
run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0 transformers
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -89,7 +89,8 @@
# for data streaming via http
"aiohttp",
# To get datasets from the Datasets Hub on huggingface.co
"huggingface-hub>=0.1.0,<1.0.0",
# minimum 0.2.0 for set_access_token
"huggingface-hub>=0.2.0,<1.0.0",
# Utilities from PyPA to e.g., compare versions
"packaging",
"responses<0.19",
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Expand Up @@ -102,6 +102,7 @@
from .tasks import TaskTemplate
from .utils import logging
from .utils._hf_hub_fixes import create_repo
from .utils._hf_hub_fixes import list_repo_files as hf_api_list_repo_files
from .utils.file_utils import _retry, cached_path, estimate_dataset_size, hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.py_utils import asdict, convert_file_size_to_int, unique_values
Expand Down Expand Up @@ -4243,7 +4244,7 @@ def shards_with_embedded_external_files(shards):

shards = shards_with_embedded_external_files(shards)

files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token)
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
data_files = [file for file in files if file.startswith("data/")]

def path_in_repo(_index, shard):
Expand Down
21 changes: 8 additions & 13 deletions src/datasets/load.py
Expand Up @@ -29,7 +29,7 @@

import fsspec
import requests
from huggingface_hub import HfApi, HfFolder
from huggingface_hub import HfApi

from . import config
from .arrow_dataset import Dataset
Expand Down Expand Up @@ -62,6 +62,7 @@
)
from .splits import Split
from .tasks import TaskTemplate
from .utils._hf_hub_fixes import dataset_info as hf_api_dataset_info
from .utils.deprecation_utils import deprecated
from .utils.file_utils import (
OfflineModeIsEnabled,
Expand Down Expand Up @@ -744,14 +745,11 @@ def __init__(
increase_load_count(name, resource_type="dataset")

def get_module(self) -> DatasetModule:
if isinstance(self.download_config.use_auth_token, bool):
token = HfFolder.get_token() if self.download_config.use_auth_token else None
else:
token = self.download_config.use_auth_token
hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
hfh_dataset_info = hf_api_dataset_info(
HfApi(config.HF_ENDPOINT),
self.name,
revision=self.revision,
token=token if token else "no-token",
use_auth_token=self.download_config.use_auth_token,
timeout=100.0,
)
patterns = (
Expand Down Expand Up @@ -1112,14 +1110,11 @@ def dataset_module_factory(
_raise_if_offline_mode_is_enabled()
hf_api = HfApi(config.HF_ENDPOINT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this line because hf_api is no longer used in this method?

hf_api = HfApi(config.HF_ENDPOINT)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still used in hf_api_dataset_info(hf_api, ...) a few lines later ;)

try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
dataset_info = hf_api_dataset_info(
hf_api,
repo_id=path,
revision=revision,
token=token if token else "no-token",
use_auth_token=download_config.use_auth_token,
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
Expand Down
79 changes: 77 additions & 2 deletions src/datasets/utils/_hf_hub_fixes.py
@@ -1,7 +1,8 @@
from typing import Optional
from typing import List, Optional, Union

import huggingface_hub
from huggingface_hub import HfApi
from huggingface_hub import HfApi, HfFolder
from huggingface_hub.hf_api import DatasetInfo
from packaging import version


Expand Down Expand Up @@ -99,3 +100,77 @@ def delete_repo(
token=token,
repo_type=repo_type,
)


def dataset_info(
hf_api: HfApi,
repo_id: str,
*,
revision: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> DatasetInfo:
"""
The huggingface_hub.HfApi.dataset_info parameters changed in 0.10.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.

Args:
hf_api (`huggingface_hub.HfApi`): Hub client
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
revision (`str`, *optional*):
The revision of the dataset repository from which to get the
information.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
use_auth_token (`bool` or `str`, *optional*):
Whether to use the `auth_token` provided from the
`huggingface_hub` cli. If not logged in, a valid `auth_token`
can be passed in as a string.
Returns:
[`hf_api.DatasetInfo`]: The dataset repository information.
<Tip>
Raises the following errors:
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
- [`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
</Tip>
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
if use_auth_token is False:
token = "no-token"
elif isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token() or "no-token"
return hf_api.dataset_info(
repo_id,
revision=revision,
token=token,
timeout=timeout,
)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.dataset_info(repo_id, revision=revision, timeout=timeout, use_auth_token=use_auth_token)


def list_repo_files(
hf_api: HfApi,
repo_id: str,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
) -> List[str]:
"""
The huggingface_hub.HfApi.list_repo_files parameters changed in 0.10.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
)
16 changes: 10 additions & 6 deletions src/datasets/utils/file_utils.py
Expand Up @@ -22,7 +22,9 @@
from typing import List, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse

import huggingface_hub
import requests
from huggingface_hub import HfFolder

from .. import __version__, config
from ..download.download_config import DownloadConfig
Expand Down Expand Up @@ -218,7 +220,9 @@ def cached_path(


def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = f"datasets/{__version__}; python/{config.PY_VERSION}"
ua = f"datasets/{__version__}"
ua += f"; python/{config.PY_VERSION}"
ua += f"; huggingface_hub/{huggingface_hub.__version__}"
Comment on lines 222 to +225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to hear if you have ideas on how to make this easy to configure. We have the same logic in huggingface_hub and transformers as well but nothing very convenient.

(not a suggestion to change something here, more as a general question)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's only a few lines and pretty easy to modify it, I don't think we can make something significantly better than that

ua += f"; pyarrow/{config.PYARROW_VERSION}"
if config.TORCH_AVAILABLE:
ua += f"; torch/{config.TORCH_VERSION}"
Expand All @@ -239,13 +243,13 @@ def get_authentication_headers_for_url(url: str, use_auth_token: Optional[Union[
"""Handle the HF authentication"""
headers = {}
if url.startswith(config.HF_ENDPOINT):
token = None
if isinstance(use_auth_token, str):
if use_auth_token is False:
token = None
elif isinstance(use_auth_token, str):
token = use_auth_token
elif bool(use_auth_token):
from huggingface_hub import hf_api
else:
token = HfFolder.get_token()

token = hf_api.HfFolder.get_token()
if token:
headers["authorization"] = f"Bearer {token}"
return headers
Expand Down
3 changes: 2 additions & 1 deletion tests/test_filesystem.py
Expand Up @@ -12,6 +12,7 @@
extract_path_from_uri,
is_remote_filesystem,
)
from datasets.utils._hf_hub_fixes import dataset_info as hf_api_dataset_info

from .utils import require_lz4, require_zstandard

Expand Down Expand Up @@ -93,7 +94,7 @@ def test_fs_isfile(protocol, zip_jsonl_path, jsonl_gz_path):

@pytest.mark.integration
def test_hf_filesystem(hf_token, hf_api, hf_private_dataset_repo_txt_data, text_file):
repo_info = hf_api.dataset_info(hf_private_dataset_repo_txt_data, token=hf_token)
repo_info = hf_api_dataset_info(hf_api, hf_private_dataset_repo_txt_data, use_auth_token=hf_token)
hffs = HfFileSystem(repo_info=repo_info, token=hf_token)
assert sorted(hffs.glob("*")) == [".gitattributes", "data"]
assert hffs.isdir("data")
Expand Down
36 changes: 18 additions & 18 deletions tests/test_load.py
Expand Up @@ -756,18 +756,6 @@ def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming):
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=hf_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("data_file", ["zip_csv_path", "zip_csv_with_dir_path", "csv_path"])
def test_load_dataset_zip_csv(data_file, streaming, zip_csv_path, zip_csv_with_dir_path, csv_path):
Expand Down Expand Up @@ -876,20 +864,32 @@ def assert_auth(url, *args, headers, **kwargs):

@pytest.mark.integration
def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
with pytest.raises(FileNotFoundError):
load_dataset(hf_private_dataset_repo_txt_data, streaming=True)
ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True, use_auth_token=hf_token)
ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True)
assert next(iter(ds)) is not None


@pytest.mark.integration
def test_load_streaming_private_dataset_with_zipped_data(hf_token, hf_private_dataset_repo_zipped_txt_data):
with pytest.raises(FileNotFoundError):
load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True)
ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True, use_auth_token=hf_token)
ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True)
assert next(iter(ds)) is not None


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("implicit_token", [False, True])
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(
hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token
):
use_auth_token = None if implicit_token else hf_token
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=use_auth_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir, tmp_path, caplog):
cache_dir1 = tmp_path / "cache1"
cache_dir2 = tmp_path / "cache2"
Expand Down
18 changes: 17 additions & 1 deletion tests/test_metric_common.py
Expand Up @@ -38,6 +38,9 @@
UNSUPPORTED_ON_WINDOWS = {"code_eval"}
_on_windows = os.name == "nt"

REQUIRE_TRANSFORMERS = {"bertscore", "frugalscore", "perplexity"}
_has_transformers = importlib.util.find_spec("transformers") is not None


def skip_if_metric_requires_fairseq(test_case):
@wraps(test_case)
Expand All @@ -50,6 +53,17 @@ def wrapper(self, metric_name):
return wrapper


def skip_if_metric_requires_transformers(test_case):
@wraps(test_case)
def wrapper(self, metric_name):
if not _has_transformers and metric_name in REQUIRE_TRANSFORMERS:
self.skipTest('"test requires transformers"')
else:
test_case(self, metric_name)

return wrapper


Comment on lines +56 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have you added this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had an env without transformers and I had failing tests x) so I did the same as for other optional deps: excluding the tests when it's not installed

def skip_on_windows_if_not_windows_compatible(test_case):
@wraps(test_case)
def wrapper(self, metric_name):
Expand All @@ -67,7 +81,9 @@ def get_local_metric_names():


@parameterized.named_parameters(get_local_metric_names())
@for_all_test_methods(skip_if_metric_requires_fairseq, skip_on_windows_if_not_windows_compatible)
@for_all_test_methods(
skip_if_metric_requires_fairseq, skip_if_metric_requires_transformers, skip_on_windows_if_not_windows_compatible
)
@local
@pytest.mark.integration
class LocalMetricTest(parameterized.TestCase):
Expand Down