Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow dataloading with big datasets issue persists #2252

Closed
hwijeen opened this issue Apr 23, 2021 · 70 comments · Fixed by #6464
Closed

Slow dataloading with big datasets issue persists #2252

hwijeen opened this issue Apr 23, 2021 · 70 comments · Fixed by #6464
Assignees

Comments

@hwijeen
Copy link

hwijeen commented Apr 23, 2021

Hi,

I reported too slow data fetching when data is large(#2210) a couple of weeks ago, and @lhoestq referred me to the fix (#2122).
However, the problem seems to persist. Here is the profiled results:

  1. Running with 60GB
Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  517.96         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
model_backward                     	|  0.26144        	|100            	|  26.144         	|  5.0475         	|
model_forward                      	|  0.11123        	|100            	|  11.123         	|  2.1474         	|
get_train_batch                    	|  0.097121       	|100            	|  9.7121         	|  1.8751         	|
  1. Running with 600GB, datasets==1.6.0
Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  4563.2         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
get_train_batch                    	|  5.1279         	|100            	|  512.79         	|  11.237         	|
model_backward                     	|  4.8394         	|100            	|  483.94         	|  10.605         	|
model_forward                      	|  0.12162        	|100            	|  12.162         	|  0.26652        	|

I see that get_train_batch lags when data is large. Could this be related to different issues?
I would be happy to provide necessary information to investigate.

@lhoestq
Copy link
Member

lhoestq commented Apr 23, 2021

Hi ! Sorry to hear that. This may come from another issue then.

First can we check if this latency comes from the dataset itself ?
You can try to load your dataset and benchmark the speed of querying random examples inside it ?

import time
import numpy as np

from datasets import load_from_disk

dataset = load_from_disk(...) # or from load_dataset...

_start = time.time()
n = 100
for i in np.random.default_rng(42).integers(0, len(dataset), size=n):
    _ = dataset[i]
print(time.time() - _start)

If we see a significant speed difference between your two datasets then it would mean that there's an issue somewhere

@hwijeen
Copy link
Author

hwijeen commented Apr 23, 2021

Hi @lhoestq, here is the result. I additionally measured time to load_from_disk:

  • 60GB
loading took:  22.618776321411133
ramdom indexing 100 times took: 0.10214924812316895
  • 600GB
loading took:  1176.1764674186707
ramdom indexing 100 times took: 2.853600025177002

Hmm.. I double checked that it's version 1.6.0. The difference seems quite big, could it be related to the running environment?

@lhoestq
Copy link
Member

lhoestq commented Apr 23, 2021

I'm surprised by the speed change. Can you give more details about your dataset ?
The speed depends on the number of batches in the arrow tables and the distribution of the lengths of the batches.
You can access the batches by doing dataset.data.to_batches() (use only for debugging) (it doesn't bring data in memory).

Also can you explain what parameters you used if you used map calls ?
Also if you have some code that reproduces the issue I'd be happy to investigate it.

@lhoestq
Copy link
Member

lhoestq commented Apr 23, 2021

Also if you could give us more info about your env like your OS, version of pyarrow and if you're using an HDD or a SSD

@hwijeen
Copy link
Author

hwijeen commented Apr 26, 2021

Here are some details of my 600GB dataset. This is a dataset AFTER the map function and once I load this dataset, I do not use map anymore in the training. Regarding the distribution of the lengths, it is almost uniform (90% is 512 tokens, and 10% is randomly shorter than that -- typical setting for language modeling).

len(batches):
492763

batches[0]: 
pyarrow.RecordBatch
attention_mask: list<item: uint8>
  child 0, item: uint8
input_ids: list<item: int16>
  child 0, item: int16
special_tokens_mask: list<item: uint8>
  child 0, item: uint8
token_type_ids: list<item: uint8>
  child 0, item: uint8

Here the some parameters to map function just in case it is relevant:

num_proc=1    # as multi processing is slower in my case
load_from_cache_file=False

@hwijeen
Copy link
Author

hwijeen commented Apr 26, 2021

Regarding the environment, I am running the code on a cloud server. Here are some info:

Ubuntu 18.04.5 LTS   # cat /etc/issue
pyarrow                 3.0.0  # pip list | grep pyarrow

The data is stored in SSD and it is mounted to the machine via Network File System.

If you could point me to some of the commands to check the details of the environment, I would be happy to provide relevant information @lhoestq !

@hwijeen
Copy link
Author

hwijeen commented Apr 26, 2021

I am not sure how I could provide you with the reproducible code, since the problem only arises when the data is big. For the moment, I would share the part that I think is relevant. Feel free to ask me for more info.

class MyModel(pytorch_lightning.LightningModule)
    def setup(self, stage):
        self.dataset = datasets.load_from_disk(path)
        self.dataset.set_format("torch")

    def train_dataloader(self):
        collate_fn = transformers.DataCollatorForLanguageModeling(
                tokenizer=transformers.ElectraTokenizerFast.from_pretrained(tok_path)
        )
        dataloader = torch.utils.DataLoader(
                self.dataset,
                batch_size=32,
                collate_fn=collate_fn,
                num_workers=8,
                pin_memory=True,
       )

@lhoestq
Copy link
Member

lhoestq commented May 10, 2021

Hi ! Sorry for the delay I haven't had a chance to take a look at this yet. Are you still experiencing this issue ?
I'm asking because the latest patch release 1.6.2 fixed a few memory issues that could have lead to slow downs

@hwijeen
Copy link
Author

hwijeen commented May 19, 2021

Hi! I just ran the same code with different datasets (one is 60 GB and another 600 GB), and the latter runs much slower. ETA differs by 10x.

@BenoitDalFerro
Copy link

@lhoestq and @hwijeen

Despite upgrading to datasets 1.6.2, still experiencing extremely slow (2h00) loading for a 300Gb local dataset shard size 1.1Gb on local HDD (40Mb/s read speed). This corresponds almost exactly to total data divided by reading speed implying that it reads the entire dataset at each load.

Stack details:

GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration: GPU 0: GeForce GTX 1050
Nvidia driver version: 457.63
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\bin\cudnn64_7.dll
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] datasets==1.6.2
[pip3] transformers==4.5.1
[pip3] numpy==1.19.1
[pip3] numpydoc==1.1.0
[pip3] pytorch-metric-learning==0.9.98
[pip3] torch==1.8.1
[pip3] torchaudio==0.8.1
[pip3] torchvision==0.2.2
[conda] blas 2.16 mkl conda-forge
[conda] cudatoolkit 10.2.89 hb195166_8 conda-forge
[conda] libblas 3.8.0 16_mkl conda-forge
[conda] libcblas 3.8.0 16_mkl conda-forge
[conda] liblapack 3.8.0 16_mkl conda-forge
[conda] liblapacke 3.8.0 16_mkl conda-forge
[conda] mkl 2020.1 216
[conda] numpy 1.19.1 py37hae9e721_0 conda-forge
[conda] numpydoc 1.1.0 py_1 conda-forge
[conda] pytorch 1.8.1 py3.7_cuda10.2_cudnn7_0 pytorch
[conda] pytorch-metric-learning 0.9.98 pyh39e3cac_0 metric-learning
[conda] torchaudio 0.8.1 py37 pytorch
[conda] torchvision 0.2.2 py_3 pytorch

@lhoestq
Copy link
Member

lhoestq commented May 19, 2021

Hi @BenoitDalFerro how do your load your dataset ?

@BenoitDalFerro
Copy link

BenoitDalFerro commented May 19, 2021

Hi @lhoestq thanks for the quick turn-around, actually the plain vanilla way, without an particular knack or fashion, I tried to look into the documentation for some alternative but couldn't find any

dataset = load_from_disk(dataset_path=os.path.join(datasets_dir,dataset_dir))

@tsproisl
Copy link

I’m facing the same issue when loading a 900GB dataset (stored via save_to_disk): load_from_disk(path_to_dir) takes 1.5 hours and htop consistently shows high IO rates > 120 M/s.

@BenoitDalFerro
Copy link

BenoitDalFerro commented May 20, 2021

@tsproisl same here, smells like teen spirit intended generator inadvertently ending up iterator

@lhoestq perhaps solution to detect bug location in code is to track its signature via HD read usage monitoring, option is to add tracking decorator on top each function and sequentially close all hatches from top to bottom, suggest PySmart https://pypi.org/project/pySMART/ a Smartmontools implementation

@lhoestq lhoestq self-assigned this May 20, 2021
@lhoestq
Copy link
Member

lhoestq commented May 21, 2021

I wasn't able to reproduce this on a toy dataset of around 300GB:

import datasets as ds

s = ds.load_dataset("squad", split="train")
s4000 = ds.concatenate_datasets([s] * 4000)
print(ds.utils.size_str(s4000.data.nbytes))  # '295.48 GiB'

s4000.save_to_disk("tmp/squad_4000")
import psutil
import time
from datasets import load_from_disk

disk = "disk0"  # You may have to change your disk here
iocnt1 = psutil.disk_io_counters(perdisk=True)[disk]
time1 = time.time()

s4000_reloaded = load_from_disk("tmp/squad_4000")

time2 = time.time()
iocnt2 = psutil.disk_io_counters(perdisk=True)[disk]

print(f"Blocks read {iocnt2.read_count - iocnt1.read_count}")  # Blocks read 18
print(f"Elapsed time: {time2 - time1:.02f}s")  # Elapsed time: 14.60s

Could you run this on your side and tell me if how much time it takes ? Please run this when your machine is idle so that other processes don't interfere.

I got these results on my macbook pro on datasets 1.6.2

@BenoitDalFerro
Copy link

BenoitDalFerro commented May 21, 2021

@lhoestq thanks, test running as we speak, bear with me

@lhoestq
Copy link
Member

lhoestq commented May 21, 2021

Just tried on google colab and got ~1min for a 15GB dataset (only 200 times SQuAD), while it should be instantaneous. The time is spent reading the Apache Arrow table from the memory mapped file. This might come a virtual disk management issue. I'm trying to see if I can still speed it up on colab.

@BenoitDalFerro
Copy link

BenoitDalFerro commented May 21, 2021

@lhoestq what is Google Colab's HD read speed, is it possible to introspect incl. make like SSD or HDD ?

@tsproisl
Copy link

@lhoestq Thank you! The issue is getting more interesting. The second script is still running, but it's definitely taking much longer than 15 seconds.

@tsproisl
Copy link

Okay, here’s the ouput:
Blocks read 158396
Elapsed time: 529.10s

Also using datasets 1.6.2. Do you have any ideas, how to pinpoint the problem?

@BenoitDalFerro
Copy link

BenoitDalFerro commented May 21, 2021

@lhoestq, @tsproisl mmmh still writing on my side about 1h to go, thinking on it are your large datasets all monoblock unsharded ? mine is 335 times 1.18Gb shards.

@tsproisl
Copy link

The 529.10s was a bit too optimistic. I cancelled the reading process once before running it completely, therefore the harddrive cache probably did its work.

Here are three consecutive runs
First run (freshly written to disk):
Blocks read 309702
Elapsed time: 1267.74s
Second run (immediately after):
Blocks read 113944
Elapsed time: 417.55s
Third run (immediately after):
Blocks read 42518
Elapsed time: 199.19s

@BenoitDalFerro
Copy link

BenoitDalFerro commented May 21, 2021

@lhoestq
First test

elapsed time: 11219.05s

Second test running bear with me, for Windows users slight trick to modify original "disk0" string:

First find physical unit relevant key in dictionnary

import psutil
psutil.disk_io_counters(perdisk=True)

{'PhysicalDrive0': sdiskio(read_count=18453286, write_count=4075333, read_bytes=479546467840, write_bytes=161590275072, read_time=20659, write_time=2464),
'PhysicalDrive1': sdiskio(read_count=1495778, write_count=388781, read_bytes=548628622336, write_bytes=318234849280, read_time=426066, write_time=19085)}

In my case it's PhysicalDrive1

Then insert relevant key's string as disk variable

psutil.disk_io_counters()
disk = 'PhysicalDrive1'  # You may have to change your disk here
iocnt1 = psutil.disk_io_counters(perdisk=True)[disk]
time1 = time.time()
s4000_reloaded = load_from_disk("your path here")
time2 = time.time()
iocnt2 = psutil.disk_io_counters(perdisk=True)[disk]
print(f"Blocks read {iocnt2.read_count - iocnt1.read_count}")  # Blocks read 18
print(f"Elapsed time: {time2 - time1:.02f}s")  # Elapsed time: 14.60s

@BenoitDalFerro
Copy link

@lhoestq
Second test

Blocks read 1265609
Elapsed time: 11216.55s

@BenoitDalFerro
Copy link

@lhoestq any luck ?

@lhoestq
Copy link
Member

lhoestq commented May 26, 2021

Unfortunately no. Thanks for running the benchmark though, it shows that you machine does a lot of read operations. This is not expected: in other machines it does almost no read operations which enables a very fast loading.

I did some tests on google colab and have the same issue. The first time the dataset arrow file is memory mapped takes always a lot of time (time seems linear with respect to the dataset size). Reloading the dataset is then instantaneous since the arrow file has already been memory mapped.

I also tried using the Arrow IPC file format (see #1933) instead of the current streaming format that we use but it didn't help.

Memory mapping is handled by the OS and depends on the disk you're using, so I'm not sure we can do much about it. I'll continue to investigate anyway, because I still don't know why in some cases it would go through the entire file (high Blocks read as in your tests) and in other cases it would do almost no reading.

@BenoitDalFerro
Copy link

@lhoestq thanks for the effort, let's stay in touch

@gurvindersingh
Copy link

gurvindersingh commented Jul 20, 2021

Just want to say that I am seeing the same issue. Dataset size if 268GB and it takes 3 hours to load load_from_disk, using dataset version 1.9.0. Filesystem underneath is Lustre

@lhoestq
Copy link
Member

lhoestq commented Dec 8, 2022

Cool !

but I'm curious why was the dataset so slow why is this happening bcs I do some preprocessing with the Dataset.map() function and it works quite fast (for 34GB of text) it takes around 1,5 hour to process the data (tokenization, chunk merging. etc.) with 6 workers, but then once I'm using this preprocessed dataset the iteration is significantly slower like 10times e.g., 2.2it/s vs 5s/batch.

When you process an unshuffled dataset with map, you iterate over contiguous chunks of data, which is very fast. You get the best speed when you have an iterable dataset as well, when it's based on shards of contiguous data.

This is fast because internally Arrow simply iterates over the record batches.

On the other hand, if you use a map-style dataset in PyTorch, then PyTorch samples uniformly from the files on your disk. This is slower for your disk, and also requires an extra step to get the location of the examples from an index.

@pauli31
Copy link

pauli31 commented Dec 8, 2022

Cool !

but I'm curious why was the dataset so slow why is this happening bcs I do some preprocessing with the Dataset.map() function and it works quite fast (for 34GB of text) it takes around 1,5 hour to process the data (tokenization, chunk merging. etc.) with 6 workers, but then once I'm using this preprocessed dataset the iteration is significantly slower like 10times e.g., 2.2it/s vs 5s/batch.

When you process an unshuffled dataset with map, you iterate over contiguous chunks of data, which is very fast. You get the best speed when you have an iterable dataset as well, when it's based on shards of contiguous data.

This is fast because internally Arrow simply iterates over the record batches.

On the other hand, if you use a map-style dataset in PyTorch, then PyTorch samples uniformly from the files on your disk. This is slower for your disk, and also requires an extra step to get the location of the examples from an index.

Now it makes sense, I thought that even the map-style unshuffled dataset, will be processed iteratively (I mean from start to end without any sampling). Great!

@JeanKaddour
Copy link

Hey all,

I'm facing the same issue with the PILE.

 raw_datasets = load_dataset(
        "EleutherAI/the_pile_deduplicated",
        cache_dir=CACHE_DIR,
        ignore_verifications=True,
    )

takes ~1h 30min, although I already cached it.

Later in my code, I use

tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=psutil.cpu_count(),
            remove_columns=column_names,
            batch_size=100,
            load_from_cache_file=True,
            cache_file_names=cache_file_names,
        )

Is there any way I can instantiate the (also already cached) tokenized dataset directly without having to wait until raw_datasets is instantiated first?

@lhoestq
Copy link
Member

lhoestq commented Jan 31, 2023

takes ~1h 30min, although I already cached it.

An alternative is to load the dataset as iterable but this is not implemented yet, see #5481

Is there any way I can instantiate the (also already cached) tokenized dataset directly without having to wait until raw_datasets is instantiated first?

If you want to skip that step, next time I'd recommend you to save the dataset somewhere after tokenization (e.g. using .save_to_disk()) and reload it from there instead of relying on the cache.

Though you could look for the cached arrow files in your cache and reload the data from there if you're adventurous. You can use Dataset.from_file to reload a file, and then concatenate_datasets to concatenate all the chunks.

@JeanKaddour
Copy link

Thank you so much, you made my day.

The adventurous route worked out well; it now only takes ~5 min :)

@cakeislife100
Copy link

@lhoestq I'm also still having this issue where load_from_disk becomes exponentially slower on larger datasets. I've tried the suggestion here to shard and then concatenate the dataset but it's still just as slow as using load_from_disk out of the box.

Do you have any other suggestions on what to try here? What is the best guidance around how to use the datasets library with >1TB datasets? I'm guessing it's not supposed to be reading the entire dataset in during a load_from_disk call.

In case it helps, this is my code for loading and concatenating.

def concatenate_ds_shards(path):
    arrow_files = get_arrow_files(path)  # get the arrow files generated by `save_to_disk(...)`
    shards = [Dataset.from_file(f) for f in arrow_files]
    return ds.concatenate_datasets(shards)

@lhoestq
Copy link
Member

lhoestq commented Aug 23, 2023

Can you try using IterableDataset.from_file instead ? This should be faster since it doesn't memory map the full dataset and doesn't read the arrow record batches metadata either.

You'll get an IterableDataset which is more suited for bigger datasets. You can find some documentation about the differences about Dataset vs IterableDataset here

@cakeislife100
Copy link

cakeislife100 commented Aug 25, 2023

@lhoestq thank you for the suggestion! Can we also use IterableDataset directly within DDP setups (multi-GPU or multi-node)? Or does some special work need to be done?

@lhoestq
Copy link
Member

lhoestq commented Aug 26, 2023

Yup using split_dataset_by_node for example (see docs here)

@RmZeta2718
Copy link

Can you try using IterableDataset.from_file instead ? This should be faster since it doesn't memory map the full dataset and doesn't read the arrow record batches metadata either.

@lhoestq Why memory map the full dataset is slow? It is said to be memory efficient, so I assume load_dataset should also be fast (given that only few data are loaded).

I think (if I didn't take it wrong) mmap should take nearly no time and returns immediately, and all data is read from disk later when it is actually wanted. So, no pre-fetch here. But it is not true for load_dataset, which actually reads a lot of data from disk at the very beginning.

Profiling (in this comment, and also in my profiling) shows RecordBatchStreamReader.read_all is the most time-consuming function call. So what is it reading? Is it indeed required to read these data in memory mapped scenario?

@lhoestq
Copy link
Member

lhoestq commented Nov 21, 2023

We memory map the arrow files and use RecordBatchStreamReader.read_all indeed. This does not load the data in memory so it is generally super fast.

However if you have a slow disk and your dataset has thousands of arrow files with lots of record batches then it can take a few minutes to run. This is because of RecordBatchStreamReader.read_all: it doesn't read the data but it does read the record batches metadata (e.g. their length). This is used to know the size of the full dataset for example.

For iterable datasets we don't need to know the length of the dataset in advance, so we don't need to read the metadata, which makes things faster to load.

@RmZeta2718
Copy link

RmZeta2718 commented Nov 21, 2023

@lhoestq Thanks for your quick reply and detailed explanation. What is the proportion of metadata accounts for all data typically?

I'd like to give more details on my experiments if you don't mind taking a look.

I'm basically calling load_datasets and then map on openwebtext, which is around 40GB (~80 cached arrow files with ~500MB each) for load_dataset and a single 17GB cached arrow file for map. Observed from iotop, both load_datasets and map read ~9GB from disk (~100MB/s for ~100 seconds each), combined ~18GB read. I believe this is abnormal: map reads 50% of its cache file.

For a better view of what's in the cache file: I specified the feature of cache file to be Features({'text': Sequence(Value("uint16"), length=2048)}) when calling map.

It's also interesting that load_datasets and map read almost the same amount of metadata. Don't know if it's a coincidence.

Environment:

  • python 3.10.12
  • datasets 2.14.6
  • pyarrow 12.0.1
  • ubuntu 22.04

@lhoestq
Copy link
Member

lhoestq commented Nov 21, 2023

Thanks for investigating @RmZeta2718 , this is useful information and indeed abnormal. Not sure what would cause that but may I ask you to see if reducing the number of record batches in the mapped dataset helps ? You can try passing writer_batch_size=10_000 to do so (one record batch = 10k examples instead of the default 1k)

@RmZeta2718
Copy link

RmZeta2718 commented Nov 27, 2023

Sorry for the delay. @lhoestq

I found NO performance improvement after increasing writer_batch_size. But I found setting batch_size=100_000 (100x default) reduces map time and disk read to a half (~50 seconds and ~4GB). But load_datasets doesn't have any of these options, so it remains slow.

And it's hard to test because once the data are loaded, they are cached in memory. Any consecutive attempts to load the same data will return immediately.

Can you guys reproduce the problem? Testing just load_dataset is enough, I think. In my example, load_dataset read 25% of the cached data (9/40), which is also abnormal. I'm not sure what the data structure of metadata is, but I'm expecting it to be around several KB (if it is a length, then 8B per shard?).

@kkoutini
Copy link
Contributor

I also faced the slow mmap when loading a large dataset (~2TB) using datasets.load_from_disk. The dataset was saved with around ~1100 shards of 2GB. I believe it seems to depend on the file system. I have access to two clusters:

  1. In the first (smaller) cluster, the file storage servers were partitioned as ext4 and mounted as nfs4 on the compute node. datasets.load_from_disk loads the dataset in a couple of minutes. Monitoring the virtual sizeVIRT in htop, I can see that the Python process is mapping ~70 GBs of data per second.
  2. In the second cluster, the file storage is mounted as 'lustre', and the dataset files are stored with 16 stripes (lfs getstripe). datasets.load_from_disk takes around 30–40 minutes. Monitoring the virtual sizeVIRT of the Python process in htop, I see the mapping speed is around 1-2 GBs per second.

Is there any recommendations for stripe_size or stripe-count of luster or the shards size to improve the loading speed?
I prefer not to use iterative datasets since AFAIU, it's not possible to do weighted sampling from an iterative dataset.

@lhoestq
Copy link
Member

lhoestq commented Nov 27, 2023

Can you guys reproduce the problem? Testing just load_dataset is enough, I think. In my example, load_dataset read 25% of the cached data (9/40), which is also abnormal. I'm not sure what the data structure of metadata is, but I'm expecting it to be around several KB (if it is a length, then 8B per shard?).

Reproducing these issues is not easy on our side, given they depend on the setup.

For load_datastet it would be nice to be able to control the size of the batches written on disk, feel free to open an issue if it's something you'd like to see, and we'll discuss there how to do it.

I also faced the slow mmap when loading a large dataset (~2TB) using datasets.load_from_disk. The dataset was saved with around ~1100 shards of 2GB. I believe it seems to depend on the file system. I have access to two clusters:

That's helpful information, thanks ! It seems like Lustre doesn't read at full speed with the memory mapping in datasets

Is there any recommendations for stripe_size or stripe-count of luster or the shards size to improve the loading speed?

I would try increasing the stripe size in case the memory mapping does too much unecessary readahead with the default value

@RmZeta2718
Copy link

Reproducing these issues is not easy on our side, given they depend on the setup.

Hey! This is what I do:

raw_datasets["train"] = load_dataset("openwebtext", split=f"train[5000:]")
raw_datasets["validation"] = load_dataset("openwebtext", split=f"train[:5000]")

I mean there is nothing special in the code, so I believe the slowing issue is general and should be able to be reproduced by a plain call to load_dataset. Environments are listed above. I'm glad to provide any other info if you need.

@kkoutini
Copy link
Contributor

kkoutini commented Nov 28, 2023

I managed to speed up the loading time (on the Lustre file system) by mmapping the arrow shards in parallel (python preload_mmap.py see the script below) and relying on the OS to cache the mmap.

Here are some results:

  1. Without caching (on a fresh node):
# Sequentially using datasets.load_from_disk
 Loading dataset_name: 1865.329 seconds
  1. After calling python preload_mmap.py once to cache (the first time, it takes around 90 seconds with 16 processes).
# return only the lenght of the table from each worker
# python preload_mmap.py p # use processes and return only len of table from workers
Loading dataset_name using num of (returning len) processes=16: 42.837 seconds
# python preload_mmap.py t  # use threads and return only len of table from workers
Loading dataset_name using num of (returning len) threads=16: 105.167 seconds

# return the whole table from each worker
# python preload_mmap.py p table # use processes and return tables from workers
Loading dataset_name using num of (returning table) processes=16: 367.917 seconds
# python preload_mmap.py t  table # use threads and return tables from workers
Loading dataset_name using num of (returning table) threads=16: 260.434 seconds

# Sequentially using datasets.load_from_disk (the dataset has only one split)
Loading dataset_name: 397.046 seconds

It seems that preloading the files in processes (without returning the table) speeds up subsequent load_from_disk calls. However, the communication time to return the tables for concatenation is high (I am not sure how they are pickled).

Threads are slower to mmap the table but faster to communicate. If this works on other file systems, it may be worth it to have the option to load the shards in parallel here.

# preload_mmap.py
import datasets
import os
from datasets.table import MemoryMappedTable, concat_tables
import glob
import logging
from time import perf_counter
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import sys
import concurrent
import functools

logger = datasets.logging.get_logger(__name__)


datasets.logging.set_verbosity_info()



class catchtime:
    # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
    def __init__(self, debug_print="Time", logger=logger):
        self.debug_print = debug_print
        self.logger = logger

    def __enter__(self):
        self.start = perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time = perf_counter() - self.start
        readout = f"{self.debug_print}: {self.time:.3f} seconds"
        self.logger.info(readout)


def load_file(f, return_len=False):
    with catchtime(f"Loading {f}", logger=logger):
        ds = MemoryMappedTable.from_file(f)
    if return_len:  # process pool is slow to serialize
        return len(ds)
    return ds


def load_files(
    files, debug_name="dataset_name", num_proc=16, use_threads=False, return_len=False
):
    if use_threads:
        pool_cls = concurrent.futures.ThreadPoolExecutor
        pool_kwargs = {"max_workers": num_proc}
        debug_desc = "threads"
    else:
        pool_cls = Pool
        pool_kwargs = {"processes": num_proc}
        debug_desc = "processes"
    if return_len:
        debug_desc = "(returning table) " + debug_desc
    else:
        debug_desc = "(returning len) " + debug_desc
    with catchtime(
        f"Loading {debug_name} using num of {debug_desc}={num_proc}", logger=logger
    ):
        with pool_cls(**pool_kwargs) as pool:
            result = list(
                pool.map(functools.partial(load_file, return_len=return_len), files)
            )
    return result


def main(use_threads, return_len):
    datasets.logging.set_verbosity_info()
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    logger.info("Starting")

    jc = "datset_name"
    local = "dataset_path"
    split = "train"
    files = glob.glob(os.path.join(local, jc, split, "*.arrow"))
    files = sorted(files)

    ds = load_files(files, jc, use_threads=use_threads, return_len=return_len)
    if not return_len:
        with catchtime(f"concat_ tables"):
            ds = concat_tables(ds)

    logger.info("done")


if __name__ == "__main__":
    use_threads = False
    return_len = True

    print(
        "Usage: \n Threads: python preload_mmap.py t \n Threads and concatenate datasets: python preload_mmap.py t c"
        "\n processes: python preload_mmap.py p \n processes and concatenate datasets: python preload_mmap.py p c "
    )
    if len(sys.argv) > 1 and sys.argv[1].startswith("t"):
        use_threads = True
    if len(sys.argv) > 2:
        return_len = False

    main(use_threads, return_len)

@wangzihe1996
Copy link

I managed to speed up the loading time (on the Lustre file system) by mmapping the arrow shards in parallel (python preload_mmap.py see the script below) and relying on the OS to cache the mmap.

Here are some results:

  1. Without caching (on a fresh node):
# Sequentially using datasets.load_from_disk
 Loading dataset_name: 1865.329 seconds
  1. After calling python preload_mmap.py once to cache (the first time, it takes around 90 seconds with 16 processes).
# return only the lenght of the table from each worker
# python preload_mmap.py p # use processes and return only len of table from workers
Loading dataset_name using num of (returning len) processes=16: 42.837 seconds
# python preload_mmap.py t  # use threads and return only len of table from workers
Loading dataset_name using num of (returning len) threads=16: 105.167 seconds

# return the whole table from each worker
# python preload_mmap.py p table # use processes and return tables from workers
Loading dataset_name using num of (returning table) processes=16: 367.917 seconds
# python preload_mmap.py t  table # use threads and return tables from workers
Loading dataset_name using num of (returning table) threads=16: 260.434 seconds

# Sequentially using datasets.load_from_disk (the dataset has only one split)
Loading dataset_name: 397.046 seconds

It seems that preloading the files in processes (without returning the table) speeds up subsequent load_from_disk calls. However, the communication time to return the tables for concatenation is high (I am not sure how they are pickled).

Threads are slower to mmap the table but faster to communicate. If this works on other file systems, it may be worth it to have the option to load the shards in parallel here.

# preload_mmap.py
import datasets
import os
from datasets.table import MemoryMappedTable, concat_tables
import glob
import logging
from time import perf_counter
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import sys
import concurrent
import functools

logger = datasets.logging.get_logger(__name__)


datasets.logging.set_verbosity_info()



class catchtime:
    # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
    def __init__(self, debug_print="Time", logger=logger):
        self.debug_print = debug_print
        self.logger = logger

    def __enter__(self):
        self.start = perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time = perf_counter() - self.start
        readout = f"{self.debug_print}: {self.time:.3f} seconds"
        self.logger.info(readout)


def load_file(f, return_len=False):
    with catchtime(f"Loading {f}", logger=logger):
        ds = MemoryMappedTable.from_file(f)
    if return_len:  # process pool is slow to serialize
        return len(ds)
    return ds


def load_files(
    files, debug_name="dataset_name", num_proc=16, use_threads=False, return_len=False
):
    if use_threads:
        pool_cls = concurrent.futures.ThreadPoolExecutor
        pool_kwargs = {"max_workers": num_proc}
        debug_desc = "threads"
    else:
        pool_cls = Pool
        pool_kwargs = {"processes": num_proc}
        debug_desc = "processes"
    if return_len:
        debug_desc = "(returning table) " + debug_desc
    else:
        debug_desc = "(returning len) " + debug_desc
    with catchtime(
        f"Loading {debug_name} using num of {debug_desc}={num_proc}", logger=logger
    ):
        with pool_cls(**pool_kwargs) as pool:
            result = list(
                pool.map(functools.partial(load_file, return_len=return_len), files)
            )
    return result


def main(use_threads, return_len):
    datasets.logging.set_verbosity_info()
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    logger.info("Starting")

    jc = "datset_name"
    local = "dataset_path"
    split = "train"
    files = glob.glob(os.path.join(local, jc, split, "*.arrow"))
    files = sorted(files)

    ds = load_files(files, jc, use_threads=use_threads, return_len=return_len)
    if not return_len:
        with catchtime(f"concat_ tables"):
            ds = concat_tables(ds)

    logger.info("done")


if __name__ == "__main__":
    use_threads = False
    return_len = True

    print(
        "Usage: \n Threads: python preload_mmap.py t \n Threads and concatenate datasets: python preload_mmap.py t c"
        "\n processes: python preload_mmap.py p \n processes and concatenate datasets: python preload_mmap.py p c "
    )
    if len(sys.argv) > 1 and sys.argv[1].startswith("t"):
        use_threads = True
    if len(sys.argv) > 2:
        return_len = False

    main(use_threads, return_len)

I'm very happying that I use this way to acclereate the time of the data load from 15 minutes to 45 seconds. And my dataset size is on the TB scale.

My file system is virtio-fs. I can't know its real file system because my code is in the virtual machine and i can't access the host machine. I guess it should be a distributed file system similar to Luster. I don't know why it is so slow and why multi threading can acclerate it. But what i know is it really acclerates the load time and it is vital for me. Thanks for your code.

@lhoestq
Copy link
Member

lhoestq commented Dec 1, 2023

Nice to see this method validated on multiple setups !

Would be cool to integrate multithreading when memory mapping the Arrow files then
I think this can be added here (for load_dataset):

for f_dict in files:
pa_table: Table = self._get_table_from_filename(f_dict, in_memory=in_memory)
pa_tables.append(pa_table)

and here (for load_from_disk):

arrow_table = concat_tables(
table_cls.from_file(posixpath.join(dest_dataset_path, data_file["filename"]))
for data_file in state["_data_files"]
)

I can take some time next week to do it, but feel free to open a PR if you want to give it a try

@kkoutini
Copy link
Contributor

kkoutini commented Dec 4, 2023

I can take some time next week to do it, but feel free to open a PR if you want to give it a try

Threading seems to work faster in arrow_dataset.py . However, changing arrow_reader.py may require changes in the higher api of DatasetBuilder.as_dataset, which is used in many places including load_dataset, see #6464.

kkoutini added a commit to kkoutini/datasets that referenced this issue Dec 6, 2023
kkoutini added a commit to kkoutini/datasets that referenced this issue Dec 6, 2023
lhoestq added a commit that referenced this issue Jan 26, 2024
* add threadmap to load_from_disk #2252

* Add threadmap to arrow_reader.read_files #2252

* remove old way of loading files

* sort imports

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.