Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The tested speed is not as fast as expected. #60

Open
tikboaHIT opened this issue Mar 7, 2024 · 20 comments
Open

The tested speed is not as fast as expected. #60

tikboaHIT opened this issue Mar 7, 2024 · 20 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@tikboaHIT
Copy link

🐛 Bug

The tested speed is not as fast as expected.

Code sample

import os
import torch
import numpy as np
from tqdm import tqdm
from torchvision.transforms import Compose, Lambda
from litdata import StreamingDataset, StreamingDataLoader

from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo

input_dir = 's3://extract_frames/'
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

class ImagenetStreamingDataset(StreamingDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.transform = Compose(
            [
                Lambda(lambda x: x / 255.0),
                NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
                # ShortSideScale(size=224),
                CenterCropVideo(224),
            ]
        )
    
    def __getitem__(self, index):
        data = super().__getitem__(index)
        video_data = []
        for i in range(8):
            frame = np.array(data["image"][i])
            video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
        video_data = torch.stack(video_data, dim=1)
        video_data = self.transform(video_data)
        return video_data

dataset = ImagenetStreamingDataset(input_dir, shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=64, num_workers=8)
for batch in tqdm(dataloader, total=len(dataloader)):
    pass

Expected behavior

There are approximately 200,000 data points, each consisting of 8 frames extracted. Based on the tested speed, it should be very fast, but in reality, it is not.

Screenshot 2024-03-07 at 20 42 20

The tested speed is approximately as follows:
Screenshot 2024-03-07 at 20 48 30

Environment

  • PyTorch Version (e.g., 1.0): 2.2.1
  • OS (e.g., Linux): linux
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.9
  • CUDA/cuDNN version:11.6
@tikboaHIT tikboaHIT added bug Something isn't working help wanted Extra attention is needed labels Mar 7, 2024
Copy link

github-actions bot commented Mar 7, 2024

Hi! thanks for your contribution!, great first issue!

@tchaton
Copy link
Collaborator

tchaton commented Mar 7, 2024

Hey @tikboaHIT,

The benchmark are fully reproducible for Imagenet. So you can check by yourself the numbers are correct.

For your custom use cases, there is a lot of optimizations possible, especially around your transforms and going through numpy. Would you be open to create a reproducible Studio on https://lightning.ai/.

In the meanwhile, you can enable profile_batches=10 to the StreamingDataloader to check where is the time spent. When you can, could you share the trace with me, so I can help to you optimize it.

@tikboaHIT
Copy link
Author

Thanks @tchaton
Where should I find the generated result.json? After running the code StreamingDataLoader(dataset, batch_size=64, num_workers=16, profile_batches=5), I couldn't find the corresponding file.

@tchaton
Copy link
Collaborator

tchaton commented Mar 8, 2024

It should appear where you run the command. Maybe reduce the batch size and number of workers.

Could you provide a synthetic example for me to debug it too ? This helps tremendously to optimize those things. Here is another user synthetic script: #62 (comment) as a reference.

@tikboaHIT
Copy link
Author

It should appear where you run the command. Maybe reduce the batch size and number of workers.

Could you provide a synthetic example for me to debug it too ? This helps tremendously to optimize those things. Here is another user synthetic script: #62 (comment) as a reference.

Sure:

import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset

def generate_images(video_path):
    data = {
        "name": video_path,
        "image": torch.rand((3, 8, 320, 568)),
    }
    return data

optimize(
    fn=generate_images,
    inputs=list(range(100)),
    output_dir="/root/data/example_data/chunk_cache",
    num_workers=1,
    chunk_bytes="256MB",
)

input_dir = '/root/data/example_data/chunk_cache'
dataset = StreamingDataset(input_dir, shuffle=True)
for data in tqdm(dataset):
    pass

input_dir = 's3://pzhao/data/example_data/chunk_cache'
dataset = StreamingDataset(input_dir, shuffle=True)
for data in tqdm(dataset):
    pass

The speed is as follows when I load from local.
Screenshot 2024-03-08 at 19 48 10

The speed is as follows when I load from s3.
Screenshot 2024-03-08 at 19 48 29

@tchaton
Copy link
Collaborator

tchaton commented Mar 8, 2024

Hey @tikboaHIT. Thanks. Are you streaming from s3 to your local machine ? If yes, you might also get bottlenecked by our own internet connection.

Addtionally, this is super un-optimized. You are storing the full video as raw tensor. This is usually a 10-100x compared to JPEG encoding, 1000x times compared to av1 format.

def generate_images(video_path):
    data = {
        "name": video_path,
        "image": torch.rand((3, 8, 320, 568)),
    }
    return data

@tchaton
Copy link
Collaborator

tchaton commented Mar 8, 2024

Hey @tikboaHIT, I run the exact same code on Lightning AI A10G.

import os
import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset, StreamingDataLoader

input_dir = '/teamspace/datasets/videos3'
dataset = StreamingDataset(input_dir, shuffle=True)
dataloader = StreamingDataLoader(dataset, batch_size=1, num_workers=1)#os.cpu_count())
for data in tqdm(dataset):
    pass

It took 21 seconds for me. So it is definitely your internet connection. Usually, streaming dataset is much faster when streaming in the cloud provider where the data are stored.

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:21<00:00,  4.59it/s]

