From ac61fddfcbc422fe04435aec69b8914b88260e42 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 30 Mar 2022 10:49:31 +0200 Subject: [PATCH] Load GitHub datasets from Hub --- src/datasets/load.py | 174 ++++++++++--------------------------------- 1 file changed, 39 insertions(+), 135 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 445c585aa89..ed06e9cf82b 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -482,91 +482,6 @@ def get_module(self) -> MetricModule: raise NotImplementedError -class GithubDatasetModuleFactory(_DatasetModuleFactory): - """ - Get the module of a dataset from GitHub (legacy). - The dataset script is downloaded from GitHub. - This class will eventually be removed and a HubDatasetModuleFactory will be used instead. - """ - - def __init__( - self, - name: str, - revision: Optional[Union[str, Version]] = None, - download_config: Optional[DownloadConfig] = None, - download_mode: Optional[DownloadMode] = None, - dynamic_modules_path: Optional[str] = None, - ): - self.name = name - self.revision = revision - self.download_config = download_config or DownloadConfig() - self.download_mode = download_mode - self.dynamic_modules_path = dynamic_modules_path - assert self.name.count("/") == 0 - increase_load_count(name, resource_type="dataset") - - def download_loading_script(self, revision: Optional[str]) -> str: - file_path = hf_github_url(path=self.name, name=self.name + ".py", revision=revision) - download_config = self.download_config.copy() - if download_config.download_desc is None: - download_config.download_desc = "Downloading builder script" - return cached_path(file_path, download_config=download_config) - - def download_dataset_infos_file(self, revision: Optional[str]) -> str: - dataset_infos = hf_github_url(path=self.name, name=config.DATASETDICT_INFOS_FILENAME, revision=revision) - # Download the dataset infos file if available - download_config = self.download_config.copy() - if download_config.download_desc is None: - download_config.download_desc = "Downloading metadata" - try: - return cached_path( - dataset_infos, - download_config=download_config, - ) - except (FileNotFoundError, ConnectionError): - return None - - def get_module(self) -> DatasetModule: - # get script and other files - revision = self.revision - try: - local_path = self.download_loading_script(revision) - except FileNotFoundError: - if revision is not None or os.getenv("HF_SCRIPTS_VERSION", None) is not None: - raise - else: - revision = "master" - local_path = self.download_loading_script(revision) - logger.warning( - f"Couldn't find a directory or a dataset named '{self.name}' in this version. " - f"It was picked from the master branch on github instead." - ) - dataset_infos_path = self.download_dataset_infos_file(revision) - imports = get_imports(local_path) - local_imports = _download_additional_modules( - name=self.name, - base_path=hf_github_url(path=self.name, name="", revision=revision), - imports=imports, - download_config=self.download_config, - ) - additional_files = [(config.DATASETDICT_INFOS_FILENAME, dataset_infos_path)] if dataset_infos_path else [] - # copy the script and the files in an importable directory - dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules() - module_path, hash = _create_importable_file( - local_path=local_path, - local_imports=local_imports, - additional_files=additional_files, - dynamic_modules_path=dynamic_modules_path, - module_namespace="datasets", - name=self.name, - download_mode=self.download_mode, - ) - # make the new module to be noticed by the import system - importlib.invalidate_caches() - builder_kwargs = {"hash": hash, "base_path": hf_hub_url(self.name, "", revision=self.revision)} - return DatasetModule(module_path, hash, builder_kwargs) - - class GithubMetricModuleFactory(_MetricModuleFactory): """Get the module of a metric. The metric script is downloaded from GitHub.""" @@ -895,11 +810,10 @@ def __init__( self.download_config = download_config or DownloadConfig() self.download_mode = download_mode self.dynamic_modules_path = dynamic_modules_path - assert self.name.count("/") == 1 increase_load_count(name, resource_type="dataset") def download_loading_script(self) -> str: - file_path = hf_hub_url(path=self.name, name=self.name.split("/")[1] + ".py", revision=self.revision) + file_path = hf_hub_url(path=self.name, name=self.name.split("/")[-1] + ".py", revision=self.revision) download_config = self.download_config.copy() if download_config.download_desc is None: download_config.download_desc = "Downloading builder script" @@ -1170,60 +1084,50 @@ def dataset_module_factory( elif is_relative_path(path) and path.count("/") <= 1 and not force_local_path: try: _raise_if_offline_mode_is_enabled() - if path.count("/") == 0: # even though the dataset is on the Hub, we get it from GitHub for now - # TODO(QL): use a Hub dataset module factory instead of GitHub - return GithubDatasetModuleFactory( + hf_api = HfApi(config.HF_ENDPOINT) + try: + if isinstance(download_config.use_auth_token, bool): + token = HfFolder.get_token() if download_config.use_auth_token else None + else: + token = download_config.use_auth_token + dataset_info = hf_api.dataset_info( + repo_id=path, + revision=revision, + token=token, + timeout=100.0, + ) + except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist + if isinstance( + e, + ( + OfflineModeIsEnabled, + requests.exceptions.ConnectTimeout, + requests.exceptions.ConnectionError, + ), + ): + raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).__name__})") + elif "404" in str(e): + msg = f"Dataset '{path}' doesn't exist on the Hub" + raise FileNotFoundError(msg + f" at revision '{revision}'" if revision else msg) + else: + raise e + if filename in [sibling.rfilename for sibling in dataset_info.siblings]: + return HubDatasetModuleFactoryWithScript( path, revision=revision, download_config=download_config, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path, ).get_module() - elif path.count("/") == 1: # community dataset on the Hub - hf_api = HfApi(config.HF_ENDPOINT) - try: - if isinstance(download_config.use_auth_token, bool): - token = HfFolder.get_token() if download_config.use_auth_token else None - else: - token = download_config.use_auth_token - dataset_info = hf_api.dataset_info( - repo_id=path, - revision=revision, - token=token, - timeout=100.0, - ) - except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist - if isinstance( - e, - ( - OfflineModeIsEnabled, - requests.exceptions.ConnectTimeout, - requests.exceptions.ConnectionError, - ), - ): - raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).__name__})") - elif "404" in str(e): - msg = f"Dataset '{path}' doesn't exist on the Hub" - raise FileNotFoundError(msg + f" at revision '{revision}'" if revision else msg) - else: - raise e - if filename in [sibling.rfilename for sibling in dataset_info.siblings]: - return HubDatasetModuleFactoryWithScript( - path, - revision=revision, - download_config=download_config, - download_mode=download_mode, - dynamic_modules_path=dynamic_modules_path, - ).get_module() - else: - return HubDatasetModuleFactoryWithoutScript( - path, - revision=revision, - data_dir=data_dir, - data_files=data_files, - download_config=download_config, - download_mode=download_mode, - ).get_module() + else: + return HubDatasetModuleFactoryWithoutScript( + path, + revision=revision, + data_dir=data_dir, + data_files=data_files, + download_config=download_config, + download_mode=download_mode, + ).get_module() except Exception as e1: # noqa: all the attempts failed, before raising the error we should check if the module is already cached. try: return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()