Skip to content

Commit

Permalink
Load GitHub datasets from Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed Mar 30, 2022
1 parent fae366b commit ac61fdd
Showing 1 changed file with 39 additions and 135 deletions.
174 changes: 39 additions & 135 deletions src/datasets/load.py
Expand Up @@ -482,91 +482,6 @@ def get_module(self) -> MetricModule:
raise NotImplementedError


class GithubDatasetModuleFactory(_DatasetModuleFactory):
"""
Get the module of a dataset from GitHub (legacy).
The dataset script is downloaded from GitHub.
This class will eventually be removed and a HubDatasetModuleFactory will be used instead.
"""

def __init__(
self,
name: str,
revision: Optional[Union[str, Version]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
dynamic_modules_path: Optional[str] = None,
):
self.name = name
self.revision = revision
self.download_config = download_config or DownloadConfig()
self.download_mode = download_mode
self.dynamic_modules_path = dynamic_modules_path
assert self.name.count("/") == 0
increase_load_count(name, resource_type="dataset")

def download_loading_script(self, revision: Optional[str]) -> str:
file_path = hf_github_url(path=self.name, name=self.name + ".py", revision=revision)
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading builder script"
return cached_path(file_path, download_config=download_config)

def download_dataset_infos_file(self, revision: Optional[str]) -> str:
dataset_infos = hf_github_url(path=self.name, name=config.DATASETDICT_INFOS_FILENAME, revision=revision)
# Download the dataset infos file if available
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading metadata"
try:
return cached_path(
dataset_infos,
download_config=download_config,
)
except (FileNotFoundError, ConnectionError):
return None

def get_module(self) -> DatasetModule:
# get script and other files
revision = self.revision
try:
local_path = self.download_loading_script(revision)
except FileNotFoundError:
if revision is not None or os.getenv("HF_SCRIPTS_VERSION", None) is not None:
raise
else:
revision = "master"
local_path = self.download_loading_script(revision)
logger.warning(
f"Couldn't find a directory or a dataset named '{self.name}' in this version. "
f"It was picked from the master branch on github instead."
)
dataset_infos_path = self.download_dataset_infos_file(revision)
imports = get_imports(local_path)
local_imports = _download_additional_modules(
name=self.name,
base_path=hf_github_url(path=self.name, name="", revision=revision),
imports=imports,
download_config=self.download_config,
)
additional_files = [(config.DATASETDICT_INFOS_FILENAME, dataset_infos_path)] if dataset_infos_path else []
# copy the script and the files in an importable directory
dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
module_path, hash = _create_importable_file(
local_path=local_path,
local_imports=local_imports,
additional_files=additional_files,
dynamic_modules_path=dynamic_modules_path,
module_namespace="datasets",
name=self.name,
download_mode=self.download_mode,
)
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {"hash": hash, "base_path": hf_hub_url(self.name, "", revision=self.revision)}
return DatasetModule(module_path, hash, builder_kwargs)


class GithubMetricModuleFactory(_MetricModuleFactory):
"""Get the module of a metric. The metric script is downloaded from GitHub."""

Expand Down Expand Up @@ -895,11 +810,10 @@ def __init__(
self.download_config = download_config or DownloadConfig()
self.download_mode = download_mode
self.dynamic_modules_path = dynamic_modules_path
assert self.name.count("/") == 1
increase_load_count(name, resource_type="dataset")

def download_loading_script(self) -> str:
file_path = hf_hub_url(path=self.name, name=self.name.split("/")[1] + ".py", revision=self.revision)
file_path = hf_hub_url(path=self.name, name=self.name.split("/")[-1] + ".py", revision=self.revision)
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading builder script"
Expand Down Expand Up @@ -1170,60 +1084,50 @@ def dataset_module_factory(
elif is_relative_path(path) and path.count("/") <= 1 and not force_local_path:
try:
_raise_if_offline_mode_is_enabled()
if path.count("/") == 0: # even though the dataset is on the Hub, we get it from GitHub for now
# TODO(QL): use a Hub dataset module factory instead of GitHub
return GithubDatasetModuleFactory(
hf_api = HfApi(config.HF_ENDPOINT)
try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
repo_id=path,
revision=revision,
token=token,
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
if isinstance(
e,
(
OfflineModeIsEnabled,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
),
):
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).__name__})")
elif "404" in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
raise FileNotFoundError(msg + f" at revision '{revision}'" if revision else msg)
else:
raise e
if filename in [sibling.rfilename for sibling in dataset_info.siblings]:
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
).get_module()
elif path.count("/") == 1: # community dataset on the Hub
hf_api = HfApi(config.HF_ENDPOINT)
try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
repo_id=path,
revision=revision,
token=token,
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
if isinstance(
e,
(
OfflineModeIsEnabled,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
),
):
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).__name__})")
elif "404" in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
raise FileNotFoundError(msg + f" at revision '{revision}'" if revision else msg)
else:
raise e
if filename in [sibling.rfilename for sibling in dataset_info.siblings]:
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
).get_module()
else:
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
else:
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
except Exception as e1: # noqa: all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
Expand Down