@tchaton
Copy link
Collaborator

tchaton commented Mar 8, 2024

A more efficient one would be to encode the images in JPEG as follow:

import os
import torch
from tqdm import tqdm
from litdata import optimize, StreamingDataset
from PIL import Image
from io import BytesIO

def generate_images(video_path):
    images = []
    for _ in range(8):
        random_image = torch.randint(0, 255, (320, 568, 3), dtype=torch.uint8).numpy()
        buff = BytesIO()
        Image.fromarray(random_image).save(buff, quality=90, format='JPEG') # You can implement a better resizing logic
        buff.seek(0)
        img = buff.read()
        images.append(Image.open(BytesIO(img)))
    return {
        "name": video_path,
        "image": images,
    }

optimize(
    fn=generate_images,
    inputs=list(range(100)),
    output_dir="/teamspace/datasets/videos5",
    num_workers=1,#,os.cpu_count(),
    chunk_bytes="64MB",
)

When streaming it from the cloud, it takes 1 seconds now.

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 50.08it/s]

Additionally, I recommend using torchvision.transforms.v2 which are roughly 40% faster at resizing the images, etc..

But alternatively, we support videos from torchvision video support: https://pytorch.org/audio/stable/build.ffmpeg.html. If you convert your clips into av1 format, they should get super small. You should be able to stream them easily and de-serialize them faster. Worth exploring.

@tchaton
Copy link
Collaborator

tchaton commented Mar 8, 2024

@tikboaHIT I also fixed the chunk_bytes not being correct with the optimize operator.

@tikboaHIT
Copy link
Author

tikboaHIT commented Mar 8, 2024

@tchaton Thank you for your suggestions and the quick bug fix. Regarding the logic for saving frames, I can provide more context. I mainly extract and save frames through the following logic, storing them in tensor and PIL.Image formats. I've found that the loading speed is very slow.

cmd = f"./preprocess/data_preprocess/get_frames.sh {video_name} 8 {frame_save_dir}"
os.popen(cmd).read()

imgs = {}
for index, img_path in enumerate(glob(f"{frame_save_dir}/*.jpg")):
   img = Image.fromarray(np.array(Image.open(img_path)))
   imgs[index] = img

data = {
   "name":  video_name,
   "image": imgs,
}

This should be similar to the solution you mentioned, both involve storing images in JPEG format.

You mentioned:

for _ in range(8):
    random_image = torch.randint(0, 255, (320, 568, 3), dtype=torch.uint8).numpy()
    buff = BytesIO()
    Image.fromarray(random_image).save(buff, quality=90, format='JPEG') # You can implement a better resizing logic
    buff.seek(0)
    img = buff.read()
    images.append(Image.open(BytesIO(img)))

