Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterableDataset and Dataset return different batch sizes when using Trainer with multiple GPUs #5506

Closed
kheyer opened this issue Feb 6, 2023 · 4 comments

Comments

@kheyer
Copy link

kheyer commented Feb 6, 2023

Describe the bug

I am training a Roberta model using 2 GPUs and the Trainer API with a batch size of 256.

Initially I used a standard Dataset, but had issues with slow data loading. After reading this issue, I swapped to loading my dataset as contiguous shards and passing those to an IterableDataset. I observed an unexpected drop in GPU memory utilization, and found the batch size returned from the model had been cut in half.

When using Trainer with 2 GPUs and a batch size of 256, Dataset returns a batch of size 512 (256 per GPU), while IterableDataset returns a batch size of 256 (256 total). My guess is IterableDataset isn't accounting for multiple cards.

Steps to reproduce the bug

import datasets
from datasets import IterableDataset

from transformers import RobertaConfig
from transformers import RobertaTokenizerFast
from transformers import RobertaForMaskedLM

from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

use_iterable_dataset = True
def gen_from_shards(shards):
    for shard in shards:
        for example in shard:
            yield example

dataset = datasets.load_from_disk('my_dataset.hf')

if use_iterable_dataset:
    n_shards = 100
    shards = [dataset.shard(num_shards=n_shards, index=i) for i in range(n_shards)]
    dataset = IterableDataset.from_generator(gen_from_shards, gen_kwargs={"shards": shards})

tokenizer = RobertaTokenizerFast.from_pretrained("./my_tokenizer", max_len=160, use_fast=True)

config = RobertaConfig(
    vocab_size=8248,
    max_position_embeddings=256,
    num_attention_heads=8,
    num_hidden_layers=6,
    type_vocab_size=1)

model = RobertaForMaskedLM(config=config)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    per_device_train_batch_size=256
    # other args removed for brevity 
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

trainer.train()

Expected behavior

Expected Dataset and IterableDataset to have the same batch size behavior. If the current behavior is intentional, the batch size printout at the start of training should be updated. Currently, both dataset classes result in Trainer printing the same total batch size, even though the batch size sent to the GPUs are different.

Environment info

datasets 2.7.1
transformers 4.25.1

@lhoestq
Copy link
Member

lhoestq commented Feb 7, 2023

Hi ! datasets doesn't do batching - the PyTorch DataLoader does and is created by the Trainer. Do you pass other arguments to training_args with respect to data loading ?

Also we recently released .to_iterable_dataset that does pretty much what you implemented, but using contiguous shards to get a better speed:

if use_iterable_dataset:
    num_shards = 100
    dataset = dataset.to_iterable_dataset(num_shards=num_shards)

@kheyer
Copy link
Author

kheyer commented Feb 7, 2023

This is the full set of training args passed. No training args were changed when switching dataset types.

training_args = TrainingArguments(
    output_dir="./checkpoints",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=256,
    save_steps=2000,
    save_total_limit=4,
    prediction_loss_only=True,
    report_to='none',
    gradient_accumulation_steps=6,
    fp16=True,
    max_steps=60000,
    lr_scheduler_type='linear',
    warmup_ratio=0.1,
    logging_steps=100,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.98,
    adam_epsilon=1e-6,
    learning_rate=1e-4
)

@lhoestq
Copy link
Member

lhoestq commented Feb 8, 2023

I think the issue comes from transformers: huggingface/transformers#21444

@kheyer
Copy link
Author

kheyer commented Feb 8, 2023

Makes sense. Given that it's a transformers issue and already being tracked, I'll close this out.

@kheyer kheyer closed this as completed Feb 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants