From 9b7f9b9d60da4fe95ee455f29864cae556845499 Mon Sep 17 00:00:00 2001 From: Yiqing-Zhou Date: Sat, 6 May 2023 19:39:24 +0800 Subject: [PATCH] [feature] persistent_workers; new args accumulate_grad_batches strategy --- train.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index ef7ded8..1ebee3e 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,6 @@ import datasets import pytorch_lightning as pl import torch import torchmetrics -from pytorch_lightning import strategies from torch.utils.data import ConcatDataset, DataLoader from transformers import ( AutoConfig, @@ -131,12 +130,24 @@ def parse_args(): help="Batch size of validating", default=16, ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + help="Accumulate grad batches", + default=32, + ) parser.add_argument( "--num_proc", type=str, help="Number of data processes", default=16, ) + parser.add_argument( + "--strategy", + type=str, + help="Name of pytorch lightning distribution strategy", + default='fsdp', + ) parser.add_argument( "--resume_from_ckpt_path", type=str, @@ -256,6 +267,7 @@ if __name__ == '__main__': batch_size=args.train_batch_size, num_workers=args.num_proc, collate_fn=DefaultDataCollator(), + persistent_workers=True, shuffle=True, ) val_dataloader = DataLoader( @@ -263,6 +275,7 @@ if __name__ == '__main__': batch_size=args.val_batch_size, num_workers=args.num_proc, collate_fn=DefaultDataCollator(), + persistent_workers=True, ) # trainer @@ -277,8 +290,8 @@ if __name__ == '__main__': accelerator='gpu', precision=precision, log_every_n_steps=5, - accumulate_grad_batches=32, - strategy=strategies.DeepSpeedStrategy(stage=3), + accumulate_grad_batches=args.accumulate_grad_batches, + strategy=args.strategy, ) lit_trainer.fit( lit_module,