@tchaton
Copy link
Collaborator

tchaton commented Mar 8, 2024

Hey @tikboaHIT . It isn't fully equivalent. In my case, I add some compression by explicitly converting them to JPEG with quality 90. You can see the difference by checking how many chunks files are being created. In my case, I had only 1. With your approach, I had dozens.

Would you mind trying the code above? If you don't see any differences, it means you are bottlenecked by your internet connection. I would recommend to try out Lightning AI to get the full speed.

@tikboaHIT
Copy link
Author

@tchaton Indeed, using JPEG compression can significantly improve reading speed, but overall, the process is not particularly smooth due to intermittent stuttering. Is this primarily limited by my AWS network?

The recording is as belows:

Screen.Recording.2024-03-11.at.13.09.02.mov

@tchaton
Copy link
Collaborator

tchaton commented Mar 11, 2024

Hey @tikboaHIT. Yes, I suspect this is the bottleneck and you would have the same not matter which library you are using.

But they might be some extra optimization to do for low bandwidth computer. Would you be free for a pair debugging session this week ?

Best,
T.C

@tikboaHIT
Copy link
Author

Hey @tchaton, that sounds great. Let's continue our conversation on Discord and find a time that works for both of us to schedule a meeting.

@tikboaHIT
Copy link
Author

Hey @tchaton, When I use litdata to load data, I encounter this issue suddenly when the training reaches around 80% of the first epoch.

Screenshot 2024-03-12 at 20 58 08

@tchaton
Copy link
Collaborator

tchaton commented Mar 13, 2024

Hey @tikboaHIT, are you using the latest version of litdata ? I think I was resolved this bug on main. Otherwise, would you mind sharing a reproducing script ?

@tikboaHIT
Copy link
Author

@tchaton Yes, version 2.2 is being used. Currently, the issue mainly occurs within a subset of the data. To provide some context, I have divided several thousand videos into 10 parts, with each part forming individual chunks according to the previous frame extraction logic. There were no errors in the first part of the data, but this problem arose in the second part. Due to privacy policy reasons, the reproducing script involves some data, which is not convenient to share.

Is there a way for me to debug on my own to find out the specific cause?

Additionally, there is another minor issue when I combine litdata with PyTorch Lightning. When the data volume is small [during the debugging process], training proceeds completely normally. However, when the data volume reaches around 500k, and I follow a training logic of train->validation->train, the training of the second epoch gets perpetually blocked. Then, this issue occurs: "watchdog caught collective operation timeout: WorkNCCL(SeqNum=27131, OpType=BROADCAST, NumelIn=274, NumelOut=274, Timeout(ms)=1800000) ran for 1800324 milliseconds before timing out." However, this issue does not arise if I do not perform validation in between.

The partial code snippet is shown below.
Screenshot 2024-03-13 at 23 20 03

@tchaton
Copy link
Collaborator

tchaton commented Mar 14, 2024

Hey @tikboaHIT,

Due to privacy policy reasons, the reproducing script involves some data, which is not convenient to share.

Do you think you could try to reproduce the bug with synthetic generated data or even an open source dataset? So I can debug it on my end.

train->validation->train

Are you using DDP ? Normally, each ranks should get the same quantity of data but it is possible there is a bug somewhere. If the length were to be different, then it would hang.

@tikboaHIT
Copy link
Author

tikboaHIT commented Mar 18, 2024

Hey @tchaton I've pinpointed that the issue lies with a particular piece of data within a chunk, as shown in the normal reading below.

Screenshot 2024-03-18 at 21 58 24

The abnormal data reading is as follows
Screenshot 2024-03-18 at 21 57 37

I'm not sure if this can be of help with your debugging.

@tchaton
Copy link
Collaborator

tchaton commented Apr 11, 2024

Hey @tikboaHIT. This means the chunk wasn't fully copied over when opened.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants