Skip to content

Commit

Permalink
Re-add support for single url files in objects download (#19014)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Sep 13, 2022
1 parent ad5045e commit f89f16a
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 2 deletions.
7 changes: 6 additions & 1 deletion src/transformers/configuration_utils.py
Expand Up @@ -32,7 +32,9 @@
PushToHubMixin,
cached_file,
copy_func,
download_url,
extract_commit_hash,
is_remote_url,
is_torch_available,
logging,
)
Expand Down Expand Up @@ -592,9 +594,12 @@ def _get_config_dict(

is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Soecial case when pretrained_model_name_or_path is a local file
# Special case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
configuration_file = pretrained_model_name_or_path
resolved_config_file = download_url(pretrained_model_name_or_path)
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/feature_extraction_utils.py
Expand Up @@ -31,8 +31,10 @@
TensorType,
cached_file,
copy_func,
download_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
logging,
Expand Down Expand Up @@ -386,6 +388,9 @@ def get_feature_extractor_dict(
if os.path.isfile(pretrained_model_name_or_path):
resolved_feature_extractor_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
else:
feature_extractor_file = FEATURE_EXTRACTOR_NAME
try:
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/modeling_flax_utils.py
Expand Up @@ -47,8 +47,10 @@
add_start_docstrings_to_model_forward,
cached_file,
copy_func,
download_url,
has_file,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -677,6 +679,9 @@ def from_pretrained(
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
try:
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/modeling_tf_utils.py
Expand Up @@ -54,9 +54,11 @@
ModelOutput,
PushToHubMixin,
cached_file,
download_url,
find_labels,
has_file,
is_offline_mode,
is_remote_url,
logging,
requires_backends,
working_or_temp_dir,
Expand Down Expand Up @@ -2345,6 +2347,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/modeling_utils.py
Expand Up @@ -59,10 +59,12 @@
PushToHubMixin,
cached_file,
copy_func,
download_url,
has_file,
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1998,6 +2000,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if from_tf:
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/tokenization_utils_base.py
Expand Up @@ -42,9 +42,11 @@
add_end_docstrings,
cached_file,
copy_func,
download_url,
extract_commit_hash,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -1680,6 +1682,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
FutureWarning,
)
file_id = list(cls.vocab_files_names.keys())[0]

vocab_files[file_id] = pretrained_model_name_or_path
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
Expand Down Expand Up @@ -1723,6 +1726,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
elif is_remote_url(file_path):
resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
else:
resolved_vocab_files[file_id] = cached_file(
pretrained_model_name_or_path,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Expand Up @@ -63,13 +63,15 @@
cached_file,
default_cache_path,
define_sagemaker_information,
download_url,
extract_commit_hash,
get_cached_models,
get_file_from_repo,
get_full_repo_name,
has_file,
http_user_agent,
is_offline_mode,
is_remote_url,
move_cache,
send_example_telemetry,
)
Expand Down
35 changes: 34 additions & 1 deletion src/transformers/utils/hub.py
Expand Up @@ -19,10 +19,12 @@
import re
import shutil
import sys
import tempfile
import traceback
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from uuid import uuid4

import huggingface_hub
Expand All @@ -37,7 +39,7 @@
whoami,
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import (
EntryNotFoundError,
LocalEntryNotFoundError,
Expand Down Expand Up @@ -124,6 +126,11 @@ def is_offline_mode():
_CACHED_NO_EXIST = object()


def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")


def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
Expand Down Expand Up @@ -541,6 +548,32 @@ def get_file_from_repo(
)


def download_url(url, proxies=None):
"""
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
for deprecated behavior allowing to download config/models with a single url instead of using the Hub.
Args:
url (`str`): The url of the file to download.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
Returns:
`str`: The location of the temporary file where the url was downloaded.
"""
warnings.warn(
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
" multiple processes (each process will download the file in a different temporary file)."
)
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(url, f, proxies=proxies)
return tmp_file


def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
Expand Down

0 comments on commit f89f16a

Please sign in to comment.