Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hfh 0.10 implicit auth #5031

Merged
merged 9 commits into from Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix:
test: ['unit', 'integration']
os: [ubuntu-latest, windows-latest]
pyarrow_version: [latest, 6.0.1]
deps_versions: [latest, minimum]
continue-on-error: ${{ matrix.test == 'integration' }}
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -63,12 +63,12 @@ jobs:
run: |
pip install .[tests]
pip install -r additional-tests-requirements.txt --no-deps
- name: Install latest PyArrow
if: ${{ matrix.pyarrow_version == 'latest' }}
run: pip install pyarrow --upgrade
- name: Install PyArrow ${{ matrix.pyarrow_version }}
if: ${{ matrix.pyarrow_version != 'latest' }}
run: pip install pyarrow==${{ matrix.pyarrow_version }}
- name: Install dependencies (latest versions)
if: ${{ matrix.deps_versions == 'latest' }}
run: pip install --upgrade pyarrow huggingface-hub
- name: Install depencencies (minimum versions)
if: ${{ matrix.deps_versions != 'latest' }}
run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -89,7 +89,8 @@
# for data streaming via http
"aiohttp",
# To get datasets from the Datasets Hub on huggingface.co
"huggingface-hub>=0.1.0,<1.0.0",
# minimum 0.2.0 for set_access_token
"huggingface-hub>=0.2.0,<1.0.0",
# Utilities from PyPA to e.g., compare versions
"packaging",
"responses<0.19",
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Expand Up @@ -102,6 +102,7 @@
from .tasks import TaskTemplate
from .utils import logging
from .utils._hf_hub_fixes import create_repo
from .utils._hf_hub_fixes import list_repo_files as hf_api_list_repo_files
from .utils.file_utils import _retry, cached_path, estimate_dataset_size, hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.py_utils import asdict, convert_file_size_to_int, unique_values
Expand Down Expand Up @@ -4243,7 +4244,7 @@ def shards_with_embedded_external_files(shards):

shards = shards_with_embedded_external_files(shards)

files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token)
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
data_files = [file for file in files if file.startswith("data/")]

def path_in_repo(_index, shard):
Expand Down
21 changes: 8 additions & 13 deletions src/datasets/load.py
Expand Up @@ -29,7 +29,7 @@

import fsspec
import requests
from huggingface_hub import HfApi, HfFolder
from huggingface_hub import HfApi

