Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will CycleIterator forward to dataset on resume for pretrain? #1386

Open
calvintwr opened this issue May 5, 2024 · 3 comments
Open

Will CycleIterator forward to dataset on resume for pretrain? #1386

calvintwr opened this issue May 5, 2024 · 3 comments

Comments

@calvintwr
Copy link

calvintwr commented May 5, 2024

When resuming finetuning, I see that the CycleIterator is forwarded to the dataset where the iteration is to continue from:

# resume data loader state by fast-forwarding through all seen batches
if resume:
resume_t0 = time.perf_counter()
for resume_iter in range(initial_iter):
next(train_iterator)
if resume_iter % 1000 == 0:
fabric.print(f"Resuming dataset: {resume_iter} / {initial_iter}")
fabric.barrier()
fabric.print(
f"Resuming data loader finished. Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration"
f" {initial_iter}."
)

However, for pretrain, this does not exist and the training seems to resume from the begining:

litgpt/litgpt/pretrain.py

Lines 217 to 271 in f334378

def fit(
fabric: L.Fabric,
devices: int,
state: dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
out_dir: Path,
tokenizer_dir: Optional[Path],
train: TrainArgs,
eval: EvalArgs,
) -> None:
model = state["model"]
optimizer = state["optimizer"]
if eval.initial_validation:
val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)
val_loss = f"{val_loss:.3f}"
else:
validate(fabric, model, val_dataloader, max_iters=2) # sanity check
val_loss = "n/a"
throughput = ThroughputMonitor(fabric, window_size=5)
with torch.device("meta"):
meta_model = GPT(model.config)
x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length))
model_fwd = lambda: meta_model(x)
model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0)
measured_flops = measure_flops(meta_model, model_fwd, model_loss)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
max_tokens_per_device = train.max_tokens // fabric.world_size
tokens_per_iter = train.micro_batch_size * model.max_seq_length
max_iters = max_tokens_per_device // tokens_per_iter
log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices)
initial_iter = state["iter_num"]
train_iterator = CycleIterator(train_dataloader)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
fabric.device
)
fabric.barrier()
total_t0 = time.perf_counter()
warmup_iters = train.warmup_iters(devices, max_iters, train_dataloader)
for train_data in train_iterator:
if state["iter_num"] >= max_iters:
break
# determine and set the learning rate for this iteration
lr = get_lr(train.learning_rate, state["iter_num"], warmup_iters, max_iters, train.min_lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr

Can I check in this case, it looks like when resuming, the pretraining will start from the first dataset, and not forwarded?

@awaelchli
Copy link
Member

The pretraining code uses a stateful dataloader from LitData:

litgpt/litgpt/pretrain.py

Lines 192 to 198 in f334378

state = {
"model": model,
"optimizer": optimizer,
"train_dataloader": train_dataloader,
"iter_num": 0,
"step_count": 0,
}

@calvintwr
Copy link
Author

I see, thank you so much!

@calvintwr calvintwr reopened this May 24, 2024
@calvintwr
Copy link
Author

Sorry, @awaelchli quick question, in that case, is there a reason why it is not implemented for finetuning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants