-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
_hf_hub_fixes.py
174 lines (161 loc) 路 6.51 KB
/
_hf_hub_fixes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import List, Optional, Union
import huggingface_hub
from huggingface_hub import HfApi, HfFolder
from huggingface_hub.hf_api import DatasetInfo
from packaging import version
def create_repo(
hf_api: HfApi,
name: str,
token: Optional[str] = None,
organization: Optional[str] = None,
private: Optional[bool] = None,
repo_type: Optional[str] = None,
exist_ok: Optional[bool] = False,
space_sdk: Optional[str] = None,
) -> str:
"""
The huggingface_hub.HfApi.create_repo parameters changed in 0.5.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
Args:
hf_api (`huggingface_hub.HfApi`): Hub client
name (`str`): name of the repository (without the namespace)
token (`str`, *optional*): user or organization token. Defaults to None.
organization (`str`, *optional*): namespace for the repository: the username or organization name.
By default it uses the namespace associated to the token used.
private (`bool`, *optional*):
Whether the model repo should be private.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if uploading to a dataset or
space, `None` or `"model"` if uploading to a model. Default is
`None`.
exist_ok (`bool`, *optional*, defaults to `False`):
If `True`, do not raise an error if repo already exists.
space_sdk (`str`, *optional*):
Choice of SDK to use if repo_type is "space". Can be
"streamlit", "gradio", or "static".
Returns:
`str`: URL to the newly created repo.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
return hf_api.create_repo(
name=name,
organization=organization,
token=token,
private=private,
repo_type=repo_type,
exist_ok=exist_ok,
space_sdk=space_sdk,
)
else: # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
return hf_api.create_repo(
repo_id=f"{organization}/{name}",
token=token,
private=private,
repo_type=repo_type,
exist_ok=exist_ok,
space_sdk=space_sdk,
)
def delete_repo(
hf_api: HfApi,
name: str,
token: Optional[str] = None,
organization: Optional[str] = None,
repo_type: Optional[str] = None,
) -> str:
"""
The huggingface_hub.HfApi.delete_repo parameters changed in 0.5.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
Args:
hf_api (`huggingface_hub.HfApi`): Hub client
name (`str`): name of the repository (without the namespace)
token (`str`, *optional*): user or organization token. Defaults to None.
organization (`str`, *optional*): namespace for the repository: the username or organization name.
By default it uses the namespace associated to the token used.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if uploading to a dataset or
space, `None` or `"model"` if uploading to a model. Default is
`None`.
Returns:
`str`: URL to the newly created repo.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.5.0"):
return hf_api.delete_repo(
name=name,
organization=organization,
token=token,
repo_type=repo_type,
)
else: # the `organization` parameter is deprecated in huggingface_hub>=0.5.0
return hf_api.delete_repo(
repo_id=f"{organization}/{name}",
token=token,
repo_type=repo_type,
)
def dataset_info(
hf_api: HfApi,
repo_id: str,
*,
revision: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> DatasetInfo:
"""
Get info on one specific dataset on huggingface.co.
Dataset can be private if you pass an acceptable token.
Args:
hf_api (`huggingface_hub.HfApi`): Hub client
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
revision (`str`, *optional*):
The revision of the dataset repository from which to get the
information.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
use_auth_token (`bool` or `str`, *optional*):
Whether to use the `auth_token` provided from the
`huggingface_hub` cli. If not logged in, a valid `auth_token`
can be passed in as a string.
Returns:
[`hf_api.DatasetInfo`]: The dataset repository information.
<Tip>
Raises the following errors:
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
- [`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
</Tip>
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
if use_auth_token is False:
token = "no-token"
elif isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token() or "no-token"
return hf_api.dataset_info(
repo_id,
revision=revision,
token=token,
timeout=timeout,
)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.dataset_info(repo_id, revision=revision, timeout=timeout, use_auth_token=use_auth_token)
def list_repo_files(
hf_api: HfApi,
repo_id: str,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
) -> List[str]:
"""
Get the list of files in a given repo.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
)