from . import config
from .arrow_dataset import Dataset
Expand Down Expand Up @@ -62,6 +62,7 @@
)
from .splits import Split
from .tasks import TaskTemplate
from .utils._hf_hub_fixes import dataset_info as hf_api_dataset_info
from .utils.deprecation_utils import deprecated
from .utils.file_utils import (
OfflineModeIsEnabled,
Expand Down Expand Up @@ -744,14 +745,11 @@ def __init__(
increase_load_count(name, resource_type="dataset")

def get_module(self) -> DatasetModule:
if isinstance(self.download_config.use_auth_token, bool):
token = HfFolder.get_token() if self.download_config.use_auth_token else None
else:
token = self.download_config.use_auth_token
hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
hfh_dataset_info = hf_api_dataset_info(
HfApi(config.HF_ENDPOINT),
self.name,
revision=self.revision,
token=token if token else "no-token",
use_auth_token=self.download_config.use_auth_token,
timeout=100.0,
)
patterns = (
Expand Down Expand Up @@ -1112,14 +1110,11 @@ def dataset_module_factory(
_raise_if_offline_mode_is_enabled()
hf_api = HfApi(config.HF_ENDPOINT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this line because hf_api is no longer used in this method?

hf_api = HfApi(config.HF_ENDPOINT)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still used in hf_api_dataset_info(hf_api, ...) a few lines later ;)

try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
dataset_info = hf_api_dataset_info(
hf_api,
repo_id=path,
revision=revision,
token=token if token else "no-token",
use_auth_token=download_config.use_auth_token,
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
Expand Down
77 changes: 75 additions & 2 deletions src/datasets/utils/_hf_hub_fixes.py
@@ -1,7 +1,8 @@
from typing import Optional
from typing import List, Optional, Union

import huggingface_hub
from huggingface_hub import HfApi
from huggingface_hub import HfApi, HfFolder
from huggingface_hub.hf_api import DatasetInfo
from packaging import version


Expand Down Expand Up @@ -99,3 +100,75 @@ def delete_repo(
token=token,
repo_type=repo_type,
)


def dataset_info(
hf_api: HfApi,
repo_id: str,
*,
revision: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> DatasetInfo:
"""
Get info on one specific dataset on huggingface.co.
Dataset can be private if you pass an acceptable token.
Args:
hf_api (`huggingface_hub.HfApi`): Hub client
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
revision (`str`, *optional*):
The revision of the dataset repository from which to get the
information.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
use_auth_token (`bool` or `str`, *optional*):
Whether to use the `auth_token` provided from the
`huggingface_hub` cli. If not logged in, a valid `auth_token`
can be passed in as a string.
Returns:
[`hf_api.DatasetInfo`]: The dataset repository information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is a copy-paste from the huggingface_hub docstring ? I would mention it and refer to https://huggingface.co/docs/huggingface_hub/v0.10.0/en/package_reference/hf_api#huggingface_hub.HfApi.dataset_info in the description to that a dataset user knows.

<Tip>
Raises the following errors:
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
- [`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
</Tip>
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
if use_auth_token is False:
token = "no-token"
elif isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token() or "no-token"
return hf_api.dataset_info(
repo_id,
revision=revision,
token=token,
timeout=timeout,
)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.dataset_info(repo_id, revision=revision, timeout=timeout, use_auth_token=use_auth_token)


def list_repo_files(
hf_api: HfApi,
repo_id: str,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
) -> List[str]:
"""
Get the list of files in a given repo.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
)
16 changes: 10 additions & 6 deletions src/datasets/utils/file_utils.py
Expand Up @@ -22,7 +22,9 @@
from typing import List, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse

import huggingface_hub
import requests
from huggingface_hub import HfFolder

from .. import __version__, config
from ..download.download_config import DownloadConfig
Expand Down Expand Up @@ -218,7 +220,9 @@ def cached_path(


def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = f"datasets/{__version__}; python/{config.PY_VERSION}"
ua = f"datasets/{__version__}"
ua += f"; python/{config.PY_VERSION}"
ua += f"; huggingface_hub/{huggingface_hub.__version__}"
Comment on lines 222 to +225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to hear if you have ideas on how to make this easy to configure. We have the same logic in huggingface_hub and transformers as well but nothing very convenient.

(not a suggestion to change something here, more as a general question)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's only a few lines and pretty easy to modify it, I don't think we can make something significantly better than that

ua += f"; pyarrow/{config.PYARROW_VERSION}"
if config.TORCH_AVAILABLE:
ua += f"; torch/{config.TORCH_VERSION}"
Expand All @@ -239,13 +243,13 @@ def get_authentication_headers_for_url(url: str, use_auth_token: Optional[Union[
"""Handle the HF authentication"""
headers = {}
if url.startswith(config.HF_ENDPOINT):
token = None
if isinstance(use_auth_token, str):
if use_auth_token is False:
token = None
elif isinstance(use_auth_token, str):
token = use_auth_token
elif bool(use_auth_token):
from huggingface_hub import hf_api
else:
token = HfFolder.get_token()

token = hf_api.HfFolder.get_token()
if token:
headers["authorization"] = f"Bearer {token}"
return headers
Expand Down
3 changes: 2 additions & 1 deletion tests/test_filesystem.py
Expand Up @@ -12,6 +12,7 @@
extract_path_from_uri,
is_remote_filesystem,
)
from datasets.utils._hf_hub_fixes import dataset_info as hf_api_dataset_info

from .utils import require_lz4, require_zstandard

Expand Down Expand Up @@ -93,7 +94,7 @@ def test_fs_isfile(protocol, zip_jsonl_path, jsonl_gz_path):

@pytest.mark.integration
def test_hf_filesystem(hf_token, hf_api, hf_private_dataset_repo_txt_data, text_file):
repo_info = hf_api.dataset_info(hf_private_dataset_repo_txt_data, token=hf_token)
repo_info = hf_api_dataset_info(hf_api, hf_private_dataset_repo_txt_data, use_auth_token=hf_token)
hffs = HfFileSystem(repo_info=repo_info, token=hf_token)
assert sorted(hffs.glob("*")) == [".gitattributes", "data"]
assert hffs.isdir("data")
Expand Down
28 changes: 16 additions & 12 deletions tests/test_load.py
Expand Up @@ -756,18 +756,6 @@ def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming):
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=hf_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("data_file", ["zip_csv_path", "zip_csv_with_dir_path", "csv_path"])
def test_load_dataset_zip_csv(data_file, streaming, zip_csv_path, zip_csv_with_dir_path, csv_path):
Expand Down Expand Up @@ -874,6 +862,22 @@ def assert_auth(url, *args, headers, **kwargs):
mock_head.assert_called()


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("implicit_token", [False, True])
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(
hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token
):
use_auth_token = None if implicit_token else hf_token
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=use_auth_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


@pytest.mark.integration
def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
with pytest.raises(FileNotFoundError):
Expand Down