1 comment on commit ac61fdd

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==5.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.012677 / 0.011353 (0.001324) 0.004774 / 0.011008 (-0.006234) 0.039850 / 0.038508 (0.001342) 0.044705 / 0.023109 (0.021595) 0.365700 / 0.275898 (0.089802) 0.395176 / 0.323480 (0.071696) 0.009491 / 0.007986 (0.001506) 0.004389 / 0.004328 (0.000061) 0.010951 / 0.004250 (0.006701) 0.046207 / 0.037052 (0.009155) 0.373487 / 0.258489 (0.114998) 0.410899 / 0.293841 (0.117058) 0.050375 / 0.128546 (-0.078172) 0.016336 / 0.075646 (-0.059310) 0.311886 / 0.419271 (-0.107385) 0.066869 / 0.043533 (0.023337) 0.380650 / 0.255139 (0.125511) 0.384435 / 0.283200 (0.101235) 0.111310 / 0.141683 (-0.030373) 2.123140 / 1.452155 (0.670985) 2.228797 / 1.492716 (0.736081)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.309449 / 0.018006 (0.291443) 0.537996 / 0.000490 (0.537506) 0.022647 / 0.000200 (0.022447) 0.000164 / 0.000054 (0.000110)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.028070 / 0.037411 (-0.009341) 0.113880 / 0.014526 (0.099354) 0.130931 / 0.176557 (-0.045626) 0.174629 / 0.737135 (-0.562507) 0.131153 / 0.296338 (-0.165186)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.577642 / 0.215209 (0.362433) 5.968018 / 2.077655 (3.890364) 2.259830 / 1.504120 (0.755710) 1.910023 / 1.541195 (0.368828) 1.928502 / 1.468490 (0.460012) 0.723872 / 4.584777 (-3.860905) 6.566814 / 3.745712 (2.821102) 3.319620 / 5.269862 (-1.950241) 1.589559 / 4.565676 (-2.976118) 0.097143 / 0.424275 (-0.327132) 0.014620 / 0.007607 (0.007013) 0.745908 / 0.226044 (0.519864) 7.581438 / 2.268929 (5.312509) 2.961568 / 55.444624 (-52.483056) 2.398666 / 6.876477 (-4.477810) 2.553461 / 2.142072 (0.411388) 0.919799 / 4.805227 (-3.885429) 0.188240 / 6.500664 (-6.312424) 0.077924 / 0.075469 (0.002455)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 2.037164 / 1.841788 (0.195377) 16.420081 / 8.074308 (8.345773) 42.253090 / 10.191392 (32.061698) 1.108904 / 0.680424 (0.428480) 0.668605 / 0.534201 (0.134404) 0.600358 / 0.579283 (0.021075) 0.737218 / 0.434364 (0.302854) 0.403897 / 0.540337 (-0.136440) 0.451412 / 1.386936 (-0.935524)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.011191 / 0.011353 (-0.000162) 0.005176 / 0.011008 (-0.005832) 0.037674 / 0.038508 (-0.000834) 0.040779 / 0.023109 (0.017670) 0.388168 / 0.275898 (0.112270) 0.434899 / 0.323480 (0.111420) 0.008394 / 0.007986 (0.000409) 0.006100 / 0.004328 (0.001771) 0.009469 / 0.004250 (0.005219) 0.043861 / 0.037052 (0.006809) 0.370784 / 0.258489 (0.112295) 0.426238 / 0.293841 (0.132397) 0.051173 / 0.128546 (-0.077374) 0.014787 / 0.075646 (-0.060859) 0.334245 / 0.419271 (-0.085027) 0.070667 / 0.043533 (0.027134) 0.386597 / 0.255139 (0.131458) 0.411003 / 0.283200 (0.127803) 0.111755 / 0.141683 (-0.029928) 2.206759 / 1.452155 (0.754605) 2.282500 / 1.492716 (0.789784)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.316889 / 0.018006 (0.298883) 0.543719 / 0.000490 (0.543229) 0.025931 / 0.000200 (0.025731) 0.000447 / 0.000054 (0.000393)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.031785 / 0.037411 (-0.005626) 0.120884 / 0.014526 (0.106358) 0.129051 / 0.176557 (-0.047505) 0.189732 / 0.737135 (-0.547403) 0.130795 / 0.296338 (-0.165543)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.610866 / 0.215209 (0.395657) 6.116104 / 2.077655 (4.038449) 2.389571 / 1.504120 (0.885451) 2.021205 / 1.541195 (0.480011) 2.057802 / 1.468490 (0.589312) 0.730684 / 4.584777 (-3.854093) 6.779492 / 3.745712 (3.033779) 2.855280 / 5.269862 (-2.414582) 1.489573 / 4.565676 (-3.076103) 0.079955 / 0.424275 (-0.344320) 0.013629 / 0.007607 (0.006022) 0.766337 / 0.226044 (0.540292) 7.462526 / 2.268929 (5.193597) 3.063317 / 55.444624 (-52.381307) 2.348684 / 6.876477 (-4.527793) 2.448304 / 2.142072 (0.306231) 0.892524 / 4.805227 (-3.912704) 0.184002 / 6.500664 (-6.316662) 0.075993 / 0.075469 (0.000524)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 2.143943 / 1.841788 (0.302156) 17.055055 / 8.074308 (8.980747) 41.833109 / 10.191392 (31.641717) 1.112527 / 0.680424 (0.432103) 0.680060 / 0.534201 (0.145859) 0.633664 / 0.579283 (0.054381) 0.753380 / 0.434364 (0.319016) 0.465359 / 0.540337 (-0.074979) 0.472593 / 1.386936 (-0.914343)

CML watermark

Please sign in to comment.