[feature] persistent_workers; new args accumulate_grad_batches strategy

This commit is contained in:
Yiqing-Zhou 2023-05-06 19:39:24 +08:00
parent 45fa065530
commit 9b7f9b9d60
1 changed files with 16 additions and 3 deletions

View File

@ -8,7 +8,6 @@ import datasets
import pytorch_lightning as pl
import torch
import torchmetrics
from pytorch_lightning import strategies
from torch.utils.data import ConcatDataset, DataLoader
from transformers import (
AutoConfig,
@ -131,12 +130,24 @@ def parse_args():
help="Batch size of validating",
default=16,
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
help="Accumulate grad batches",
default=32,
)
parser.add_argument(
"--num_proc",
type=str,
help="Number of data processes",
default=16,
)
parser.add_argument(
"--strategy",
type=str,
help="Name of pytorch lightning distribution strategy",
default='fsdp',
)
parser.add_argument(
"--resume_from_ckpt_path",
type=str,
@ -256,6 +267,7 @@ if __name__ == '__main__':
batch_size=args.train_batch_size,
num_workers=args.num_proc,
collate_fn=DefaultDataCollator(),
persistent_workers=True,
shuffle=True,
)
val_dataloader = DataLoader(
@ -263,6 +275,7 @@ if __name__ == '__main__':
batch_size=args.val_batch_size,
num_workers=args.num_proc,
collate_fn=DefaultDataCollator(),
persistent_workers=True,
)
# trainer
@ -277,8 +290,8 @@ if __name__ == '__main__':
accelerator='gpu',
precision=precision,
log_every_n_steps=5,
accumulate_grad_batches=32,
strategy=strategies.DeepSpeedStrategy(stage=3),
accumulate_grad_batches=args.accumulate_grad_batches,
strategy=args.strategy,
)
lit_trainer.fit(
lit_module,